Belle II Software development
GraFEIModel Class Reference
Inheritance diagram for GraFEIModel:

Public Member Functions

def __init__ (self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, edge_classes=6, x_classes=7, hidden_layer_dim=128, num_hid_layers=1, num_ML=1, dropout=0.0, global_layer=True, **kwargs)
 
def forward (self, batch)
 

Public Attributes

 first_ML
 First MetaLayer.
 
 ML_list
 Intermediate MetaLayers.
 
 last_ML
 Output MetaLayer.
 

Detailed Description

Actual implementation of the model,
based on the
`MetaLayer <https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.MetaLayer.html>`_
class.

.. seealso::
    `Relational inductive biases, deep learning, and graph networks <https://arxiv.org/abs/1806.01261>`_

The network is composed of:

1. A first MetaLayer to increase the number of nodes and edges features;
2. A number of intermediate MetaLayers (tunable in config file);
3. A last MetaLayer to decrease the number of node and edge features to the desired output dimension.

.. figure:: figs/graFEI.png
    :width: 42em
    :align: center

Each MetaLayer is in turn composed of `EdgeLayer`, `NodeLayer` and `GlobalLayer` sub-blocks.

Args:
    nfeat_in_dim (int): Node features dimension (number of input node features).
    efeat_in_dim (int): Edge features dimension (number of input edge features).
    gfeat_in_dim (int): Global features dimension (number of input global features).
    edge_classes (int): Edge features output dimension (i.e. number of different edge labels in the LCAS matrix).
    x_classes (int): Node features output dimension (i.e. number of different mass hypotheses).
    hidden_layer_dim (int): Intermediate features dimension (same for node, edge and global).
    num_hid_layers (int): Number of hidden layers in every MetaLayer.
    num_ML (int): Number of intermediate MetaLayers.
    dropout (float): Dropout rate :math:`r \\in [0,1]`.
    global_layer (bool): Whether to use global layer.

:return: Node, edge and global features after model evaluation.
:rtype: tuple(`Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_)

Definition at line 15 of file geometric_network.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  nfeat_in_dim,
  efeat_in_dim,
  gfeat_in_dim,
  edge_classes = 6,
  x_classes = 7,
  hidden_layer_dim = 128,
  num_hid_layers = 1,
  num_ML = 1,
  dropout = 0.0,
  global_layer = True,
**  kwargs 
)
Initialization.

Definition at line 53 of file geometric_network.py.

66 ):
67 """
68 Initialization.
69 """
70 super(GraFEIModel, self).__init__()
71
72
73 self.first_ML = MetaLayer(
74 EdgeLayer(
75 nfeat_in_dim,
76 efeat_in_dim,
77 gfeat_in_dim,
78 hidden_layer_dim,
79 hidden_layer_dim,
80 num_hid_layers,
81 dropout,
82 ),
83 NodeLayer(
84 nfeat_in_dim,
85 hidden_layer_dim,
86 gfeat_in_dim,
87 hidden_layer_dim,
88 hidden_layer_dim,
89 num_hid_layers,
90 dropout,
91 ),
92 GlobalLayer(
93 hidden_layer_dim,
94 hidden_layer_dim,
95 gfeat_in_dim,
96 hidden_layer_dim,
97 hidden_layer_dim,
98 num_hid_layers,
99 dropout,
100 )
101 if global_layer
102 else None,
103 )
104
105
106 self.ML_list = torch.nn.ModuleList(
107 [
108 MetaLayer(
109 EdgeLayer(
110 hidden_layer_dim,
111 hidden_layer_dim,
112 hidden_layer_dim if global_layer else 0,
113 hidden_layer_dim,
114 hidden_layer_dim,
115 num_hid_layers,
116 dropout,
117 ),
118 NodeLayer(
119 hidden_layer_dim,
120 hidden_layer_dim,
121 hidden_layer_dim if global_layer else 0,
122 hidden_layer_dim,
123 hidden_layer_dim,
124 num_hid_layers,
125 dropout,
126 ),
127 GlobalLayer(
128 hidden_layer_dim,
129 hidden_layer_dim,
130 hidden_layer_dim,
131 hidden_layer_dim,
132 hidden_layer_dim,
133 num_hid_layers,
134 dropout,
135 )
136 if global_layer
137 else None,
138 )
139 for _ in range(num_ML)
140 ]
141 )
142
143
144 self.last_ML = MetaLayer(
145 EdgeLayer(
146 hidden_layer_dim,
147 hidden_layer_dim,
148 hidden_layer_dim if global_layer else 0,
149 hidden_layer_dim,
150 edge_classes,
151 num_hid_layers,
152 dropout,
153 normalize=False, # Do not normalize output layer
154 ),
155 NodeLayer(
156 hidden_layer_dim,
157 edge_classes,
158 hidden_layer_dim if global_layer else 0,
159 hidden_layer_dim,
160 x_classes,
161 num_hid_layers,
162 dropout,
163 normalize=False, # Do not normalize output layer
164 ),
165 GlobalLayer(
166 x_classes,
167 edge_classes,
168 hidden_layer_dim,
169 hidden_layer_dim,
170 1,
171 num_hid_layers,
172 dropout,
173 normalize=False, # Do not normalize output layer
174 )
175 if global_layer
176 else None,
177 )
178

Member Function Documentation

◆ forward()

def forward (   self,
  batch 
)
Called internally by PyTorch to propagate the input through the network.

Definition at line 179 of file geometric_network.py.

179 def forward(self, batch):
180 """
181 Called internally by PyTorch to propagate the input through the network.
182 """
183 x, u, edge_index, edge_attr, torch_batch = (
184 batch.x,
185 batch.u,
186 batch.edge_index,
187 batch.edge_attr,
188 batch.batch,
189 )
190
191 x, edge_attr, u = self.first_ML(
192 x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
193 )
194
195 del batch
196
197 for ML in self.ML_list:
198 x_skip = x
199 edge_skip = edge_attr
200 u_skip = u
201
202 x, edge_attr, u = ML(
203 x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
204 )
205
206 # Skip connections are added
207 x += x_skip
208 edge_attr += edge_skip
209 u += u_skip
210
211 del x_skip, edge_skip, u_skip
212
213 x, edge_attr, u = self.last_ML(
214 x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
215 )
216
217 # Edge labels are symmetrized
218 edge_index_t = edge_index[[1, 0]] # edge_index transposed
219
220 for i in range(edge_attr.shape[1]):
221 # edge_attr converted to sparse tensor...
222 edge_matrix = torch.sparse_coo_tensor(
223 edge_index, edge_attr[:, i]
224 )
225 # ... and its transposed
226 edge_matrix_t = torch.sparse_coo_tensor(
227 edge_index_t, edge_attr[:, i]
228 )
229
230 # Symmetrization happens here
231 edge_attr[:, i] = (
232 ((edge_matrix + edge_matrix_t) / 2.0).coalesce()
233 ).values()
234
235 return x, edge_attr, u

Member Data Documentation

◆ first_ML

first_ML

First MetaLayer.

Definition at line 73 of file geometric_network.py.

◆ last_ML

last_ML

Output MetaLayer.

Definition at line 144 of file geometric_network.py.

◆ ML_list

ML_list

Intermediate MetaLayers.

Definition at line 106 of file geometric_network.py.


The documentation for this class was generated from the following file: