<>RuntimeError: Expected object of scalar type Double but got scalar type
Float for argument #3 ‘mat2’ in call to _th_addmm_out
<>1. 说明
在训练网络的过程中由于类型的冲突导致这种错误,主要是模型内部参数和输入类型不一致所导致的。主要有两个部分需要注意到:1.自己定义的变量要设置为一种数据类型;2.网络内部的变量类型也要统一。
<>2. 解决办法一
统一声明变量的类型。
# 将接下来创建的变量类型均为Double torch.set_default_tensor_type(torch.DoubleTensor)
or
#将接下来创建的变量类型均为Float torch.set_default_tensor_type(torch.FloatTensor)
一定要注意要在变量创建之间声明类型。
<>3. 解决办法二
在训练过程中加入一下两点即可:
# For your model net = net.double() # For your data net(input_x.double)