torch.Tensor

张量是pytorch中一种基本的数据结构,简单说就是多维数组。

这个图讲的挺好的 image-20220730194357227

  • 0维张量其实就是一个常数,1维张量也可以看成是一个向量或者一个列表,所以求他们的shape就是一个长度值。2维张量也好理解,就是一个矩阵。

image-20220730194901640

  • 3维以上就不太好理解了。把数据打印出来也不太好看。总结了一下,就是从外往里数方括号,最外面的是第一维,接着是第二维,然后越往里维数越大。

image-20220730202944833

  • 比如说下面这个,可以把他摊到一条直线上:

    • 第一维是3,所以第一个括号里有三个括号,[[],[],[]]
    • 第二维是2,所以第二个括号里面有两个括号,[[[],[]],[[],[]],[[],[]]]

    • 第三维是1,第三维就是最后一维了,所以第三个括号里就是一个数,不用再加括号了。如果第三维大于1,则是填充对应个数的数字,最后一维可以看成是个列表。

image-20220730203044705

  • pytorch打印出来最开始看着挺不习惯的,按照他这个打印的效果来看,最后一维的数字个数其实就是打印出来的列数(和二维中的行、列概念不一样),就是最后一维有几个数,在打印结果中就能看到竖着有几列数。这个只是按我自己方便理解来的,不一定严谨。

image-20220730204423902

image-20220730204557154

  • 这个维数比较大,jupyter notebook都打印不全了。 image-20220730205300508
  • (3,2,4,5)结果是这样:其实能发现括号是竖着对齐的,第一个括号竖着往下看没有对应的括号,因为这个是最外层的括号,里面应该有三组括号;第二个括号往下看,加上本身确实能数出来三个括号;从第三个括号开始就不能单纯的数了,往下看,加上本身能数出来六个括号,也就是3×2=6;第四个括号往下数,能数出来3×2×4=24个括号;第四个括号后面包围着的就是一个长度为5的列表。
tensor([[[[0.1718, 0.8527, 0.9666, 0.7092, 0.5154],
          [0.2044, 0.8089, 0.9778, 0.9318, 0.9228],
          [0.5465, 0.1736, 0.2631, 0.2034, 0.9088],
          [0.4270, 0.4081, 0.4631, 0.3560, 0.6983]],

         [[0.2999, 0.7122, 0.5903, 0.1277, 0.5751],
          [0.1246, 0.5702, 0.1140, 0.8634, 0.9751],
          [0.3634, 0.7508, 0.0780, 0.7560, 0.0486],
          [0.2376, 0.0524, 0.6017, 0.1852, 0.7268]]],


        [[[0.4310, 0.0380, 0.6877, 0.6871, 0.1552],
          [0.0410, 0.7598, 0.5961, 0.2067, 0.5000],
          [0.2841, 0.0070, 0.7234, 0.4698, 0.5778],
          [0.0604, 0.6361, 0.3279, 0.3193, 0.0935]],

         [[0.5848, 0.5055, 0.1627, 0.1865, 0.6084],
          [0.8139, 0.8762, 0.8718, 0.9310, 0.5559],
          [0.9700, 0.0542, 0.3832, 0.7804, 0.0634],
          [0.5485, 0.8710, 0.0250, 0.5183, 0.2197]]],


        [[[0.4483, 0.8109, 0.3941, 0.4923, 0.1219],
          [0.4432, 0.1553, 0.9352, 0.8578, 0.9011],
          [0.6211, 0.4054, 0.0948, 0.8497, 0.0453],
          [0.3323, 0.5797, 0.9602, 0.2465, 0.1827]],

         [[0.8947, 0.4333, 0.6646, 0.3788, 0.3188],
          [0.6941, 0.8336, 0.0191, 0.7483, 0.1941],
          [0.6901, 0.2755, 0.2638, 0.4162, 0.9805],
          [0.8687, 0.4539, 0.1989, 0.0908, 0.3199]]]])

为啥要费劲说这么多这个呢,因为我发现不理解清楚多维数组,根本没法理解torch.view()、torch.squeeze() / torch.unsqueeze()、torch.permute()等改变tensor形状的函数。