Belle II Software development
write_test_file.py
1#!/usr/bin/env python3
2
3
10
11# This example creates the file ModelForStandalone.onnx used in tests
12
13import torch
14from torch import nn
15
16
17class Model(nn.Module):
18 """
19 Example model with 2 different input tensors
20 """
21 def __init__(self):
22 """
23 Intialize with a single Linear layer
24 """
25 super().__init__()
26
27
28 self.linear = nn.Linear(8, 2)
29
30 def forward(self, a, b):
31 """
32 Run the forward pass - a and b are concatenated and b is flattened
33 """
34 return self.linear(torch.cat([a, b.reshape(-1, 6).float()], axis=1))
35
36
37if __name__ == "__main__":
38 model = Model()
39 model.load_state_dict({
40 "linear.weight": torch.tensor(
41 [[0.0040, -0.1127, -0.0641, 0.0129, -0.0216, 0.2783, -0.0190, -0.0011],
42 [-0.0772, -0.2133, -0.0243, 0.1520, 0.0784, 0.1187, -0.1681, 0.0372]]
43 ),
44 "linear.bias": torch.tensor([-0.2196, 0.1375]),
45 })
46 a = torch.tensor([[0.5309, 0.4930]])
47 b = torch.tensor([[[1, 0, 1], [1, -1, 0]]])
48
49 torch.set_printoptions(precision=10)
50 print("Outputs to test against:", model(a, b))
51
52 torch.onnx.export(
53 model,
54 (a, b),
55 "ModelForStandalone.onnx",
56 input_names=["a", "b"],
57 output_names=["output"],
58 )
linear
linear Layer with 8 inputs, 2 outputs