ZeroLauncher/util/OnnxManager.cs
2024-05-12 16:43:07 +08:00

96 lines
2.8 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using Microsoft.ML.OnnxRuntime.Tensors;
using Microsoft.ML.OnnxRuntime;
using System.Diagnostics;
using System.Text;
using System.Drawing;
namespace Zerolauncher.util
{
public class OnnxManager
{
private string _modelPath;
private InferenceSession _session;
private readonly string _inputName;
static string[] labels = [
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z",
"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M",
"N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z",
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"
];
public OnnxManager(string modelPath = "ocr/model.onnx")
{
_modelPath = modelPath;
_session = new InferenceSession(_modelPath);
// 获取模型的输入节点名称
_inputName = _session.InputMetadata.Keys.First();
}
public string RunInference(Bitmap image)
{
var tensor = CtcPreprocess(image, (3, 48, 320));
var inputs = new List<NamedOnnxValue> { NamedOnnxValue.CreateFromTensor(_inputName, tensor) };
using var results = _session.Run(inputs);
var outputs = results.First().AsTensor<float>();
var dimensions = outputs.Dimensions;
var sb = new StringBuilder();
float tmp; int tmp_index, last_index = -1;
for (int j = 0; j < dimensions[1]; j++)
{
tmp = 0; tmp_index = 0;
for (int k = 0; k < dimensions[2]; k++)
{
var tmp1 = outputs[0, j, k];
if (tmp < tmp1) { tmp = tmp1; tmp_index = k; }
}
if (tmp_index == 0 || last_index == tmp_index)
{
// 试着过滤一下重复的字符假设最大重复长度为2
last_index = -1;
continue;
}
last_index = tmp_index;
sb.Append(labels[tmp_index - 1]);
}
return sb.ToString();
}
static DenseTensor<float> CtcPreprocess(Bitmap image, (int, int, int) recImageShape)
{
var (imgC, imgH, imgW) = recImageShape;
var maxWhRatio = imgW / (float)imgH;
var h = image.Height;
var w = image.Width;
var whRatio = w * 1.0f / h;
maxWhRatio = Math.Max(maxWhRatio, whRatio);
Debug.Assert(imgC == 3); // Assuming the image is in RGB format
imgW = (int)(imgH * maxWhRatio);
var ratio = w / (float)h;
var resizedW = Math.Ceiling(imgH * ratio) > imgW ? imgW : (int)Math.Ceiling(imgH * ratio);
var resizedImage = new Bitmap(image, new Size(resizedW, imgH));
var paddingIm = new DenseTensor<float>([1, imgC, imgH, imgW]);
for (int i = 0; i < resizedW; i++)
{
for (int j = 0; j < imgH; j++)
{
var pixel = resizedImage.GetPixel(i, j);
paddingIm[0, 0, j, i] = (pixel.R / 255.0f - 0.5f) / 0.5f;
paddingIm[0, 1, j, i] = (pixel.G / 255.0f - 0.5f) / 0.5f;
paddingIm[0, 2, j, i] = (pixel.B / 255.0f - 0.5f) / 0.5f;
}
}
return paddingIm;
}
}
}