<>RuntimeError: Expected object of scalar type Double but got scalar type
Float for argument #3 ‘mat2’ in call to _th_addmm_out

<>1. 说明

在训练网络的过程中由于类型的冲突导致这种错误,主要是模型内部参数和输入类型不一致所导致的。主要有两个部分需要注意到:1.自己定义的变量要设置为一种数据类型;2.网络内部的变量类型也要统一。

<>2. 解决办法一

统一声明变量的类型。
# 将接下来创建的变量类型均为Double torch.set_default_tensor_type(torch.DoubleTensor)
or
#将接下来创建的变量类型均为Float torch.set_default_tensor_type(torch.FloatTensor)
一定要注意要在变量创建之间声明类型。

<>3. 解决办法二

在训练过程中加入一下两点即可:
# For your model net = net.double() # For your data net(input_x.double)

技术
今日推荐
下载桌面版
GitHub
百度网盘(提取码:draw)
Gitee
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:[email protected]
QQ群:766591547
关注微信