🔥torch.cat()用法✨

2025-03-23 07:09:07
导读 在PyTorch中,`torch.cat()` 是一个非常实用的函数,用于将多个张量按指定维度拼接在一起。简单来说,它就像是把不同的积木块拼成一个更大...

在PyTorch中,`torch.cat()` 是一个非常实用的函数,用于将多个张量按指定维度拼接在一起。简单来说,它就像是把不同的积木块拼成一个更大的结构!🤔

首先,确保你导入了PyTorch:`import torch` 。然后,假设你有两个形状相同的张量 `tensor1 = torch.tensor([[1, 2], [3, 4]])` 和 `tensor2 = torch.tensor([[5, 6], [7, 8]])`。如果你想沿行方向(即第0维)拼接它们,只需执行:

```python

result = torch.cat((tensor1, tensor2), dim=0)

```

结果会是:

```

tensor([[1, 2],

[3, 4],

[5, 6],

[7, 8]])

```

如果想沿列方向(即第1维)拼接,则设置 `dim=1`:

```python

result = torch.cat((tensor1, tensor2), dim=1)

```

输出变为:

```

tensor([[1, 2, 5, 6],

[3, 4, 7, 8]])

```

需要注意的是,所有张量的形状必须在非拼接维度上保持一致哦!💼

掌握这个小技巧,处理多维数据时会更加得心应手!🚀

郑重声明:本文版权归原作者所有,转载文章仅为传播更多信息之目的,如作者信息标记有误,请第一时间联系我们修改或删除,多谢。