完成模型日志记录模块
This commit is contained in:
@@ -1,13 +1,34 @@
|
||||
'''
|
||||
1 加载数据
|
||||
2 构建模型
|
||||
3 获取损失函数
|
||||
4 获取优化器
|
||||
5 开始训练 调用3、4
|
||||
1 img--->model--->out
|
||||
2 out y 计算loss
|
||||
@作者:你遇了我321640253@qq.com
|
||||
@文件:prediction.py
|
||||
@创建时间:2023 11 20
|
||||
|
||||
模型预测功能
|
||||
'''
|
||||
import torch
|
||||
import train.VGG16Net as VGG16Net
|
||||
|
||||
class Predict():
|
||||
'''
|
||||
:description
|
||||
使用模型进行预测
|
||||
:author 你遇了我
|
||||
'''
|
||||
def __init__(self,modelPath:str) -> None:
|
||||
|
||||
#获取模型结构、加载权重
|
||||
self.model = VGG16Net.VGG16()
|
||||
self.model.load_state_dict(torch.load(modelPath))
|
||||
|
||||
def predict_img(imgpath:str):
|
||||
'''
|
||||
:description 预测图片
|
||||
:author 你遇了我
|
||||
:param
|
||||
imgpath 图片路径
|
||||
:return
|
||||
'''
|
||||
pass
|
||||
|
||||
|
||||
for i in range(100):
|
||||
|
||||
for img in dasf:
|
||||
|
||||
Reference in New Issue
Block a user