Belle II Software development
cnn_pid_cluster_image.py
1#!/usr/bin/env python3
2
3
10
11# @cond
12
13import torch
14import numpy as np
15from torch.utils.data import Dataset
16from sklearn.preprocessing import LabelEncoder
17
18
19class ClusterImage(Dataset):
20 """ Prepare an image with necessary inputs for ConvNet.
21
22 It gets 7x7(=49) pixels of energy + 1 thetaId + 1 PhiId + 1 Pt.
23 Then prepare proper format for each input.
24
25 Regarding energy pixels, the specified threshold of 1 MeV is applied
26 which means pixels with < 1 MeV will become 0.
27 Clipping is also applied on pixels with the condition that
28 pixels with energy more than 1.0 GeV will be replaced with 1.0 GeV.
29 """
30
31 def __init__(
32 self,
33 params
34 ):
35 np_energy = params['energy_image'].astype(dtype=np.float32)
36
37 np_shape = (1, 1, params['image_length'], params['image_length'])
38
39 np_energy_reshaped = np_energy.reshape(np_shape)
40 self.energy = torch.from_numpy(np_energy_reshaped)
41
42 self.energy[self.energy < params['threshold']] = 0.
43
44 self.energy = np.clip(self.energy, 0, 1.0)
45
46 theta_input = params['thetaId']
47 encoder_theta_input = LabelEncoder()
48 # Since CNN can only predict PID of the tracks inside the barrel,
49 # there are two hard-coded numbers in the following line (14 and 58),
50 # representing the thetaID limit.
51 encoder_theta_input.fit(np.array([float(i) for i in range(14, 58)]))
52 theta_input = encoder_theta_input.transform(theta_input.ravel())
53 self.theta_input = torch.from_numpy(theta_input)
54
55 phi_input = params['phiId']
56 encoder_phi_input = LabelEncoder()
57 encoder_phi_input.fit(np.array([float(i) for i in range(0, 144)]))
58 phi_input = encoder_phi_input.transform(phi_input.ravel())
59 self.phi_input = torch.from_numpy(phi_input)
60
61 pt = params['pt'].astype(dtype=np.float32)
62 self.pt = torch.from_numpy(pt)
63
64 def __len__(self):
65
66 return(self.pt.shape[0])
67
68 def __getitem__(self, idx):
69
70 return(self.energy[idx],
71 self.theta_input[idx],
72 self.phi_input[idx],
73 self.pt[idx])
74# @endcond