''' @作者:你遇了我321640253@qq.com @文件:prediction.py @创建时间:2023 11 20 模型预测功能 ''' import torch import train.VGG16Net as VGG16Net class Predict(): ''' :description 使用模型进行预测 :author 你遇了我 ''' def __init__(self,modelPath:str) -> None: #获取模型结构、加载权重 self.model = VGG16Net.VGG16() self.model.load_state_dict(torch.load(modelPath)) def predict_img(imgpath:str): ''' :description 预测图片 :author 你遇了我 :param imgpath 图片路径 :return ''' pass