Belle II Software  light-2403-persian
geometric_network.py
1 
8 
9 
10 import torch
11 from torch_geometric.nn import MetaLayer
12 from .geometric_layers import NodeLayer, EdgeLayer, GlobalLayer
13 
14 
15 class GraFEIModel(torch.nn.Module):
16  """
17  Actual implementation of the model,
18  based on the
19  `MetaLayer <https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.MetaLayer.html>`_
20  class.
21 
22  .. seealso::
23  `Relational inductive biases, deep learning, and graph networks <https://arxiv.org/abs/1806.01261>`_
24 
25  The network is composed of:
26 
27  1. A first MetaLayer to increase the number of nodes and edges features;
28  2. A number of intermediate MetaLayers (tunable in config file);
29  3. A last MetaLayer to decrease the number of node and edge features to the desired output dimension.
30 
31  .. figure:: figs/graFEI.png
32  :width: 42em
33  :align: center
34 
35  Each MetaLayer is in turn composed of `EdgeLayer`, `NodeLayer` and `GlobalLayer` sub-blocks.
36 
37  Args:
38  nfeat_in_dim (int): Node features dimension (number of input node features).
39  efeat_in_dim (int): Edge features dimension (number of input edge features).
40  gfeat_in_dim (int): Global features dimension (number of input global features).
41  edge_classes (int): Edge features output dimension (i.e. number of different edge labels in the LCAS matrix).
42  x_classes (int): Node features output dimension (i.e. number of different mass hypotheses).
43  hidden_layer_dim (int): Intermediate features dimension (same for node, edge and global).
44  num_hid_layers (int): Number of hidden layers in every MetaLayer.
45  num_ML (int): Number of intermediate MetaLayers.
46  droput (float): Dropout rate :math:`r \\in [0,1]`.
47  global_layer (bool): Whether to use global layer.
48 
49  :return: Node, edge and global features after model evaluation.
50  :rtype: tuple(`Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_)
51  """
52 
53  def __init__(
54  self,
55  nfeat_in_dim,
56  efeat_in_dim,
57  gfeat_in_dim,
58  edge_classes=6,
59  x_classes=7,
60  hidden_layer_dim=128,
61  num_hid_layers=1,
62  num_ML=1,
63  dropout=0.0,
64  global_layer=True,
65  **kwargs
66  ):
67  """
68  Initialization.
69  """
70  super(GraFEIModel, self).__init__()
71 
72 
73  self.first_MLfirst_ML = MetaLayer(
74  EdgeLayer(
75  nfeat_in_dim,
76  efeat_in_dim,
77  gfeat_in_dim,
78  hidden_layer_dim,
79  hidden_layer_dim,
80  num_hid_layers,
81  dropout,
82  ),
83  NodeLayer(
84  nfeat_in_dim,
85  hidden_layer_dim,
86  gfeat_in_dim,
87  hidden_layer_dim,
88  hidden_layer_dim,
89  num_hid_layers,
90  dropout,
91  ),
93  hidden_layer_dim,
94  hidden_layer_dim,
95  gfeat_in_dim,
96  hidden_layer_dim,
97  hidden_layer_dim,
98  num_hid_layers,
99  dropout,
100  )
101  if global_layer
102  else None,
103  )
104 
105 
106  self.ML_listML_list = torch.nn.ModuleList(
107  [
108  MetaLayer(
109  EdgeLayer(
110  hidden_layer_dim,
111  hidden_layer_dim,
112  hidden_layer_dim if global_layer else 0,
113  hidden_layer_dim,
114  hidden_layer_dim,
115  num_hid_layers,
116  dropout,
117  ),
118  NodeLayer(
119  hidden_layer_dim,
120  hidden_layer_dim,
121  hidden_layer_dim if global_layer else 0,
122  hidden_layer_dim,
123  hidden_layer_dim,
124  num_hid_layers,
125  dropout,
126  ),
127  GlobalLayer(
128  hidden_layer_dim,
129  hidden_layer_dim,
130  hidden_layer_dim,
131  hidden_layer_dim,
132  hidden_layer_dim,
133  num_hid_layers,
134  dropout,
135  )
136  if global_layer
137  else None,
138  )
139  for _ in range(num_ML)
140  ]
141  )
142 
143 
144  self.last_MLlast_ML = MetaLayer(
145  EdgeLayer(
146  hidden_layer_dim,
147  hidden_layer_dim,
148  hidden_layer_dim if global_layer else 0,
149  hidden_layer_dim,
150  edge_classes,
151  num_hid_layers,
152  dropout,
153  normalize=False, # Do not normalize output layer
154  ),
155  NodeLayer(
156  hidden_layer_dim,
157  edge_classes,
158  hidden_layer_dim if global_layer else 0,
159  hidden_layer_dim,
160  x_classes,
161  num_hid_layers,
162  dropout,
163  normalize=False, # Do not normalize output layer
164  ),
165  GlobalLayer(
166  x_classes,
167  edge_classes,
168  hidden_layer_dim,
169  hidden_layer_dim,
170  1,
171  num_hid_layers,
172  dropout,
173  normalize=False, # Do not normalize output layer
174  )
175  if global_layer
176  else None,
177  )
178 
179  def forward(self, batch):
180  """
181  Called internally by PyTorch to propagate the input through the network.
182  """
183  x, u, edge_index, edge_attr, torch_batch = (
184  batch.x,
185  batch.u,
186  batch.edge_index,
187  batch.edge_attr,
188  batch.batch,
189  )
190 
191  x, edge_attr, u = self.first_MLfirst_ML(
192  x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
193  )
194 
195  del batch
196 
197  for ML in self.ML_listML_list:
198  x_skip = x
199  edge_skip = edge_attr
200  u_skip = u
201 
202  x, edge_attr, u = ML(
203  x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
204  )
205 
206  # Skip connections are added
207  x += x_skip
208  edge_attr += edge_skip
209  u += u_skip
210 
211  del x_skip, edge_skip, u_skip
212 
213  x, edge_attr, u = self.last_MLlast_ML(
214  x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
215  )
216 
217  # Edge labels are symmetrized
218  edge_index_t = edge_index[[1, 0]] # edge_index transposed
219 
220  for i in range(edge_attr.shape[1]):
221  # edge_attr converted to sparse tensor...
222  edge_matrix = torch.sparse_coo_tensor(
223  edge_index, edge_attr[:, i]
224  )
225  # ... and its transposed
226  edge_matrix_t = torch.sparse_coo_tensor(
227  edge_index_t, edge_attr[:, i]
228  )
229 
230  # Symmetrization happens here
231  edge_attr[:, i] = (
232  ((edge_matrix + edge_matrix_t) / 2.0).coalesce()
233  ).values()
234 
235  return x, edge_attr, u
def __init__(self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, edge_classes=6, x_classes=7, hidden_layer_dim=128, num_hid_layers=1, num_ML=1, dropout=0.0, global_layer=True, **kwargs)
ML_list
Intermediate MetaLayers.