Belle II Software  release-08-01-10
cnn_pid_cluster_image.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 # @cond
13 
14 import torch
15 import numpy as np
16 from torch.utils.data import Dataset
17 from sklearn.preprocessing import LabelEncoder
18 
19 
20 class ClusterImage(Dataset):
21  """ Prepare an image with necessary inputs for ConvNet.
22 
23  It gets 7x7(=49) pixels of energy + 1 thetaId + 1 PhiId + 1 Pt.
24  Then prepare proper format for each input.
25 
26  Regarding energy pixels, the specified threshold of 1 MeV is applied
27  which means pixels with < 1 MeV will become 0.
28  Clipping is also applied on pixels with the condition that
29  pixels with energy more than 1.0 GeV will be replaced with 1.0 GeV.
30  """
31 
32  def __init__(
33  self,
34  params
35  ):
36  np_energy = params['energy_image'].astype(dtype=np.float32)
37 
38  np_shape = (1, 1, params['image_length'], params['image_length'])
39 
40  np_energy_reshaped = np_energy.reshape(np_shape)
41  self.energy = torch.from_numpy(np_energy_reshaped)
42 
43  self.energy[self.energy < params['threshold']] = 0.
44 
45  self.energy = np.clip(self.energy, 0, 1.0)
46 
47  theta_input = params['thetaId']
48  encoder_theta_input = LabelEncoder()
49  # Since CNN can only predict PID of the tracks inside the barrel,
50  # there are two hard-coded numbers in the following line (14 and 58),
51  # representing the thetaID limit.
52  encoder_theta_input.fit(np.array([float(i) for i in range(14, 58)]))
53  theta_input = encoder_theta_input.transform(theta_input.ravel())
54  self.theta_input = torch.from_numpy(theta_input)
55 
56  phi_input = params['phiId']
57  encoder_phi_input = LabelEncoder()
58  encoder_phi_input.fit(np.array([float(i) for i in range(0, 144)]))
59  phi_input = encoder_phi_input.transform(phi_input.ravel())
60  self.phi_input = torch.from_numpy(phi_input)
61 
62  pt = params['pt'].astype(dtype=np.float32)
63  self.pt = torch.from_numpy(pt)
64 
65  def __len__(self):
66 
67  return(self.pt.shape[0])
68 
69  def __getitem__(self, idx):
70 
71  return(self.energy[idx],
72  self.theta_input[idx],
73  self.phi_input[idx],
74  self.pt[idx])
75 # @endcond