12 Generate PXD background samples for background overlay on the fly.
18 from itertools
import product
19 from typing
import Callable, Union
22 from ROOT.Belle2
import DataStore, PyStoreArray, PyStoreObj
23 from ROOT.Belle2
import DBAccessorBase, DBStoreEntry
24 from ROOT.Belle2
import VxdID, PXDDigit
26 from .models
import MODELS
27 from .models
import _get_model_cls, _get_generate_func
37 tuple(product([1], [ladder + 1
for ladder
in range(8)], [1, 2]))
38 + tuple(product([2], [ladder + 1
for ladder
in range(12)], [1, 2]))
50 if not isinstance(model, str):
51 raise TypeError(
"expecting type str `model`")
52 elif model
not in MODELS:
53 options =
", ".join(map(repr, MODELS))
54 raise ValueError(f
"invalid `model`: {model!r} (options: {options}")
66 if not isinstance(checkpoint, (type(
None), str, pathlib.Path)):
67 raise TypeError(
"expecting None or type str `checkpoint`")
68 if checkpoint
is None:
70 checkpoint = os.path.expandvars(str(checkpoint))
71 if not (os.path.exists(checkpoint)
and os.path.isfile(checkpoint)):
72 raise ValueError(f
"invalid `checkpoint`: {checkpoint!r}")
84 if not isinstance(seed, (type(
None), int)):
85 raise TypeError(
"expecting None or type int `seed`")
88 if not -(2 ** 63) <= seed < 2 ** 63:
89 raise ValueError(f
"expecting -2^63 <= `seed` < 2^63 (got: {seed})")
101 if not isinstance(nintra, int):
102 raise TypeError(
"expecting type int `nintra`")
104 raise ValueError(f
"expecting `nintra` > 0 (got: {nintra}")
116 if not isinstance(ninter, int):
117 raise TypeError(
"expecting type int `ninter`")
119 raise ValueError(f
"expecting `ninter` > 0 (got: {ninter}")
130 if not isinstance(globaltag, str):
131 raise TypeError(
"expecting type str `globaltag`")
138 """Generates PXD background samples for background overlay on the fly.
140 :param model: Name of the generator model to use - either "convnet" or "resnet",
141 defaults to "convnet" (optional)
144 :param checkpoint: Path to the checkpoint file with weights for the selected model,
145 defaults to None - use the default checkpoint from the conditions database (optional)
146 :type checkpoint: str
148 :param seed: Integer number in the interval :math:`[-2^{63}, 2^{63} - 1]`
149 used internally as the initial seed,
150 defaults to None - derive a deterministic seed from the
151 value returned by :py:func:`basf2.get_random_seed` (optional)
154 :param nintra: Number of intra-op threads to be utilized for the generation,
155 defaults to 1 (optional)
158 :param ninter: Number of inter-op threads to be utilized for the generation,
159 defaults to 1 (optional)
162 :param globaltag: Global tag of the conditions database
163 providing the default checkpoints stored as payloads,
164 defaults to "PXDBackgroundGenerator" (optional)
193 model: str =
"convnet",
194 checkpoint: Union[
None, str, pathlib.Path] =
None,
195 seed: Union[
None, int] =
None,
198 globaltag: str =
"PXDBackgroundGenerator",
220 basf2.conditions.append_globaltag(self.
globaltagglobaltag)
227 except ImportError
as exc:
228 exc.msg =
"please install PyTorch: `pip3 install torch==1.4.0`"
232 torch.set_num_interop_threads(self.
ninterninter)
235 torch.set_num_threads(self.
nintranintra)
247 payload = f
"PXDBackgroundGenerator_{self.model}"
248 accessor = DBAccessorBase(DBStoreEntry.c_RawFile, payload,
True)
249 checkpoint = accessor.getFilename()
250 self.
_generator_generator.load_state_dict(torch.load(checkpoint, map_location=
"cpu"))
256 for param
in self.
_generator_generator.parameters():
257 param.requires_grad =
False
263 obj = basf2.get_random_seed()
264 func = hashlib.sha512()
265 func.update(bytes(str(obj),
"utf-8"))
266 digest = func.digest()[:8]
267 seed = int.from_bytes(digest,
"big", signed=
True)
268 basf2.B2INFO(f
"PXD background generator seed initialized to {seed}")
269 torch.manual_seed(seed)
272 self.
_vxdids_vxdids = tuple(VxdID(arg)
for arg
in VXDID_ARGS)
275 bkginfo = PyStoreObj(
"BackgroundInfo", DataStore.c_Persistent)
276 if not bkginfo.isValid():
278 basf2.B2ERROR(
"path must contain the BGOverlayInput module")
279 extension = bkginfo.getExtensionName()
282 self.
_pystorearray_pystorearray = PyStoreArray(
"PXDDigits", DataStore.c_DontWriteOut)
283 self.
_pystorearray_pystorearray.registerInDataStore(f
"PXDDigits{extension}")
298 nonzero = x.nonzero(as_tuple=
True)
299 args = [indices.tolist()
for indices
in nonzero]
300 vals = x[nonzero].tolist()
303 for n, (idx, ucell, vcell, charge)
in enumerate(zip(*args, vals)):
307 digit_array[n] = PXDDigit(self.
_vxdids_vxdids[idx], ucell, vcell, charge)
310 del x, nonzero, args, vals
365 """Incorporate a module instance
366 into :py:func:`.add_simulation` after `!BGOverlayInput`.
368 :param module: Module instance to be incorporated,
369 defaults to None - return unmodified function
370 :type module: :py:class:`basf2.Module`, optional
372 :returns: Drop-in replacement function for :py:func:`.add_simulation`
374 from simulation
import add_simulation
378 return add_simulation
379 elif not isinstance(module, basf2.Module):
380 raise TypeError(
"expecting None or type basf2.Module `module`")
382 @functools.wraps(add_simulation)
383 def injected_simulation(path, *args, **kwargs):
385 simulation_path = basf2.Path()
386 add_simulation(simulation_path, *args, **kwargs)
389 for simulation_module
in simulation_path.modules():
391 path.add_module(simulation_module)
393 if simulation_module.name() ==
"BGOverlayInput":
395 path.add_module(module)
397 return injected_simulation
Class for the PXD background generator module.
_generate_func
Generation function applied on the model instance to return an output that is transcoded into digits.
ninter
Number of inter-op threads utilized.
def __init__(self, str model="convnet", Union[None, str, pathlib.Path] checkpoint=None, Union[None, int] seed=None, int nintra=1, int ninter=1, str globaltag="PXDBackgroundGenerator")
Constructor for the PXD background generator module.
model
Name of the generator model.
nintra
Number of intra-op threads utilized.
checkpoint
Path to the checkpoint file with the pre-trained model weights.
globaltag
Global tag of the conditions database providing the default checkpoints stored as payloads.
seed
Integer number in the interval set as the initial seed.
_vxdids
Sequence of identifier objects for each PXD module.
def initialize(self)
Method called before event processing to initialize the module.
_pystorearray
Accessor for PXD background digits in the data store.
_generator
Generator model instance.
def event(self)
Method called each time an event is processed.
double eval(const std::vector< double > &spl, const std::vector< double > &vals, double x)
Evaluate spline (zero order or first order) in point x.
int _verify_nintra(int nintra)
Function to verify that nintra:
str _verify_model(str model)
Function to verify that model:
Callable inject_simulation(Union[None, basf2.Module] module=None)
Helper function to incorporate a module instance into add_simulation after BGOverlayInput.
int _verify_ninter(int ninter)
Function to verify that ninter:
Union[None, int] _verify_seed(Union[None, int] seed)
Function to verify that seed:
str _verify_checkpoint(Union[None, str, pathlib.Path] checkpoint)
Function to verify that checkpoint:
str _verify_globaltag(str globaltag)
Function to verify that globaltag: