196 def __init__(self, fit_function, dimension_of_fit_function, result_function, use_sigma_for_result_fitting):
199 ## cached value of the degrees of freedom in the fit
200 self.dimension_of_fit_function = dimension_of_fit_function
201 ## cached copy of the fitting function
202 self.fit_function = fit_function
204 FittedGroupedDEDXEstimatorTrainer.__init__(self, result_function, use_sigma_for_result_fitting)
206 def train_function(fit_data):
207 """Train on the fit to curated-data highest values whose truth value is known"""
208 max_value = self.use_only_the_highest_values(fit_data, 1).p_bin_centers.values[0]
210 if self.dimension_of_fit_function == 3:
211 p0 = (1e3, max_value, 4e-2)
212 elif self.dimension_of_fit_function == 6:
213 p0 = (1e3, max_value, 4e-2, 1, 1, 1)
215 popt, pcov = curve_fit(self.fit_function, fit_data.p_bin_centers, fit_data.number_of_p_values, p0=p0)
217 return [np.sqrt(np.diag(pcov)[1]), popt]
219 ## this class's training function
220 self.train_function = train_function
222 def plot_grouped_result(self, data):
223 """Plot the fitted grouped results"""
224 FittedGroupedDEDXEstimatorTrainer.plot_grouped_result(self, data)
226 dedx_binned_data, dedx_bins = self.create_dedx_bins(data)
228 p_plot_data = np.linspace(data.p.min(), data.p.max(), 1000)
230 # List to prevent bug in pd.DataFrame.apply
231 already_plotted_list = []
233 def plot_fitted_results(dedx_bin):
234 dedx_bin_center = dedx_bin.mean().values[0]
236 if dedx_bin_center not in already_plotted_list:
237 fitted_results = self.result_parameters_for_each_dedx_bin[dedx_bin.mean()[self.dedx_column]]
238 already_plotted_list.append(dedx_bin_center)
239 unneeded, fit_options = fitted_results
241 dedx_plot_data = self.fit_function(p_plot_data, *fitted_results[1])
242 plt.plot(p_plot_data, dedx_plot_data)
246 dedx_binned_data.apply(plot_fitted_results)