Belle II Software release-09-00-00
__init__.py
11"""
12Generate PXD background samples for background overlay on the fly.
13"""
14import functools
15import hashlib
16import os.path
17import pathlib
18from itertools import product
19from typing import Callable, Union
20
21import basf2
22from ROOT.Belle2 import DataStore, PyStoreArray, PyStoreObj
23from ROOT.Belle2 import DBAccessorBase, DBStoreEntry
24from ROOT.Belle2 import VxdID, PXDDigit
25
26from .models import MODELS
27from .models import _get_model_cls, _get_generate_func
28
29
36VXDID_ARGS = tuple(
37 tuple(product([1], [ladder + 1 for ladder in range(8)], [1, 2]))
38 + tuple(product([2], [ladder + 1 for ladder in range(12)], [1, 2]))
39)
40
41
42
49def _verify_model(model: str) -> str:
50 if not isinstance(model, str):
51 raise TypeError("expecting type str `model`")
52 elif model not in MODELS:
53 options = ", ".join(map(repr, MODELS))
54 raise ValueError(f"invalid `model`: {model!r} (options: {options}")
55 return model
56
57
58
65def _verify_checkpoint(checkpoint: Union[None, str, pathlib.Path]) -> str:
66 if not isinstance(checkpoint, (type(None), str, pathlib.Path)):
67 raise TypeError("expecting None or type str `checkpoint`")
68 if checkpoint is None:
69 return checkpoint
70 checkpoint = os.path.expandvars(str(checkpoint))
71 if not (os.path.exists(checkpoint) and os.path.isfile(checkpoint)):
72 raise ValueError(f"invalid `checkpoint`: {checkpoint!r}")
73 return checkpoint
74
75
76
83def _verify_seed(seed: Union[None, int]) -> Union[None, int]:
84 if not isinstance(seed, (type(None), int)):
85 raise TypeError("expecting None or type int `seed`")
86 if seed is None:
87 return seed
88 if not -(2 ** 63) <= seed < 2 ** 63:
89 raise ValueError(f"expecting -2^63 <= `seed` < 2^63 (got: {seed})")
90 return seed
91
92
93
100def _verify_nintra(nintra: int) -> int:
101 if not isinstance(nintra, int):
102 raise TypeError("expecting type int `nintra`")
103 elif not nintra > 0:
104 raise ValueError(f"expecting `nintra` > 0 (got: {nintra}")
105 return nintra
106
107
108
115def _verify_ninter(ninter: int) -> int:
116 if not isinstance(ninter, int):
117 raise TypeError("expecting type int `ninter`")
118 elif not ninter > 0:
119 raise ValueError(f"expecting `ninter` > 0 (got: {ninter}")
120 return ninter
121
122
123
129def _verify_globaltag(globaltag: str) -> str:
130 if not isinstance(globaltag, str):
131 raise TypeError("expecting type str `globaltag`")
132 return globaltag
133
134
135
137class PXDBackgroundGenerator(basf2.Module):
138 """Generates PXD background samples for background overlay on the fly.
139
140 :param model: Name of the generator model to use - either "convnet" or "resnet",
141 defaults to "convnet" (optional)
142 :type model: str
143
144 :param checkpoint: Path to the checkpoint file with weights for the selected model,
145 defaults to None - use the default checkpoint from the conditions database (optional)
146 :type checkpoint: str
147
148 :param seed: Integer number in the interval :math:`[-2^{63}, 2^{63} - 1]`
149 used internally as the initial seed,
150 defaults to None - derive a deterministic seed from the
151 value returned by :py:func:`basf2.get_random_seed` (optional)
152 :type seed: int
153
154 :param nintra: Number of intra-op threads to be utilized for the generation,
155 defaults to 1 (optional)
156 :type nintra: int
157
158 :param ninter: Number of inter-op threads to be utilized for the generation,
159 defaults to 1 (optional)
160 :type ninter: int
161
162 :param globaltag: Global tag of the conditions database
163 providing the default checkpoints stored as payloads,
164 defaults to "PXDBackgroundGenerator" (optional)
165 :type globaltag: str
166 """
167
168
192 self,
193 model: str = "convnet",
194 checkpoint: Union[None, str, pathlib.Path] = None,
195 seed: Union[None, int] = None,
196 nintra: int = 1,
197 ninter: int = 1,
198 globaltag: str = "PXDBackgroundGenerator",
199 ):
200 super().__init__()
201 # process `model`
202 self.model = _verify_model(model)
203
204 # process `checkpoint`
206
207 # process `seed`
208 self.seed = _verify_seed(seed)
209
210 # process `nintra`
211 self.nintra = _verify_nintra(nintra)
212
213 # process `ninter`
214 self.ninter = _verify_ninter(ninter)
215
216 # process `globaltag`
217 self.globaltag = _verify_globaltag(globaltag)
218
219 # enable the specified global tag
220 basf2.conditions.append_globaltag(self.globaltag)
221
222
224 def initialize(self):
225 try:
226 import torch
227 except ImportError as exc:
228 exc.msg = "please install PyTorch: `pip3 install torch==1.4.0`"
229 raise
230
231 # set the number of inter-op CPU threads
232 torch.set_num_interop_threads(self.ninter)
233
234 # set the number of intra-op CPU threads
235 torch.set_num_threads(self.nintra)
236
237 # get the generation function for the selected model
238 self._generate_func = _get_generate_func(self.model)
239
240 # instantiate the generator model
241 self._generator = _get_model_cls(self.model)()
242
243 # initialize the model weights
244 checkpoint = self.checkpoint
245 if self.checkpoint is None:
246 # use the default checkpoint from the conditions database
247 payload = f"PXDBackgroundGenerator_{self.model}"
248 accessor = DBAccessorBase(DBStoreEntry.c_RawFile, payload, True)
249 checkpoint = accessor.getFilename()
250 self._generator.load_state_dict(torch.load(checkpoint, map_location="cpu"))
251
252 # set mode of operation to inference
253 self._generator.eval()
254
255 # disable the computation of gradients
256 for param in self._generator.parameters():
257 param.requires_grad = False
258
259 # initialize the seed
260 seed = self.seed
261 if seed is None:
262 # derive from the basf2 seed value
263 obj = basf2.get_random_seed()
264 func = hashlib.sha512()
265 func.update(bytes(str(obj), "utf-8"))
266 digest = func.digest()[:8]
267 seed = int.from_bytes(digest, "big", signed=True)
268 basf2.B2INFO(f"PXD background generator seed initialized to {seed}")
269 torch.manual_seed(seed)
270
271 # instantiate objects for specifying distinct PXD modules
272 self._vxdids = tuple(VxdID(*arg) for arg in VXDID_ARGS)
273
274 # get the name of the extension used by BGOverlayInput for background collections
275 bkginfo = PyStoreObj("BackgroundInfo", DataStore.c_Persistent)
276 if not bkginfo.isValid():
277 # no information about background overlay input available
278 basf2.B2ERROR("path must contain the BGOverlayInput module")
279 extension = bkginfo.getExtensionName()
280
281 # register the PXD background digit collection - array - in the data store
282 self._pystorearray = PyStoreArray("PXDDigits", DataStore.c_DontWriteOut)
283 self._pystorearray.registerInDataStore(f"PXDDigits{extension}")
284
285
287 def event(self):
288 # get the low-level array accessor
289 digit_array = self._pystorearray.getPtr()
290
291 # clear digits stored by BGOverlayInput
292 digit_array.Clear()
293
294 # generate a batch of images - one image for each PXD module
295 x = self._generate_func(self._generator)
296
297 # locate indices of pixels with non-zero values - pixel hits
298 nonzero = x.nonzero(as_tuple=True)
299 args = [indices.tolist() for indices in nonzero]
300 vals = x[nonzero].tolist()
301
302 # store indices and pixel values into the data store as background digits
303 for n, (idx, ucell, vcell, charge) in enumerate(zip(*args, vals)):
304 # append a new default digit to expand the array
305 self._pystorearray.appendNew()
306 # modify the array to point to the correct digit
307 digit_array[n] = PXDDigit(self._vxdids[idx], ucell, vcell, charge)
308
309 # delete references to release memory
310 del x, nonzero, args, vals
311
312
315
316
319
320
324
325
328
329
332
333
337
338
341
342
346
347
350
351
354
355
356
364def inject_simulation(module: Union[None, basf2.Module] = None) -> Callable:
365 """Incorporate a module instance
366 into :py:func:`.add_simulation` after `!BGOverlayInput`.
367
368 :param module: Module instance to be incorporated,
369 defaults to None - return unmodified function
370 :type module: :py:class:`basf2.Module`, optional
371
372 :returns: Drop-in replacement function for :py:func:`.add_simulation`
373 """
374 from simulation import add_simulation
375
376 if module is None:
377 # no modifications necessary
378 return add_simulation
379 elif not isinstance(module, basf2.Module):
380 raise TypeError("expecting None or type basf2.Module `module`")
381
382 @functools.wraps(add_simulation)
383 def injected_simulation(path, *args, **kwargs):
384 # create a separate path with simulation modules
385 simulation_path = basf2.Path()
386 add_simulation(simulation_path, *args, **kwargs)
387
388 # manually add the simulation modules to the given path
389 for simulation_module in simulation_path.modules():
390 # append the next module from the simulation path
391 path.add_module(simulation_module)
392
393 if simulation_module.name() == "BGOverlayInput":
394 # incorporate the given module
395 path.add_module(module)
396
397 return injected_simulation
Class for the PXD background generator module.
Definition: __init__.py:137
_generate_func
Generation function applied on the model instance to return an output that is transcoded into digits.
Definition: __init__.py:238
ninter
Number of inter-op threads utilized.
Definition: __init__.py:214
def __init__(self, str model="convnet", Union[None, str, pathlib.Path] checkpoint=None, Union[None, int] seed=None, int nintra=1, int ninter=1, str globaltag="PXDBackgroundGenerator")
Constructor for the PXD background generator module.
Definition: __init__.py:199
nintra
Number of intra-op threads utilized.
Definition: __init__.py:211
checkpoint
Path to the checkpoint file with the pre-trained model weights.
Definition: __init__.py:205
globaltag
Global tag of the conditions database providing the default checkpoints stored as payloads.
Definition: __init__.py:217
seed
Integer number in the interval set as the initial seed.
Definition: __init__.py:208
_vxdids
Sequence of identifier objects for each PXD module.
Definition: __init__.py:272
def initialize(self)
Method called before event processing to initialize the module.
Definition: __init__.py:224
_pystorearray
Accessor for PXD background digits in the data store.
Definition: __init__.py:282
def event(self)
Method called each time an event is processed.
Definition: __init__.py:287
str _verify_model(str model)
Function to verify that model:
Definition: __init__.py:49
int _verify_ninter(int ninter)
Function to verify that ninter:
Definition: __init__.py:115
str _verify_globaltag(str globaltag)
Function to verify that globaltag:
Definition: __init__.py:129
Union[None, int] _verify_seed(Union[None, int] seed)
Function to verify that seed:
Definition: __init__.py:83
int _verify_nintra(int nintra)
Function to verify that nintra:
Definition: __init__.py:100
str _verify_checkpoint(Union[None, str, pathlib.Path] checkpoint)
Function to verify that checkpoint:
Definition: __init__.py:65