<>模型的加载
import torchvision.models as models resnet34 = models.resnet34() resnet34.
load_state_dict(torch.load('latest.pth')['model'])
<>要解决的疑问
* load_state_dict torch.load作用
网络结构有了 这部分是在加载参数
* dummy input作用
给网络一个输入
* 如果dynamic_axes 后面输入可以更改指定的维度
* binding inputname outputname作用
binding 每个engine有且只有两个binding,对应输入输出
name可以理解为指针,在转onnx时候就指定根据这个指针拿到输入输出的内容 dummy_input=torch.randn(BATCH_SIZE, 3,
224, 224) import torch.onnx torch.onnx.export(resnet34, dummy_input,
"rp_rec.onnx", verbose=False)
<>注意
torchvision和mmcls的Resnet模型不一样
resnet34 = models.resnet34() resnet34.load_state_dict(torch.load('latest.pth')[
'model'])
模型必须和参数对应起来
不能用torchvision的模型加载mmcls的参数
<>Pytorch转TensorRT方法总结
采用mmclassification框架,根据网络推理时的输入指定网络输入dummy_input,看推理代码,如果网络允许某个维度有变化,那么可以设定dynamic_axes(某个维度定死了,就不要dynamic_axes),采用verify参数,对比模型的输出是否一致
步骤:在服务器上完成trt到onnx转换(configs等等不好往板卡放)
然后将deployment复制到板卡上,执行转trt代码