##########################################################################
# basf2 (Belle II Analysis Software Framework) #
# Author: The Belle II Collaboration #
# #
# See git log for contributors and copyright holders. #
# This file is licensed under LGPL-3.0, see LICENSE.md. #
##########################################################################
##
# @package pxd.background_generator
# Generate PXD background samples for background overlay on the fly.
"""
Generate PXD background samples for background overlay on the fly.
"""
import functools
import hashlib
import os.path
import pathlib
from itertools import product
from typing import Callable, Union
import basf2
from ROOT.Belle2 import DataStore, PyStoreArray, PyStoreObj
from ROOT.Belle2 import DBAccessorBase, DBStoreEntry
from ROOT.Belle2 import VxdID, PXDDigit
from .models import MODELS
from .models import _get_model_cls, _get_generate_func
##
# Sequence of forty tuples `(layer, ladder, sensor)` used to
# instantiate `VxdID` specifier objects for distinct PXD modules.
#
# It is assumed that the indices of tuples in the sequence match
# indices along the first axis of tensors with shape `(40, 250, 768)`
# that are produced by the generator.
VXDID_ARGS = tuple(
tuple(product([1], [ladder + 1 for ladder in range(8)], [1, 2]))
+ tuple(product([2], [ladder + 1 for ladder in range(12)], [1, 2]))
)
##
# Function to verify that `model`:
# * is a string,
# * is a valid name for a model that is available for selection.
#
# The value of `model` is returned if the conditions are met.
# An exception is raised otherwise.
def _verify_model(model: str) -> str:
if not isinstance(model, str):
raise TypeError("expecting type str `model`")
elif model not in MODELS:
options = ", ".join(map(repr, MODELS))
raise ValueError(f"invalid `model`: {model!r} (options: {options}")
return model
##
# Function to verify that `checkpoint`:
# * is either None, a string, or a `pathlib.Path` object,
# * is a valid path to an existing file - if not None.
#
# The value of `checkpoint` is returned if the conditions are met.
# An exception is raised otherwise.
def _verify_checkpoint(checkpoint: Union[None, str, pathlib.Path]) -> str:
if not isinstance(checkpoint, (type(None), str, pathlib.Path)):
raise TypeError("expecting None or type str `checkpoint`")
if checkpoint is None:
return checkpoint
checkpoint = os.path.expandvars(str(checkpoint))
if not (os.path.exists(checkpoint) and os.path.isfile(checkpoint)):
raise ValueError(f"invalid `checkpoint`: {checkpoint!r}")
return checkpoint
##
# Function to verify that `seed`:
# * is either None or an integer,
# * is in the interval \f$ [-2^{63}, 2^{63} - 1] \f$ - if not None.
#
# The value of `seed` is returned if the conditions are met.
# An exception is raised otherwise.
def _verify_seed(seed: Union[None, int]) -> Union[None, int]:
if not isinstance(seed, (type(None), int)):
raise TypeError("expecting None or type int `seed`")
if seed is None:
return seed
if not -(2 ** 63) <= seed < 2 ** 63:
raise ValueError(f"expecting -2^63 <= `seed` < 2^63 (got: {seed})")
return seed
##
# Function to verify that `nintra`:
# * is an integer,
# * is larger than zero.
#
# The value of `nintra` is returned if the conditions are met.
# An exception is raised otherwise.
def _verify_nintra(nintra: int) -> int:
if not isinstance(nintra, int):
raise TypeError("expecting type int `nintra`")
elif not nintra > 0:
raise ValueError(f"expecting `nintra` > 0 (got: {nintra}")
return nintra
##
# Function to verify that `ninter`:
# * is an integer,
# * is larger than zero.
#
# The value of `ninter` is returned if the conditions are met.
# An exception is raised otherwise.
def _verify_ninter(ninter: int) -> int:
if not isinstance(ninter, int):
raise TypeError("expecting type int `ninter`")
elif not ninter > 0:
raise ValueError(f"expecting `ninter` > 0 (got: {ninter}")
return ninter
##
# Function to verify that `globaltag`:
# * is a string.
#
# The value of `globaltag` is returned if the condition is met.
# An exception is raised otherwise.
def _verify_globaltag(globaltag: str) -> str:
if not isinstance(globaltag, str):
raise TypeError("expecting type str `globaltag`")
return globaltag
##
# Class for the PXD background generator module.
[docs]class PXDBackgroundGenerator(basf2.Module):
"""Generates PXD background samples for background overlay on the fly.
:param model: Name of the generator model to use - either "convnet" or "resnet",
defaults to "convnet" (optional)
:type model: str
:param checkpoint: Path to the checkpoint file with weights for the selected model,
defaults to None - use the default checkpoint from the conditions database (optional)
:type checkpoint: str
:param seed: Integer number in the interval :math:`[-2^{63}, 2^{63} - 1]`
used internally as the initial seed,
defaults to None - derive a deterministic seed from the
value returned by :py:func:`basf2.get_random_seed` (optional)
:type seed: int
:param nintra: Number of intra-op threads to be utilized for the generation,
defaults to 1 (optional)
:type nintra: int
:param ninter: Number of inter-op threads to be utilized for the generation,
defaults to 1 (optional)
:type ninter: int
:param globaltag: Global tag of the conditions database
providing the default checkpoints stored as payloads,
defaults to "PXDBackgroundGenerator" (optional)
:type globaltag: str
"""
##
# Constructor for the PXD background generator module.
#
# @param model: Name of the generator model to use - either "convnet" or "resnet",
# defaults to "convnet"
#
# @param checkpoint: Path to the checkpoint file with weights for the selected model,
# defaults to None - use the default checkpoint from the conditions database
#
# @param seed: Integer number in the interval \f$ [-2^{63}, 2^{63} - 1] \f$
# used internally as the initial seed,
# defaults to None - derive a deterministic seed from the
# value returned by `basf2.get_random_seed()`
#
# @param nintra: Number of intra-op threads to be utilized for the generation,
# defaults to 1
#
# @param ninter: Number of inter-op threads to be utilized for the generation,
# defaults to 1
#
# @param globaltag: Global tag of the conditions database
# providing the default checkpoints stored as payloads,
# defaults to "PXDBackgroundGenerator"
def __init__(
self,
model: str = "convnet",
checkpoint: Union[None, str, pathlib.Path] = None,
seed: Union[None, int] = None,
nintra: int = 1,
ninter: int = 1,
globaltag: str = "PXDBackgroundGenerator",
):
super().__init__()
# process `model`
self.model = _verify_model(model)
# process `checkpoint`
self.checkpoint = _verify_checkpoint(checkpoint)
# process `seed`
self.seed = _verify_seed(seed)
# process `nintra`
self.nintra = _verify_nintra(nintra)
# process `ninter`
self.ninter = _verify_ninter(ninter)
# process `globaltag`
self.globaltag = _verify_globaltag(globaltag)
# enable the specified global tag
basf2.conditions.append_globaltag(self.globaltag)
##
# Method called before event processing to initialize the module.
def initialize(self):
try:
import torch
except ImportError as exc:
exc.msg = "please install PyTorch: `pip3 install torch==1.4.0`"
raise
# set the number of inter-op CPU threads
torch.set_num_interop_threads(self.ninter)
# set the number of intra-op CPU threads
torch.set_num_threads(self.nintra)
# get the generation function for the selected model
self._generate_func = _get_generate_func(self.model)
# instantiate the generator model
self._generator = _get_model_cls(self.model)()
# initialize the model weights
checkpoint = self.checkpoint
if self.checkpoint is None:
# use the default checkpoint from the conditions database
payload = f"PXDBackgroundGenerator_{self.model}"
accessor = DBAccessorBase(DBStoreEntry.c_RawFile, payload, True)
checkpoint = accessor.getFilename()
self._generator.load_state_dict(torch.load(checkpoint, map_location="cpu"))
# set mode of operation to inference
self._generator.eval()
# disable the computation of gradients
for param in self._generator.parameters():
param.requires_grad = False
# initialize the seed
seed = self.seed
if seed is None:
# derive from the basf2 seed value
obj = basf2.get_random_seed()
func = hashlib.sha512()
func.update(bytes(str(obj), "utf-8"))
digest = func.digest()[:8]
seed = int.from_bytes(digest, "big", signed=True)
basf2.B2INFO(f"PXD background generator seed initialized to {seed}")
torch.manual_seed(seed)
# instantiate objects for specifying distinct PXD modules
self._vxdids = tuple(VxdID(arg) for arg in VXDID_ARGS)
# get the name of the extension used by BGOverlayInput for background collections
bkginfo = PyStoreObj("BackgroundInfo", DataStore.c_Persistent)
if not bkginfo.isValid():
# no information about background overlay input available
basf2.B2ERROR("path must contain the BGOverlayInput module")
extension = bkginfo.getExtensionName()
# register the PXD background digit collection - array - in the data store
self._pystorearray = PyStoreArray("PXDDigits", DataStore.c_DontWriteOut)
self._pystorearray.registerInDataStore(f"PXDDigits{extension}")
##
# Method called each time an event is processed.
def event(self):
# get the low-level array accessor
digit_array = self._pystorearray.getPtr()
# clear digits stored by BGOverlayInput
digit_array.Clear()
# generate a batch of images - one image for each PXD module
x = self._generate_func(self._generator)
# locate indices of pixels with non-zero values - pixel hits
nonzero = x.nonzero(as_tuple=True)
args = [indices.tolist() for indices in nonzero]
vals = x[nonzero].tolist()
# store indices and pixel values into the data store as background digits
for n, (idx, ucell, vcell, charge) in enumerate(zip(*args, vals)):
# append a new default digit to expand the array
self._pystorearray.appendNew()
# modify the array to point to the correct digit
digit_array[n] = PXDDigit(self._vxdids[idx], ucell, vcell, charge)
# delete references to release memory
del x, nonzero, args, vals
##
# @var model
# Name of the generator model
##
# @var checkpoint
# Path to the checkpoint file with the pre-trained model weights
##
# @var seed
# Integer number in the interval \f$ [-2^{63}, 2^{63} - 1] \f$
# set as the initial seed
##
# @var nintra
# Number of intra-op threads utilized
##
# @var ninter
# Number of inter-op threads utilized
##
# @var globaltag
# Global tag of the conditions database
# providing the default checkpoints stored as payloads
##
# @var _generator
# Generator model instance
##
# @var _generate_func
# Generation function applied on the model instance to
# return an output that is transcoded into digits
##
# @var _pystorearray
# Accessor for PXD background digits in the data store
##
# @var _vxdids
# Sequence of identifier objects for each PXD module
##
# Helper function to incorporate a module instance
# into `add_simulation` after `BGOverlayInput`.
#
# @param module: Module instance to be incorporated,
# defaults to None - return unmodified function
#
# @return Drop-in replacement function for `add_simulation`
[docs]def inject_simulation(module: Union[None, basf2.Module] = None) -> Callable:
"""Incorporate a module instance
into :py:func:`.add_simulation` after `!BGOverlayInput`.
:param module: Module instance to be incorporated,
defaults to None - return unmodified function
:type module: :py:class:`basf2.Module`, optional
:returns: Drop-in replacement function for :py:func:`.add_simulation`
"""
from simulation import add_simulation
if module is None:
# no modifications necessary
return add_simulation
elif not isinstance(module, basf2.Module):
raise TypeError("expecting None or type basf2.Module `module`")
@functools.wraps(add_simulation)
def injected_simulation(path, *args, **kwargs):
# create a separate path with simulation modules
simulation_path = basf2.Path()
add_simulation(simulation_path, *args, **kwargs)
# manually add the simulation modules to the given path
for simulation_module in simulation_path.modules():
# append the next module from the simulation path
path.add_module(simulation_module)
if simulation_module.name() == "BGOverlayInput":
# incorporate the given module
path.add_module(module)
return injected_simulation