Pytorch中list遇到隐式.cuda()避雷

定义了一个神经网络类,为了灵活性需要使用list,如下

1
2
3
4
5
6
class Net(nn.Module):
def __init__(self, ...):
super(Net, self).__init__()
self.network_layer1 = []
for i in range(self.max_itr_num):
self.network_layer1.append(PointerNet.PointerNet())

在对这个类的实例进行.cuda()操作不会报错

1
2
A = Net()
A = A.cuda()

但是在使用时就会报错一部分数据/模型在CPU上,另一部分在GPU。检查后发现PointerNet实例都不在GPU上

1
RuntimeError: Tensor for 'out' is on CPU, Tensor for argument #1 'self' is on CPU, but expected them to be on GPU

经过实验,发现是因为list这个数据类型无法.cuda()。在对A进行.cuda()操作时,list受到隐式的.cuda()操作,所以虽然无法.cuda()但是不会报错。如果想要对list里面的内容进行.cuda()操作,就需要单独进行显式的.cuda()操作:

1
2
3
A = Net()
for i in range(len(A.network_layer1)):
A.network_layer1[i] = A.network_layer1[i].cuda()