Pytorch-tensor的分割,属性统计

    科技2025-03-24  15

    1.矩阵的分割

    方法:split(分割长度,所分割的维度),split([分割所占的百分比],所分割的维度)

    a=torch.rand(32,8) aa,bb=a.split(16,dim=0) print(aa.shape) print(bb.shape) cc,dd=a.split([20,12],dim=0) print(cc.shape) print(dd.shape)

    输出结果

    torch.Size([16, 8]) torch.Size([16, 8]) torch.Size([20, 8]) torch.Size([12, 8])

    2.tensor的属性统计

    min(dim=1):返回第一维的所有最小值,以及下标

    max(dim=1):返回第一维的所有最大值,以及下标

    a=torch.rand(4,3) print(a,'\n') print(a.min(dim=1),'\n') print(a.max(dim=1))

    输出结果

    tensor([[0.3876, 0.5638, 0.5768], [0.7615, 0.9885, 0.9660], [0.3622, 0.4334, 0.1226], [0.9390, 0.6292, 0.8370]]) torch.return_types.min( values=tensor([0.3876, 0.7615, 0.1226, 0.6292]), indices=tensor([0, 0, 2, 1])) torch.return_types.max( values=tensor([0.5768, 0.9885, 0.4334, 0.9390]), indices=tensor([2, 1, 1, 0]))

    mean:求平均值

    prod:求累乘

    sum:求累加

    argmin:求最小值下标

    argmax:求最大值下标

    a=torch.rand(1,3) print(a) print(a.mean()) print(a.prod()) print(a.sum()) print(a.argmin()) print(a.argmax())

    输出结果

    tensor([[0.5366, 0.9145, 0.0606]]) tensor(0.5039) tensor(0.0297) tensor(1.5117) tensor(2) tensor(1)

    3.tensor的topk()和kthvalue()

    topk(k,dim=a,largest=):输出维度为1的前k大的值,以及它们的下标。

    kthvalue(k,dim=a):输出维度为a的第k小的值,并输出它的下标。

    a=torch.rand(4,4) print(a,'\n') # 输出每一行中2个最大的值,并输出它们的下标 print(a.topk(2,dim=1),'\n') # 输出每一行中3个最小的值,并输出它们的下标 print(a.topk(3,dim=1,largest=False),'\n') # 输出每一行第2小的值,并输出下标 print(a.kthvalue(2,dim=1))

    输出结果

    tensor([[0.7131, 0.8148, 0.8036, 0.4720], [0.9135, 0.4639, 0.5114, 0.2277], [0.1314, 0.8407, 0.7990, 0.9426], [0.6556, 0.7316, 0.9648, 0.9223]]) torch.return_types.topk( values=tensor([[0.8148, 0.8036], [0.9135, 0.5114], [0.9426, 0.8407], [0.9648, 0.9223]]), indices=tensor([[1, 2], [0, 2], [3, 1], [2, 3]])) torch.return_types.topk( values=tensor([[0.4720, 0.7131, 0.8036], [0.2277, 0.4639, 0.5114], [0.1314, 0.7990, 0.8407], [0.6556, 0.7316, 0.9223]]), indices=tensor([[3, 0, 2], [3, 1, 2], [0, 2, 1], [0, 1, 3]])) torch.return_types.kthvalue( values=tensor([0.7131, 0.4639, 0.7990, 0.7316]), indices=tensor([0, 1, 2, 1]))
    Processed: 0.009, SQL: 9