Belle II Software development
save_model_to_payload.py
1#!/usr/bin/env python
2
3
10
11
12import sys
13import torch
14import tempfile
15import basf2 as b2
16
17from ROOT import Belle2
18
19if __name__ == "__main__":
20 model_file = sys.argv[1]
21 config_file = sys.argv[2]
22
23 main = b2.create_path()
24
25 eventinfosetter = b2.register_module('EventInfoSetter')
26 eventinfosetter.param({'evtNumList': [1], 'expList': 0, 'runList': 0})
27 main.add_module(eventinfosetter)
28
30
32 db.addPayload('graFEIConfigFile', config_file, iov)
33
34 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35 model = torch.load(model_file, map_location=torch.device(device))
36
37 with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as temp_file:
38 temp_model_path = temp_file.name
39 torch.save({"model": model["model"]}, temp_model_path)
40 db.addPayload('graFEIModelFile', temp_model_path, iov)
41
42 b2.process(main)
static IntervalOfValidity always()
Function that returns an interval of validity that is always valid, c.f.
static Database & Instance()
Instance of a singleton Database.
Definition Database.cc:41