35 lines
700 B
Python
35 lines
700 B
Python
'''
|
|
@作者:你遇了我321640253@qq.com
|
|
@文件:prediction.py
|
|
@创建时间:2023 11 20
|
|
|
|
模型预测功能
|
|
'''
|
|
import torch
|
|
import train.VGG16Net as VGG16Net
|
|
|
|
class Predict():
|
|
'''
|
|
:description
|
|
使用模型进行预测
|
|
:author 你遇了我
|
|
'''
|
|
def __init__(self,modelPath:str) -> None:
|
|
|
|
#获取模型结构、加载权重
|
|
self.model = VGG16Net.VGG16()
|
|
self.model.load_state_dict(torch.load(modelPath))
|
|
|
|
def predict_img(imgpath:str):
|
|
'''
|
|
:description 预测图片
|
|
:author 你遇了我
|
|
:param
|
|
imgpath 图片路径
|
|
:return
|
|
'''
|
|
pass
|
|
|
|
|
|
|