21 Example model with 2 different input tensors
26 Intialize with a single Linear layer
35 Run the forward pass - a and b are concatenated and b is flattened
37 return self.
linear(torch.cat([a, b.reshape(-1, 6).float()], axis=1))
42 Wrapper class to create simple non-parametric models with multiple inputs and outputs for ONNX tests
47 Initialize with forward pass function passed to the constructor
56 Run the forward pass based on `forward_fn`
61if __name__ ==
"__main__":
63 model.load_state_dict({
64 "linear.weight": torch.tensor(
65 [[0.0040, -0.1127, -0.0641, 0.0129, -0.0216, 0.2783, -0.0190, -0.0011],
66 [-0.0772, -0.2133, -0.0243, 0.1520, 0.0784, 0.1187, -0.1681, 0.0372]]
68 "linear.bias": torch.tensor([-0.2196, 0.1375]),
70 a = torch.tensor([[0.5309, 0.4930]])
71 b = torch.tensor([[[1, 0, 1], [1, -1, 0]]])
73 torch.set_printoptions(precision=10)
75 print(
"Outputs to test against for ModelForStandalone.onnx:", model(a, b))
79 "ModelForStandalone.onnx",
80 input_names=[
"a",
"b"],
81 output_names=[
"output"],
85 (torch.zeros(4), torch.zeros(4)),
87 input_names=[
"input_a",
"input_b"],
88 output_names=[
"output_a",
"output_b"],
93 "ModelAToATwiceA.onnx",
94 input_names=[
"input_a"],
95 output_names=[
"output_a",
"output_twice_a"],
linear
linear Layer with 8 inputs, 2 outputs
__init__(self, forward_fn)
forward_fn
forward pass function