pytorch不同类型进行转换

    科技2022-07-14  120

    不同类型进行转换

    1.1 张量的数据类型

    i = torch.tensor(1); print(i,i.dtype) x = i.float(); print(x,x.dtype) #调用 float方法转换成浮点类型 y = i.type(torch.float); print(y,y.dtype) #使用type函数转换成浮点类型 z = i.type_as(x); print(z,z.dtype) #使用type_as方法转换成某个Tensor相同类型

    tensor(1) torch.int64 tensor(1.) torch.float32 tensor(1.) torch.float32 tensor(1.) torch.float32

    1.2 张量的维度

    不同类型的数据可以用不同维度(dimension)的张量来表示。

    标量为0维张量,向量为1维张量,矩阵为2维张量。

    彩色图像有rgb三个通道,可以表示为3维张量。

    视频还有时间维,可以表示为4维张量。

    可以简单地总结为:有几层中括号,就是多少维的张量。

    scalar = torch.tensor(True) print(scalar) print(scalar.dim()) # 标量,0维张量

    tensor(True) 0

    vector = torch.tensor([1.0,2.0,3.0,4.0]) #向量,1维张量 print(vector) print(vector.dim())

    tensor([1., 2., 3., 4.]) 1

    matrix = torch.tensor([[1.0,2.0],[3.0,4.0]]) #矩阵, 2维张量 print(matrix) print(matrix.dim())

    tensor([[1., 2.], [3., 4.]]) 2

    tensor3 = torch.tensor([[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]]) # 3维张量 print(tensor3) print(tensor3.dim())

    tensor([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]]) 3

    1.3 张量的尺寸

    可以使用 shape属性或者 size()方法查看张量在每一维的长度.

    可以使用view方法改变张量的尺寸。

    如果view方法改变尺寸失败,可以使用reshape方法.

    scalar = torch.tensor(True) vector = torch.tensor([1.0,2.0,3.0,4.0]) print(vector.size()) print(vector.shape)

    torch.Size([]) torch.Size([4])

    matrix = torch.tensor([[1.0,2.0],[3.0,4.0]]) print(matrix.size())

    torch.Size([2, 2])

    # 使用view可以改变张量尺寸 vector = torch.arange(0,12) print(vector) print(vector.shape) matrix34 = vector.view(3,4) print(matrix34) print(matrix34.shape) matrix43 = vector.view(4,-1) #-1表示该位置长度由程序自动推断 print(matrix43) print(matrix43.shape)

    tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) torch.Size([12]) tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) torch.Size([3, 4]) tensor([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) torch.Size([4, 3])

    # 有些操作会让张量存储结构扭曲,直接使用view会失败,可以用reshape方法 matrix26 = torch.arange(0,12).view(2,6) print(matrix26) print(matrix26.shape) # 转置操作让张量存储结构扭曲 matrix62 = matrix26.t() print(matrix62.is_contiguous()) # 直接使用view方法会失败,可以使用reshape方法 #matrix34 = matrix62.view(3,4) #error! matrix34 = matrix62.reshape(3,4) #等价于matrix34 = matrix62.contiguous().view(3,4) print(matrix34)

    tensor([[ 0, 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10, 11]]) torch.Size([2, 6]) False tensor([[ 0, 6, 1, 7], [ 2, 8, 3, 9], [ 4, 10, 5, 11]])

    1.4 张量和numpy数组

    可以用numpy方法从Tensor得到numpy数组,也可以用torch.from_numpy从numpy数组得到Tensor。

    这两种方法关联的Tensor和numpy数组是共享数据内存的。

    如果改变其中一个,另外一个的值也会发生改变。

    如果有需要,可以用张量的clone方法拷贝张量,中断这种关联。

    此外,还可以使用item方法从标量张量得到对应的Python数值。

    使用tolist方法从张量得到对应的Python数值列表。

    #torch.from_numpy函数从numpy数组得到Tensor arr = np.zeros(3) tensor = torch.from_numpy(arr) print("before add 1:") print(arr) print(tensor) print("\nafter add 1:") np.add(arr,1, out = arr) #给 arr增加1,tensor也随之改变 print(arr) print(tensor)

    before add 1: [0. 0. 0.] tensor([0., 0., 0.], dtype=torch.float64)

    after add 1: [1. 1. 1.] tensor([1., 1., 1.], dtype=torch.float64)

    # numpy方法从Tensor得到numpy数组 tensor = torch.zeros(3) arr = tensor.numpy() print("before add 1:") print(tensor) print(arr) print("\nafter add 1:") #使用带下划线的方法表示计算结果会返回给调用 张量 tensor.add_(1) #给 tensor增加1,arr也随之改变 #或: torch.add(tensor,1,out = tensor) print(tensor) print(arr)

    before add 1: tensor([0., 0., 0.]) [0. 0. 0.]

    after add 1: tensor([1., 1., 1.]) [1. 1. 1.]

    # 可以用clone() 方法拷贝张量,中断这种关联 tensor = torch.zeros(3) #使用clone方法拷贝张量, 拷贝后的张量和原始张量内存独立 arr = tensor.clone().numpy() # 也可以使用tensor.data.numpy() print("before add 1:") print(tensor) print(arr) print("\nafter add 1:") #使用 带下划线的方法表示计算结果会返回给调用 张量 tensor.add_(1) #给 tensor增加1,arr不再随之改变 print(tensor) print(arr)

    before add 1: tensor([0., 0., 0.]) [0. 0. 0.]

    after add 1: tensor([1., 1., 1.]) [0. 0. 0.]

    # item方法和tolist方法可以将张量转换成Python数值和数值列表 scalar = torch.tensor(1.0) s = scalar.item() print(s) print(type(s)) tensor = torch.rand(2,2) t = tensor.tolist() print(t) print(type(t))

    1.0 <class ‘float’> [[0.5526873469352722, 0.46957558393478394], [0.6724914312362671, 0.26923561096191406]] <class ‘list’>

    Pytorch教程链接

    Processed: 0.024, SQL: 8