导读 在PyTorch中,`torch.cat()` 是一个非常实用的函数,用于将多个张量按指定维度拼接在一起。简单来说,它就像是把不同的积木块拼成一个更大...
在PyTorch中,`torch.cat()` 是一个非常实用的函数,用于将多个张量按指定维度拼接在一起。简单来说,它就像是把不同的积木块拼成一个更大的结构!🤔
首先,确保你导入了PyTorch:`import torch` 。然后,假设你有两个形状相同的张量 `tensor1 = torch.tensor([[1, 2], [3, 4]])` 和 `tensor2 = torch.tensor([[5, 6], [7, 8]])`。如果你想沿行方向(即第0维)拼接它们,只需执行:
```python
result = torch.cat((tensor1, tensor2), dim=0)
```
结果会是:
```
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
```
如果想沿列方向(即第1维)拼接,则设置 `dim=1`:
```python
result = torch.cat((tensor1, tensor2), dim=1)
```
输出变为:
```
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
```
需要注意的是,所有张量的形状必须在非拼接维度上保持一致哦!💼
掌握这个小技巧,处理多维数据时会更加得心应手!🚀