Belle II Software development
__init__.py
11"""
12Generate PXD background samples for background overlay on the fly.
13"""
14import functools
15import hashlib
16import os.path
17import pathlib
18from itertools import product
19from typing import Callable, Union
20
21import basf2
22from ROOT import Belle2 # noqa: make Belle2 namespace available
23from ROOT.Belle2 import DataStore, PyStoreArray, PyStoreObj
24from ROOT.Belle2 import DBAccessorBase, DBStoreEntry
25from ROOT.Belle2 import VxdID, PXDDigit
26
27from .models import MODELS
28from .models import _get_model_cls, _get_generate_func
29
30
37VXDID_ARGS = tuple(
38 tuple(product([1], [ladder + 1 for ladder in range(8)], [1, 2]))
39 + tuple(product([2], [ladder + 1 for ladder in range(12)], [1, 2]))
40)
41
42
43
50def _verify_model(model: str) -> str:
51 if not isinstance(model, str):
52 raise TypeError("expecting type str `model`")
53 elif model not in MODELS:
54 options = ", ".join(map(repr, MODELS))
55 raise ValueError(f"invalid `model`: {model!r} (options: {options}")
56 return model
57
58
59
66def _verify_checkpoint(checkpoint: Union[None, str, pathlib.Path]) -> str:
67 if not isinstance(checkpoint, (type(None), str, pathlib.Path)):
68 raise TypeError("expecting None or type str `checkpoint`")
69 if checkpoint is None:
70 return checkpoint
71 checkpoint = os.path.expandvars(str(checkpoint))
72 if not (os.path.exists(checkpoint) and os.path.isfile(checkpoint)):
73 raise ValueError(f"invalid `checkpoint`: {checkpoint!r}")
74 return checkpoint
75
76
77
84def _verify_seed(seed: Union[None, int]) -> Union[None, int]:
85 if not isinstance(seed, (type(None), int)):
86 raise TypeError("expecting None or type int `seed`")
87 if seed is None:
88 return seed
89 if not -(2 ** 63) <= seed < 2 ** 63:
90 raise ValueError(f"expecting -2^63 <= `seed` < 2^63 (got: {seed})")
91 return seed
92
93
94
101def _verify_nintra(nintra: int) -> int:
102 if not isinstance(nintra, int):
103 raise TypeError("expecting type int `nintra`")
104 elif not nintra > 0:
105 raise ValueError(f"expecting `nintra` > 0 (got: {nintra}")
106 return nintra
107
108
109
116def _verify_ninter(ninter: int) -> int:
117 if not isinstance(ninter, int):
118 raise TypeError("expecting type int `ninter`")
119 elif not ninter > 0:
120 raise ValueError(f"expecting `ninter` > 0 (got: {ninter}")
121 return ninter
122
123
124
130def _verify_globaltag(globaltag: str) -> str:
131 if not isinstance(globaltag, str):
132 raise TypeError("expecting type str `globaltag`")
133 return globaltag
134
135
136
138class PXDBackgroundGenerator(basf2.Module):
139 """Generates PXD background samples for background overlay on the fly.
140
141 :param model: Name of the generator model to use - either "convnet" or "resnet",
142 defaults to "convnet" (optional)
143 :type model: str
144
145 :param checkpoint: Path to the checkpoint file with weights for the selected model,
146 defaults to None - use the default checkpoint from the conditions database (optional)
147 :type checkpoint: str
148
149 :param seed: Integer number in the interval :math:`[-2^{63}, 2^{63} - 1]`
150 used internally as the initial seed,
151 defaults to None - derive a deterministic seed from the
152 value returned by :py:func:`basf2.get_random_seed` (optional)
153 :type seed: int
154
155 :param nintra: Number of intra-op threads to be utilized for the generation,
156 defaults to 1 (optional)
157 :type nintra: int
158
159 :param ninter: Number of inter-op threads to be utilized for the generation,
160 defaults to 1 (optional)
161 :type ninter: int
162
163 :param globaltag: Global tag of the conditions database
164 providing the default checkpoints stored as payloads,
165 defaults to "PXDBackgroundGenerator" (optional)
166 :type globaltag: str
167 """
168
169
193 self,
194 model: str = "convnet",
195 checkpoint: Union[None, str, pathlib.Path] = None,
196 seed: Union[None, int] = None,
197 nintra: int = 1,
198 ninter: int = 1,
199 globaltag: str = "PXDBackgroundGenerator",
200 ):
201 super().__init__()
202 # process `model`
203 self.model = _verify_model(model)
204
205 # process `checkpoint`
207
208 # process `seed`
209 self.seed = _verify_seed(seed)
210
211 # process `nintra`
212 self.nintra = _verify_nintra(nintra)
213
214 # process `ninter`
215 self.ninter = _verify_ninter(ninter)
216
217 # process `globaltag`
218 self.globaltag = _verify_globaltag(globaltag)
219
220 # enable the specified global tag
221 basf2.conditions.append_globaltag(self.globaltag)
222
223
225 def initialize(self):
226 try:
227 import torch
228 except ImportError as exc:
229 exc.msg = "please install PyTorch: `pip3 install torch==1.4.0`"
230 raise
231
232 # set the number of inter-op CPU threads
233 torch.set_num_interop_threads(self.ninter)
234
235 # set the number of intra-op CPU threads
236 torch.set_num_threads(self.nintra)
237
238 # get the generation function for the selected model
239 self._generate_func = _get_generate_func(self.model)
240
241 # instantiate the generator model
242 self._generator = _get_model_cls(self.model)()
243
244 # initialize the model weights
245 checkpoint = self.checkpoint
246 if self.checkpoint is None:
247 # use the default checkpoint from the conditions database
248 payload = f"PXDBackgroundGenerator_{self.model}"
249 accessor = DBAccessorBase(DBStoreEntry.c_RawFile, payload, True)
250 checkpoint = accessor.getFilename()
251 self._generator.load_state_dict(torch.load(checkpoint, map_location="cpu"))
252
253 # set mode of operation to inference
254 self._generator.eval()
255
256 # disable the computation of gradients
257 for param in self._generator.parameters():
258 param.requires_grad = False
259
260 # initialize the seed
261 seed = self.seed
262 if seed is None:
263 # derive from the basf2 seed value
264 obj = basf2.get_random_seed()
265 func = hashlib.sha512()
266 func.update(bytes(str(obj), "utf-8"))
267 digest = func.digest()[:8]
268 seed = int.from_bytes(digest, "big", signed=True)
269 basf2.B2INFO(f"PXD background generator seed initialized to {seed}")
270 torch.manual_seed(seed)
271
272 # instantiate objects for specifying distinct PXD modules
273 self._vxdids = tuple(VxdID(*arg) for arg in VXDID_ARGS)
274
275 # get the name of the extension used by BGOverlayInput for background collections
276 bkginfo = PyStoreObj("BackgroundInfo", DataStore.c_Persistent)
277 if not bkginfo.isValid():
278 # no information about background overlay input available
279 basf2.B2ERROR("path must contain the BGOverlayInput module")
280 extension = bkginfo.getExtensionName()
281
282 # register the PXD background digit collection - array - in the data store
283 self._pystorearray = PyStoreArray("PXDDigits", DataStore.c_DontWriteOut)
284 self._pystorearray.registerInDataStore(f"PXDDigits{extension}")
285
286
288 def event(self):
289 # get the low-level array accessor
290 digit_array = self._pystorearray.getPtr()
291
292 # clear digits stored by BGOverlayInput
293 digit_array.Clear()
294
295 # generate a batch of images - one image for each PXD module
296 x = self._generate_func(self._generator)
297
298 # locate indices of pixels with non-zero values - pixel hits
299 nonzero = x.nonzero(as_tuple=True)
300 args = [indices.tolist() for indices in nonzero]
301 vals = x[nonzero].tolist()
302
303 # store indices and pixel values into the data store as background digits
304 for n, (idx, ucell, vcell, charge) in enumerate(zip(*args, vals)):
305 # append a new default digit to expand the array
306 self._pystorearray.appendNew()
307 # modify the array to point to the correct digit
308 digit_array[n] = PXDDigit(self._vxdids[idx], ucell, vcell, charge)
309
310 # delete references to release memory
311 del x, nonzero, args, vals
312
313
316
317
320
321
325
326
329
330
333
334
338
339
342
343
347
348
351
352
355
356
357
365def inject_simulation(module: Union[None, basf2.Module] = None) -> Callable:
366 """Incorporate a module instance
367 into :py:func:`.add_simulation` after `!BGOverlayInput`.
368
369 :param module: Module instance to be incorporated,
370 defaults to None - return unmodified function
371 :type module: :py:class:`basf2.Module`, optional
372
373 :returns: Drop-in replacement function for :py:func:`.add_simulation`
374 """
375 from simulation import add_simulation
376
377 if module is None:
378 # no modifications necessary
379 return add_simulation
380 elif not isinstance(module, basf2.Module):
381 raise TypeError("expecting None or type basf2.Module `module`")
382
383 @functools.wraps(add_simulation)
384 def injected_simulation(path, *args, **kwargs):
385 # create a separate path with simulation modules
386 simulation_path = basf2.Path()
387 add_simulation(simulation_path, *args, **kwargs)
388
389 # manually add the simulation modules to the given path
390 for simulation_module in simulation_path.modules():
391 # append the next module from the simulation path
392 path.add_module(simulation_module)
393
394 if simulation_module.name() == "BGOverlayInput":
395 # incorporate the given module
396 path.add_module(module)
397
398 return injected_simulation
Class for the PXD background generator module.
Definition: __init__.py:138
_generate_func
Generation function applied on the model instance to return an output that is transcoded into digits.
Definition: __init__.py:239
ninter
Number of inter-op threads utilized.
Definition: __init__.py:215
def __init__(self, str model="convnet", Union[None, str, pathlib.Path] checkpoint=None, Union[None, int] seed=None, int nintra=1, int ninter=1, str globaltag="PXDBackgroundGenerator")
Constructor for the PXD background generator module.
Definition: __init__.py:200
nintra
Number of intra-op threads utilized.
Definition: __init__.py:212
checkpoint
Path to the checkpoint file with the pre-trained model weights.
Definition: __init__.py:206
globaltag
Global tag of the conditions database providing the default checkpoints stored as payloads.
Definition: __init__.py:218
seed
Integer number in the interval set as the initial seed.
Definition: __init__.py:209
_vxdids
Sequence of identifier objects for each PXD module.
Definition: __init__.py:273
def initialize(self)
Method called before event processing to initialize the module.
Definition: __init__.py:225
_pystorearray
Accessor for PXD background digits in the data store.
Definition: __init__.py:283
def event(self)
Method called each time an event is processed.
Definition: __init__.py:288
str _verify_model(str model)
Function to verify that model:
Definition: __init__.py:50
int _verify_ninter(int ninter)
Function to verify that ninter:
Definition: __init__.py:116
str _verify_globaltag(str globaltag)
Function to verify that globaltag:
Definition: __init__.py:130
Union[None, int] _verify_seed(Union[None, int] seed)
Function to verify that seed:
Definition: __init__.py:84
int _verify_nintra(int nintra)
Function to verify that nintra:
Definition: __init__.py:101
str _verify_checkpoint(Union[None, str, pathlib.Path] checkpoint)
Function to verify that checkpoint:
Definition: __init__.py:66