ZeroLauncher/util/OnnxManager.cs

100 lines
2.8 KiB
C#
Raw Permalink Normal View History

2024-05-12 16:43:07 +08:00
using Microsoft.ML.OnnxRuntime.Tensors;
using Microsoft.ML.OnnxRuntime;
using System.Diagnostics;
using System.Text;
using System.Drawing;
namespace Zerolauncher.util
{
public class OnnxManager
2024-05-19 10:43:41 +08:00
{
public static OnnxVerify? onnxVerify;
}
public class OnnxVerify
2024-05-12 16:43:07 +08:00
{
private string _modelPath;
private InferenceSession _session;
private readonly string _inputName;
static string[] labels = [
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z",
"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M",
"N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z",
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"
];
2024-05-19 10:43:41 +08:00
public OnnxVerify(string modelPath = "plugin/model.bin")
2024-05-12 16:43:07 +08:00
{
_modelPath = modelPath;
_session = new InferenceSession(_modelPath);
// 获取模型的输入节点名称
_inputName = _session.InputMetadata.Keys.First();
}
public string RunInference(Bitmap image)
{
var tensor = CtcPreprocess(image, (3, 48, 320));
var inputs = new List<NamedOnnxValue> { NamedOnnxValue.CreateFromTensor(_inputName, tensor) };
using var results = _session.Run(inputs);
var outputs = results.First().AsTensor<float>();
var dimensions = outputs.Dimensions;
var sb = new StringBuilder();
float tmp; int tmp_index, last_index = -1;
for (int j = 0; j < dimensions[1]; j++)
{
tmp = 0; tmp_index = 0;
for (int k = 0; k < dimensions[2]; k++)
{
var tmp1 = outputs[0, j, k];
if (tmp < tmp1) { tmp = tmp1; tmp_index = k; }
}
if (tmp_index == 0 || last_index == tmp_index)
{
// 试着过滤一下重复的字符假设最大重复长度为2
last_index = -1;
continue;
}
last_index = tmp_index;
sb.Append(labels[tmp_index - 1]);
}
return sb.ToString();
}
static DenseTensor<float> CtcPreprocess(Bitmap image, (int, int, int) recImageShape)
{
var (imgC, imgH, imgW) = recImageShape;
var maxWhRatio = imgW / (float)imgH;
var h = image.Height;
var w = image.Width;
var whRatio = w * 1.0f / h;
maxWhRatio = Math.Max(maxWhRatio, whRatio);
Debug.Assert(imgC == 3); // Assuming the image is in RGB format
imgW = (int)(imgH * maxWhRatio);
var ratio = w / (float)h;
var resizedW = Math.Ceiling(imgH * ratio) > imgW ? imgW : (int)Math.Ceiling(imgH * ratio);
var resizedImage = new Bitmap(image, new Size(resizedW, imgH));
var paddingIm = new DenseTensor<float>([1, imgC, imgH, imgW]);
for (int i = 0; i < resizedW; i++)
{
for (int j = 0; j < imgH; j++)
{
var pixel = resizedImage.GetPixel(i, j);
paddingIm[0, 0, j, i] = (pixel.R / 255.0f - 0.5f) / 0.5f;
paddingIm[0, 1, j, i] = (pixel.G / 255.0f - 0.5f) / 0.5f;
paddingIm[0, 2, j, i] = (pixel.B / 255.0f - 0.5f) / 0.5f;
}
}
return paddingIm;
}
}
}