1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
| import numpy as np import onnxruntime as ort import torch import torch.nn as nn
class TextClassifier(nn.Module): def __init__(self, vocab_size=1000, embed_dim=128, hidden_dim=512, num_classes=2): super(TextClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim) self.rnn = nn.LSTM(embed_dim, hidden_dim) self.fc = nn.Linear(hidden_dim, num_classes)
def forward(self, text):
embedded = self.embedding(text) packed_output, (hidden, cell) = self.rnn(embedded) output = self.fc(hidden.squeeze(0)) return output
def convert(): model = TextClassifier() model.load_state_dict(torch.load('model.pt')) model.eval() dummy_input = torch.ones([256, 1], dtype=torch.long) torch.onnx.export(model, dummy_input, "text_classifier.onnx", opset_version=11, input_names=['input'], output_names=['output'])
def inference(model_path, input): ort_session = ort.InferenceSession(model_path) padded_input = input + [0] * (256- len(input)) input_array = np.array(padded_input, dtype=np.int64).reshape(256, 1) inputs = {ort_session.get_inputs()[0].name: input_array} ort_outs = ort_session.run(None, inputs) result = ort_outs[0].tolist()
return result
def main(): convert() result = inference('/home/project/text_classifier.onnx', [101, 304, 993, 108,102]) print(result)
if __name__ == '__main__': main()
|