Belle II Software development
write_test_files.py
1#!/usr/bin/env python3
2
3
10
11# This example creates the files
12# - ModelForStandalone.onnx used in test for the standalone ONNX interface
13# - ModelABToAB and ModelAToATwiceA to test the behaviour of the ONNX MVA interface for naming inputs/outputs
14
15import torch
16from torch import nn
17
18
19class Model(nn.Module):
20 """
21 Example model with 2 different input tensors
22 """
23
24 def __init__(self):
25 """
26 Intialize with a single Linear layer
27 """
28 super().__init__()
29
30
31 self.linear = nn.Linear(8, 2)
32
33 def forward(self, a, b):
34 """
35 Run the forward pass - a and b are concatenated and b is flattened
36 """
37 return self.linear(torch.cat([a, b.reshape(-1, 6).float()], axis=1))
38
39
40class TrivialModel(nn.Module):
41 """
42 Wrapper class to create simple non-parametric models with multiple inputs and outputs for ONNX tests
43 """
44
45 def __init__(self, forward_fn):
46 """
47 Initialize with forward pass function passed to the constructor
48 """
49 super().__init__()
50
51
52 self.forward_fn = forward_fn
53
54 def forward(self, *args):
55 """
56 Run the forward pass based on `forward_fn`
57 """
58 return self.forward_fn(*args)
59
60
61if __name__ == "__main__":
62 model = Model()
63 model.load_state_dict({
64 "linear.weight": torch.tensor(
65 [[0.0040, -0.1127, -0.0641, 0.0129, -0.0216, 0.2783, -0.0190, -0.0011],
66 [-0.0772, -0.2133, -0.0243, 0.1520, 0.0784, 0.1187, -0.1681, 0.0372]]
67 ),
68 "linear.bias": torch.tensor([-0.2196, 0.1375]),
69 })
70 a = torch.tensor([[0.5309, 0.4930]])
71 b = torch.tensor([[[1, 0, 1], [1, -1, 0]]])
72
73 torch.set_printoptions(precision=10)
74 with torch.no_grad():
75 print("Outputs to test against for ModelForStandalone.onnx:", model(a, b))
76 torch.onnx.export(
77 model,
78 (a, b),
79 "ModelForStandalone.onnx",
80 input_names=["a", "b"],
81 output_names=["output"],
82 )
83 torch.onnx.export(
84 TrivialModel(lambda a, b: (a, b)),
85 (torch.zeros(4), torch.zeros(4)),
86 "ModelABToAB.onnx",
87 input_names=["input_a", "input_b"],
88 output_names=["output_a", "output_b"],
89 )
90 torch.onnx.export(
91 TrivialModel(lambda a: (a, 2*a)),
92 (torch.zeros(4)),
93 "ModelAToATwiceA.onnx",
94 input_names=["input_a"],
95 output_names=["output_a", "output_twice_a"],
96 )
linear
linear Layer with 8 inputs, 2 outputs
forward_fn
forward pass function