16 from torch.utils.data
import Dataset
17 from sklearn.preprocessing
import LabelEncoder
20 class ClusterImage(Dataset):
21 """ Prepare an image with necessary inputs for ConvNet.
23 It gets 7x7(=49) pixels of energy + 1 thetaId + 1 PhiId + 1 Pt.
24 Then prepare proper format for each input.
26 Regarding energy pixels, the specified threshold of 1 MeV is applied
27 which means pixels with < 1 MeV will become 0.
28 Clipping is also applied on pixels with the condition that
29 pixels with energy more than 1.0 GeV will be replaced with 1.0 GeV.
36 np_energy = params[
'energy_image'].astype(dtype=np.float32)
38 np_shape = (1, 1, params[
'image_length'], params[
'image_length'])
40 np_energy_reshaped = np_energy.reshape(np_shape)
41 self.energy = torch.from_numpy(np_energy_reshaped)
43 self.energy[self.energy < params[
'threshold']] = 0.
45 self.energy = np.clip(self.energy, 0, 1.0)
47 theta_input = params[
'thetaId']
48 encoder_theta_input = LabelEncoder()
52 encoder_theta_input.fit(np.array([float(i)
for i
in range(14, 58)]))
53 theta_input = encoder_theta_input.transform(theta_input.ravel())
54 self.theta_input = torch.from_numpy(theta_input)
56 phi_input = params[
'phiId']
57 encoder_phi_input = LabelEncoder()
58 encoder_phi_input.fit(np.array([float(i)
for i
in range(0, 144)]))
59 phi_input = encoder_phi_input.transform(phi_input.ravel())
60 self.phi_input = torch.from_numpy(phi_input)
62 pt = params[
'pt'].astype(dtype=np.float32)
63 self.pt = torch.from_numpy(pt)
67 return(self.pt.shape[0])
69 def __getitem__(self, idx):
71 return(self.energy[idx],
72 self.theta_input[idx],