完成模型日志记录模块

This commit is contained in:
2023-11-20 23:53:12 +08:00
parent 3a76ff507f
commit 7e23f6a7b3
8 changed files with 176 additions and 22 deletions

View File

@@ -1,13 +1,34 @@
'''
1 加载数据
2 构建模型
3 获取损失函数
4 获取优化器
5 开始训练 调用3、4
1 img--->model--->out
2 out y 计算loss
@作者:你遇了我321640253@qq.com
@文件:prediction.py
@创建时间:2023 11 20
模型预测功能
'''
import torch
import train.VGG16Net as VGG16Net
class Predict():
'''
:description
使用模型进行预测
:author 你遇了我
'''
def __init__(self,modelPath:str) -> None:
#获取模型结构、加载权重
self.model = VGG16Net.VGG16()
self.model.load_state_dict(torch.load(modelPath))
def predict_img(imgpath:str):
'''
:description 预测图片
:author 你遇了我
:param
imgpath 图片路径
:return
'''
pass
for i in range(100):
for img in dasf: