Onnx库笔记

onnx模型转换与运行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#task-start
import numpy as np
import onnxruntime as ort
import torch
import torch.nn as nn


class TextClassifier(nn.Module):
def __init__(self, vocab_size=1000, embed_dim=128, hidden_dim=512, num_classes=2):
super(TextClassifier, self).__init__()

self.embedding = nn.Embedding(vocab_size, embed_dim)
self.rnn = nn.LSTM(embed_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, num_classes)

def forward(self, text):

embedded = self.embedding(text)
packed_output, (hidden, cell) = self.rnn(embedded)
output = self.fc(hidden.squeeze(0))
return output


def convert():
model = TextClassifier()
model.load_state_dict(torch.load('model.pt'))
model.eval()
dummy_input = torch.ones([256, 1], dtype=torch.long)
torch.onnx.export(model, dummy_input, "text_classifier.onnx", opset_version=11, input_names=['input'], output_names=['output'])
# TODO


def inference(model_path, input):
# TODO
ort_session = ort.InferenceSession(model_path)
padded_input = input + [0] * (256- len(input))
input_array = np.array(padded_input, dtype=np.int64).reshape(256, 1)
inputs = {ort_session.get_inputs()[0].name: input_array}
ort_outs = ort_session.run(None, inputs)
result = ort_outs[0].tolist()

return result


def main():
convert()
result = inference('/home/project/text_classifier.onnx', [101, 304, 993, 108,102])
print(result)


if __name__ == '__main__':
main()
#task-end