🔥torch.cat()用法✨
在PyTorch中,`torch.cat()` 是一个非常实用的函数,用于将多个张量按指定维度拼接在一起。简单来说,它就像是把不同的积木块拼成一个更大的结构!🤔
首先,确保你导入了PyTorch:`import torch` 。然后,假设你有两个形状相同的张量 `tensor1 = torch.tensor([[1, 2], [3, 4]])` 和 `tensor2 = torch.tensor([[5, 6], [7, 8]])`。如果你想沿行方向(即第0维)拼接它们,只需执行:
```python
result = torch.cat((tensor1, tensor2), dim=0)
```
结果会是:
```
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
```
如果想沿列方向(即第1维)拼接,则设置 `dim=1`:
```python
result = torch.cat((tensor1, tensor2), dim=1)
```
输出变为:
```
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
```
需要注意的是,所有张量的形状必须在非拼接维度上保持一致哦!💼
掌握这个小技巧,处理多维数据时会更加得心应手!🚀
免责声明:本答案或内容为用户上传,不代表本网观点。其原创性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容、文字的真实性、完整性、及时性本站不作任何保证或承诺,请读者仅作参考,并请自行核实相关内容。 如遇侵权请及时联系本站删除。