MNIST图像识别模型使用问题

我在实现了MNIST图像识别的模型训练后,运行了测试例子,运行结果没有问题,但是我遇到了一个问题:我无法加载本地图片进行图片识别。

from datetime import date
from matplotlib import pyplot as plt
import jittor as jt
from numpy.core.fromnumeric import shape
from model import Model
from jittor.dataset.mnist import MNIST
import jittor.transform as trans
import numpy as np
import cv2    

model_path = '/home/*****/Python_Demo/JittorMNISTImageClassification/mnist_model.pkl'
new_model = Model()
new_model.load_parameters(jt.load(model_path))
val_loader = MNIST(train=False, transform=trans.Resize(28)).set_attrs(batch_size=1, shuffle=False)
data_iter = iter(val_loader)
val_data, val_label = next(data_iter)
outputs = new_model(val_data)
prediction = np.argmax(outputs.data, axis=1)
print(val_label.data)
print(prediction)

程序运行结果如图

请大佬告知如何加载本地图片并利用模型进行图片识别。

您可以用一张图试试以下代码。我也用 ppt 做了一张手写数字的图,可以下载保存命名为 7.jpg
7

# 读取图像并处理
image = cv2.imread('7.jpg')         # 得到一个 HxWx3 的 array
image = cv2.resize(image, (28, 28)) # 把图像缩放到 28x28 个像素
image = image / 255.0               # 把图像的 RGB 值从 [0, 255] 变为 [0, 1]
image = image.transpose(2, 0, 1)    # 把输入格式从 HWC 改为 CHW
image = jt.float32(image)           # 变为 Jittor Var
image = image.unsqueeze(dim=0)      # 加入 batch 维度,变为 [1, C, H, W]

outputs = model(image)
prediction = np.argmax(outputs.data, axis=1)
print(prediction)
1 Like

非常感谢!!!!
已成功运行。