Belle II Software  light-2403-persian
GraphDataSet Class Reference
Inheritance diagram for GraphDataSet:
Collaboration diagram for GraphDataSet:

Public Member Functions

def __init__ (self, root, n_files=None, samples=None, features=[], edge_features=[], global_features=[], normalize=None, **kwargs)
 
def processed_file_names (self)
 
def process (self)
 

Public Attributes

 root
 Root path.
 
 normalize
 Normalize.
 
 n_files
 Number of files.
 
 node_features
 Node features.
 
 edge_features
 Edge features.
 
 global_features
 Global features.
 
 samples
 Samples.
 
 slices
 Data and Slices.
 

Detailed Description

Dataset handler for converting Belle II data to PyTorch geometric InMemoryDataset.

The ROOT format expects the tree in every file to be named ``Tree``,
and all node features to have the format ``feat_FEATNAME``.

.. note:: This expects the files under root to have the structure ``root/**/<file_name>.root``
    where the root path is different for train and val.
    The ``**/`` is to handle subdirectories, e.g. ``sub00``.

Args:
    root (str): Path to ROOT files.
    n_files (int): Load only ``n_files`` files.
    samples (int): Load only ``samples`` events.
    features (list): List of node features names.
    edge_features (list): List of edge features names.
    global_features (list): List of global features names.
    normalize (bool): Whether to normalize input features.

Definition at line 258 of file geometric_datasets.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  root,
  n_files = None,
  samples = None,
  features = [],
  edge_features = [],
  global_features = [],
  normalize = None,
**  kwargs 
)
Initialization.

Definition at line 279 of file geometric_datasets.py.

289  ):
290  """
291  Initialization.
292  """
293  assert isinstance(
294  features, list
295  ), f'Argument "features" must be a list and not {type(features)}'
296  assert len(features) > 0, "You need to use at least one node feature"
297 
298 
299  self.root = Path(root)
300 
301 
302  self.normalize = normalize
303 
304 
305  self.n_files = n_files
306 
307  self.node_features = features
308 
309  self.edge_features = edge_features
310 
311  self.global_features = global_features
312 
313  self.samples = samples
314 
315  # Delete processed files, in case
316  file_path = Path(self.root, "processed")
317  files = list(file_path.glob("*.pt"))
318  for f in files:
319  f.unlink(missing_ok=True)
320 
321  # Needs to be called after having assigned all attributes
322  super().__init__(root, None, None, None)
323 
324 
325  self.data, self.slices = torch.load(self.processed_paths[0])
326 

Member Function Documentation

◆ process()

def process (   self)
Processes the data to create graph objects and stores them in ``root/processed/processed_data.pt``
where the root path is different for train and val.

Called internally by PyTorch.

Definition at line 334 of file geometric_datasets.py.

◆ processed_file_names()

def processed_file_names (   self)
Name of processed file.

Definition at line 328 of file geometric_datasets.py.


The documentation for this class was generated from the following file: