12Generate PXD background samples for background overlay on the fly.
18from itertools
import product
19from typing
import Callable, Union
22from ROOT
import Belle2
23from ROOT.Belle2
import DataStore, PyStoreArray, PyStoreObj
24from ROOT.Belle2
import DBAccessorBase, DBStoreEntry
25from ROOT.Belle2
import VxdID, PXDDigit
27from .models
import MODELS
28from .models
import _get_model_cls, _get_generate_func
38 tuple(product([1], [ladder + 1
for ladder
in range(8)], [1, 2]))
39 + tuple(product([2], [ladder + 1
for ladder
in range(12)], [1, 2]))
51 if not isinstance(model, str):
52 raise TypeError(
"expecting type str `model`")
53 elif model
not in MODELS:
54 options =
", ".join(map(repr, MODELS))
55 raise ValueError(f
"invalid `model`: {model!r} (options: {options}")
67 if not isinstance(checkpoint, (type(
None), str, pathlib.Path)):
68 raise TypeError(
"expecting None or type str `checkpoint`")
69 if checkpoint
is None:
71 checkpoint = os.path.expandvars(str(checkpoint))
72 if not (os.path.exists(checkpoint)
and os.path.isfile(checkpoint)):
73 raise ValueError(f
"invalid `checkpoint`: {checkpoint!r}")
85 if not isinstance(seed, (type(
None), int)):
86 raise TypeError(
"expecting None or type int `seed`")
89 if not -(2 ** 63) <= seed < 2 ** 63:
90 raise ValueError(f
"expecting -2^63 <= `seed` < 2^63 (got: {seed})")
102 if not isinstance(nintra, int):
103 raise TypeError(
"expecting type int `nintra`")
105 raise ValueError(f
"expecting `nintra` > 0 (got: {nintra}")
117 if not isinstance(ninter, int):
118 raise TypeError(
"expecting type int `ninter`")
120 raise ValueError(f
"expecting `ninter` > 0 (got: {ninter}")
131 if not isinstance(globaltag, str):
132 raise TypeError(
"expecting type str `globaltag`")
139 """Generates PXD background samples for background overlay on the fly.
141 :param model: Name of the generator model to use - either "convnet" or "resnet",
142 defaults to
"convnet" (optional)
145 :param checkpoint: Path to the checkpoint file
with weights
for the selected model,
146 defaults to
None - use the default checkpoint
from the conditions database (optional)
147 :type checkpoint: str
149 :param seed: Integer number
in the interval :math:`[-2^{63}, 2^{63} - 1]`
150 used internally
as the initial seed,
151 defaults to
None - derive a deterministic seed
from the
152 value returned by :py:func:`basf2.get_random_seed` (optional)
155 :param nintra: Number of intra-op threads to be utilized
for the generation,
156 defaults to 1 (optional)
159 :param ninter: Number of inter-op threads to be utilized
for the generation,
160 defaults to 1 (optional)
163 :param globaltag: Global tag of the conditions database
164 providing the default checkpoints stored
as payloads,
165 defaults to
"PXDBackgroundGenerator" (optional)
194 model: str =
"convnet",
195 checkpoint: Union[
None, str, pathlib.Path] =
None,
196 seed: Union[
None, int] =
None,
199 globaltag: str =
"PXDBackgroundGenerator",
221 basf2.conditions.append_globaltag(self.
globaltag)
228 except ImportError
as exc:
229 exc.msg =
"please install PyTorch: `pip3 install torch==1.4.0`"
233 torch.set_num_interop_threads(self.
ninter)
236 torch.set_num_threads(self.
nintra)
248 payload = f
"PXDBackgroundGenerator_{self.model}"
249 accessor = DBAccessorBase(DBStoreEntry.c_RawFile, payload,
True)
250 checkpoint = accessor.getFilename()
251 self.
_generator.load_state_dict(torch.load(checkpoint, map_location=
"cpu"))
258 param.requires_grad =
False
264 obj = basf2.get_random_seed()
265 func = hashlib.sha512()
266 func.update(bytes(str(obj),
"utf-8"))
267 digest = func.digest()[:8]
268 seed = int.from_bytes(digest,
"big", signed=
True)
269 basf2.B2INFO(f
"PXD background generator seed initialized to {seed}")
270 torch.manual_seed(seed)
273 self.
_vxdids = tuple(VxdID(*arg)
for arg
in VXDID_ARGS)
276 bkginfo = PyStoreObj(
"BackgroundInfo", DataStore.c_Persistent)
277 if not bkginfo.isValid():
279 basf2.B2ERROR(
"path must contain the BGOverlayInput module")
280 extension = bkginfo.getExtensionName()
284 self.
_pystorearray.registerInDataStore(f
"PXDDigits{extension}")
299 nonzero = x.nonzero(as_tuple=
True)
300 args = [indices.tolist()
for indices
in nonzero]
301 vals = x[nonzero].tolist()
304 for n, (idx, ucell, vcell, charge)
in enumerate(zip(*args, vals)):
308 digit_array[n] = PXDDigit(self.
_vxdids[idx], ucell, vcell, charge)
311 del x, nonzero, args, vals
365def inject_simulation(module: Union[
None, basf2.Module] =
None) -> Callable:
366 """Incorporate a module instance
367 into :py:func:`.add_simulation` after `!BGOverlayInput`.
369 :param module: Module instance to be incorporated,
370 defaults to None -
return unmodified function
371 :type module: :py:
class:`basf2.Module`, optional
373 :returns: Drop-
in replacement function
for :py:func:`.add_simulation`
375 from simulation
import add_simulation
379 return add_simulation
380 elif not isinstance(module, basf2.Module):
381 raise TypeError(
"expecting None or type basf2.Module `module`")
383 @functools.wraps(add_simulation)
384 def injected_simulation(path, *args, **kwargs):
386 simulation_path = basf2.Path()
387 add_simulation(simulation_path, *args, **kwargs)
390 for simulation_module
in simulation_path.modules():
392 path.add_module(simulation_module)
394 if simulation_module.name() ==
"BGOverlayInput":
396 path.add_module(module)
398 return injected_simulation
Class for the PXD background generator module.
_generate_func
Generation function applied on the model instance to return an output that is transcoded into digits.
ninter
Number of inter-op threads utilized.
def __init__(self, str model="convnet", Union[None, str, pathlib.Path] checkpoint=None, Union[None, int] seed=None, int nintra=1, int ninter=1, str globaltag="PXDBackgroundGenerator")
Constructor for the PXD background generator module.
model
Name of the generator model.
nintra
Number of intra-op threads utilized.
checkpoint
Path to the checkpoint file with the pre-trained model weights.
globaltag
Global tag of the conditions database providing the default checkpoints stored as payloads.
seed
Integer number in the interval set as the initial seed.
_vxdids
Sequence of identifier objects for each PXD module.
def initialize(self)
Method called before event processing to initialize the module.
_pystorearray
Accessor for PXD background digits in the data store.
_generator
Generator model instance.
def event(self)
Method called each time an event is processed.
str _verify_model(str model)
Function to verify that model:
int _verify_ninter(int ninter)
Function to verify that ninter:
str _verify_globaltag(str globaltag)
Function to verify that globaltag:
Union[None, int] _verify_seed(Union[None, int] seed)
Function to verify that seed:
int _verify_nintra(int nintra)
Function to verify that nintra:
str _verify_checkpoint(Union[None, str, pathlib.Path] checkpoint)
Function to verify that checkpoint: