84 def test_merge_tables(self):
85 """Tests merging of PID tables"""
86 reweighter = Reweighter()
87 thresholds = {11: ('electronID', 0.8)}
88 merged_table = reweighter.merge_pid_weight_tables(self.pid_tables, thresholds)
89 self.assertEqual(len(merged_table), 154)
90 self.assertIn('PDG', merged_table.columns)
91 self.assertIn('mcPDG', merged_table.columns)
92 with self.assertRaises(ValueError):
93 wrong_thresholds = {11: ('electronID', 0.9)}
94 reweighter.merge_pid_weight_tables(self.pid_tables, wrong_thresholds)
96 def test_binning(self):
97 """Tests binning of PID tables"""
98 reweighter = Reweighter()
99 binning = reweighter.get_binning(self.pid_tables[(11, 11)])
101 self.assertEqual(len(binning), 3)
102 self.assertCountEqual(binning.keys(), ['p', 'cosTheta', 'charge'])
103 self.assertEqual(len(binning['p']), 7)
104 self.assertEqual(len(binning['cosTheta']), 8)
105 self.assertEqual(len(binning['charge']), 3)
106 binning = reweighter.get_binning(self.pid_tables[(11, 211)])
107 self.assertEqual(len(binning), 3)
108 self.assertEqual(len(binning['p']), 6)
110 def test_add_pid_particle(self):
111 """Tests adding PID particle"""
112 reweighter = Reweighter()
113 thresholds = {11: ('electronID', 0.8)}
114 reweighter.add_pid_particle('', self.pid_tables, thresholds)
115 particle = reweighter.get_particle('')
116 self.assertIsNotNone(particle)
117 self.assertEqual(len(particle.merged_table), 154)
118 self.assertCountEqual(particle.get_binning_variables(), ['p', 'cosTheta', 'charge'])
120 with self.assertRaises(ValueError):
121 reweighter.add_pid_particle('', self.pid_tables, thresholds)
123 def test_add_fei_particle(self):
124 """Tests add_fei_particle method"""
125 print('==================================================')
126 reweighter = Reweighter()
127 reweighter.add_fei_particle('', self.fei_table, 0.01, None)
128 particle = reweighter.get_particle('')
129 self.assertEqual(len(particle.merged_table), 4)
130 self.assertCountEqual(particle.get_binning_variables(), ['dec_mode'])
132 with self.assertRaises(ValueError):
133 reweighter.add_fei_particle('', self.fei_table, 0.001, None)
135 def test_reweight(self):
136 """Tests reweighting of PID and FEI particles"""
138 reweighter = Reweighter(n_variations=n_variations)
139 thresholds = {11: ('electronID', 0.8)}
140 reweighter.add_pid_particle('', self.pid_tables, thresholds, sys_seed=42)
141 reweighter.add_fei_particle('B0', self.fei_table, 0.01, None)
142 local_data = self.user_data.copy(deep=True)
143 reweighter.reweight(local_data)
144 self.assertEqual(len(local_data), 4)
145 self.assertIn('Weight', local_data.columns)
146 self.assertIn('B0_Weight', local_data.columns)
147 cols = [f'Weight_{i}' for i in range(0, n_variations)]
148 self.assertCountEqual(cols, [col for col in local_data.columns if col.startswith('Weight_')])
149 self.assertTrue((local_data.query('abs(mcPDG) == 11')[['Weight']+cols] < 1.5).all().all())
150 self.assertTrue((local_data.query('abs(mcPDG) != 11')[['Weight']+cols] > 1.5).all().all())
151 # print(local_data.query('abs(mcPDG) != 11')[['Weight']+cols].std(axis=1))
152 self.assertTrue((local_data.query('abs(mcPDG) == 11')[['Weight']+cols].std(axis=1) < 0.015).all())
153 self.assertTrue((local_data.query('abs(mcPDG) != 11')[['Weight']+cols].std(axis=1) > 0.015).all())
155 cols = [f'B0_Weight_{i}' for i in range(0, n_variations)]
156 self.assertCountEqual(cols, [col for col in local_data.columns if col.startswith('B0_Weight_')])
157 # print(local_data[['B0_Weight']+cols])
158 self.assertFalse((local_data[['B0_Weight']+cols].isna()).any().any())
159 self.assertTrue((local_data[['B0_Weight']+cols] > 0).all().all())
161 print('==================================================')
162 print('==================================================')
163 print('Shift index test:')
164 local_data_shift = self.user_data.copy(deep=True)
165 local_data_shift.index += 100
166 reweighter.reweight(local_data_shift)
167 self.assertIn('Weight', local_data_shift.columns)
168 cols = [f'Weight_{i}' for i in range(0, n_variations)]
169 print(local_data[['Weight']+cols])
170 print(local_data_shift[['Weight']+cols])
171 self.assertFalse((local_data_shift[['Weight']+cols].isna()).any().any())
172 cols = [f'B0_Weight_{i}' for i in range(0, n_variations)]
173 self.assertFalse((local_data_shift[['B0_Weight']+cols].isna()).any().any())