a = torch.Tensor([[1, 2, 3, 7, 8, 5, 4, 9],
[1, 3, 4, 6, 1, 3, 3, 9]])
print(a[0])
b = torch.Tensor([1, 1])
c = torch.Tensor([1, 2])
for i in c:
index = i == b
print(index)
print(a[index])
输入:
tensor([1., 2., 3., 7., 8., 5., 4., 9.])
tensor([True, True])
tensor([[1., 2., 3., 7., 8., 5., 4., 9.],
[1., 3., 4., 6., 1., 3., 3., 9.]])
tensor([False, False])
tensor([], size=(0, 8))