27 return logging.getLogger(__name__)
31formatter = TolerateMissingKeyFormatter()
35 """Python module to refine a peeled dictionary"""
38 """Constructor of the Refiner instance"""
42 def __get__(self, harvesting_module, cls=None):
43 """Getter of the Refiner instance"""
44 if harvesting_module
is None:
51 def bound_call(*args, **kwds):
52 return refine(harvesting_module, *args, **kwds)
42 def __get__(self, harvesting_module, cls=None):
…
55 def __call__(self, harvesting_module, crops=None, *args, **kwds):
56 """implementation of the function-call of the Refiner instance
58 r(harvester) # decoration
59 r(harvester, crops, args, keywords) # refinement
63 harvesting_module.refiners.append(self)
64 return harvesting_module
67 return self.
refine(harvesting_module, crops, *args, **kwds)
55 def __call__(self, harvesting_module, crops=None, *args, **kwds):
…
69 def refine(self, harvesting_module, *args, **kwds):
70 """Apply the instance's refiner function"""
69 def refine(self, harvesting_module, *args, **kwds):
…
75 """Refiner for figures of merit"""
77 default_name =
"{module.id}_figures_of_merit{groupby_key}"
79 default_title =
"Figures of merit in {module.title}"
81 default_contact =
"{module.contact}"
83 default_description =
"Figures of merit are the {aggregation.__name__} of {keys}"
85 default_check =
"Check for reasonable values"
87 default_key =
"{aggregation.__name__}_{part_name}"
95 default_aggregation = mean
106 """Constructor for this refiner"""
110 ## cached name of the figure of merit
112 ## cached title of the figure of merit
115 ## cached description of the figure of merit
116 self.description = description
117 ## cached user-check action of the figure of merit
119 ## cached contact person of the figure of merit
120 self.contact = contact
122 ## cached copy of the figures-of-merit key
124 ## cached copy of the crops-aggregation method
125 self.aggregation = aggregation
131 groupby_part_name=None,
134 """Process the figures of merit"""
136 name = self.name or self.default_name
137 title = self.title or self.default_title
138 contact = self.contact or self.default_contact
139 description = self.description or self.default_description
140 check = self.check or self.default_check
142 aggregation = self.aggregation or self.default_aggregation
144 replacement_dict = dict(
146 module=harvesting_module,
147 aggregation=aggregation,
148 groupby_key='_' + groupby_part_name + groupby_value if groupby_part_name else "",
149 groupby=groupby_part_name, # deprecated
150 groupby_value=groupby_value, # deprecated
153 name = formatter.format(name, **replacement_dict)
154 title = formatter.format(title, **replacement_dict)
155 contact = formatter.format(contact, **replacement_dict)
157 figures_of_merit = ValidationFiguresOfMerit(name,
161 for part_name, parts in iter_items_sorted_for_key(crops):
162 key = self.key or self.default_key
163 key = formatter.format(key, part_name=part_name, **replacement_dict)
164 figures_of_merit[key] = aggregation(parts)
166 keys = list(figures_of_merit.keys())
168 description = formatter.format(description, keys=keys, **replacement_dict)
169 check = formatter.format(check, keys=keys, **replacement_dict)
171 figures_of_merit.description = description
172 figures_of_merit.check = check
175 figures_of_merit.write(tdirectory)
177 print(figures_of_merit)
180class SaveHistogramsRefiner(Refiner):
181 """Refiner for histograms"""
182 ## default name for this histogram
183 default_name = "{module.id}_{part_name}_histogram{groupby_key}{stackby_key}"
184 ## default title for this histogram
185 default_title = "Histogram of {part_name}{groupby_key}{stackby_key} from {module.title}"
186 ## default contact person for this histogram
187 default_contact = "{module.contact}"
188 ## default description for this histogram
189 default_description = "This is a histogram of {part_name}{groupby_key}{stackby_key}."
190 ## default user-check action for this histogram
191 default_check = "Check if the distribution is reasonable"
202 outlier_z_score=None,
203 allow_discrete=False,
207 """Constructor for this refine
r"""
211 ## cached user-defined name for this histogram
213 ## cached user-defined title for this histogram
216 ## cached user-defined description for this histogram
217 self.description = description
218 ## cached user-defined user-check action for this histogram
220 ## cached user-defined contact person for this histogram
221 self.contact = contact
223 ## cached lower bound for this histogram
224 self.lower_bound = lower_bound
225 ## cached upper bound for this histogram
226 self.upper_bound = upper_bound
227 ## cached number of bins for this histogram
230 ## cached Z-score (for outlier detection) for this histogram
231 self.outlier_z_score = outlier_z_score
232 ## cached flag to allow discrete values for this histogram
233 self.allow_discrete = allow_discrete
234 ## cached stacking selection for this histogram
235 self.stackby = stackby
237 ## cached fit for this histogram
239 ## cached fit Z-score (for outlier detection) for this histogram
240 self.fit_z_score = fit_z_score
246 groupby_part_name=None,
249 """Process the histogram"""
251 stackby = self.stackby
253 stackby_parts = crops[stackby]
257 replacement_dict = dict(
259 module=harvesting_module,
260 stackby_key=' stacked by ' + stackby if stackby else "",
261 groupby_key=' in group ' + groupby_part_name + groupby_value if groupby_part_name else "",
264 contact = self.contact or self.default_contact
265 contact = formatter.format(contact, **replacement_dict)
267 for part_name, parts in iter_items_sorted_for_key(crops):
268 name = self.name or self.default_name
269 title = self.title or self.default_title
270 description = self.description or self.default_description
271 check = self.check or self.default_check
273 name = formatter.format(name, part_name=part_name, **replacement_dict)
274 title = formatter.format(title, part_name=part_name, **replacement_dict)
275 description = formatter.format(description, part_name=part_name, **replacement_dict)
276 check = formatter.format(check, part_name=part_name, **replacement_dict)
278 histogram = ValidationPlot(name)
279 histogram.hist(parts,
280 lower_bound=self.lower_bound,
281 upper_bound=self.upper_bound,
283 outlier_z_score=self.outlier_z_score,
284 allow_discrete=self.allow_discrete,
285 stackby=stackby_parts)
287 histogram.title = title
288 histogram.contact = contact
289 histogram.description = description
290 histogram.check = check
292 histogram.xlabel = compose_axis_label(part_name)
295 if self.fit_z_score is None:
298 kwds = dict(z_score=self.fit_z_score)
300 fit_method_name = 'fit_' + str(self.fit)
302 fit_method = getattr(histogram, fit_method_name)
303 except AttributeError:
304 histogram.fit(str(self.fit), **kwds)
309 histogram.write(tdirectory)
180class SaveHistogramsRefiner(Refiner):
…
312class Plot2DRefiner(Refiner):
313 """Refiner for profile histograms and 2D scatterplots"""
314 ## by default, this refiner is for profile histograms
315 plot_kind = "profile"
332 outlier_z_score=None,
335 skip_single_valued=False,
336 allow_discrete=False):
337 """Constructor for this refine
r"""
341 ## cached user-defined name for this profile histogram / scatterplot
343 ## cached user-defined title for this profile histogram / scatterplot
346 ## cached user-defined description for this profile histogram / scatterplot
347 self.description = description
348 ## cached user-defined user-check action for this profile histogram / scatterplot
350 ## cached user-defined contact person for this profile histogram / scatterplot
351 self.contact = contact
353 ## cached value of abscissa
355 ## cached value of ordinate
357 ## cached stacking selection for this profile histogram / scatterplot
358 self.stackby = stackby
359 ## cached measurement unit for ordinate
362 ## cached lower bound for this profile histogram / scatterplot
363 self.lower_bound = lower_bound
364 ## cached upper bound for this profile histogram / scatterplot
365 self.upper_bound = upper_bound
366 ## cached number of bins for this profile histogram / scatterplot
368 ## cached flag for probability y axis (range 0.0 .. 1.05) for this profile histogram / scatterplot
369 self.y_binary = y_binary
370 ## cached flag for logarithmic y axis for this profile histogram / scatterplot
373 ## cached Z-score (for outlier detection) for this profile histogram / scatterplot
374 self.outlier_z_score = outlier_z_score
375 ## cached flag to allow discrete values for this profile histogram / scatterplot
376 self.allow_discrete = allow_discrete
378 ## cached fit for this profile histogram / scatterplot
380 ## cached fit Z-score (for outlier detection) for this profile histogram / scatterplot
381 self.fit_z_score = fit_z_score
383 ## cached flag to skip single-valued bins for this profile histogram / scatterplot
384 self.skip_single_valued = skip_single_valued
390 groupby_part_name=None,
393 """Process the profile histogram / scatterplot"""
395 stackby = self.stackby
397 stackby_parts = crops[stackby]
401 replacement_dict = dict(
403 module=harvesting_module,
404 stackby_key=' stacked by ' + stackby if stackby else "",
405 groupby_key=' in group ' + groupby_part_name + groupby_value if groupby_part_name else "",
408 contact = self.contact or self.default_contact
409 contact = formatter.format(contact, **replacement_dict)
411 y_crops = select_crop_parts(crops, select=self.y)
412 x_crops = select_crop_parts(crops, select=self.x, exclude=self.y)
414 for y_part_name, y_parts in iter_items_sorted_for_key(y_crops):
415 for x_part_name, x_parts in iter_items_sorted_for_key(x_crops):
417 if self.skip_single_valued and not self.has_more_than_one_value(x_parts):
418 get_logger().info('Skipping "%s" by "%s" profile because x has only a single value "%s"',
424 if self.skip_single_valued and not self.has_more_than_one_value(y_parts):
425 get_logger().info('Skipping "%s" by "%s" profile because y has only a single value "%s"',
431 name = self.name or self.default_name
432 title = self.title or self.default_title
433 description = self.description or self.default_description
434 check = self.check or self.default_check
436 name = formatter.format(name,
437 x_part_name=x_part_name,
438 y_part_name=y_part_name,
441 title = formatter.format(title,
442 x_part_name=x_part_name,
443 y_part_name=y_part_name,
446 description = formatter.format(description,
447 x_part_name=x_part_name,
448 y_part_name=y_part_name,
451 check = formatter.format(check,
452 x_part_name=x_part_name,
453 y_part_name=y_part_name,
456 profile_plot = ValidationPlot(name)
458 plot_kind = self.plot_kind
459 if plot_kind == "profile":
460 profile_plot.profile(x_parts,
462 lower_bound=self.lower_bound,
463 upper_bound=self.upper_bound,
465 y_binary=self.y_binary,
467 outlier_z_score=self.outlier_z_score,
468 allow_discrete=self.allow_discrete,
469 stackby=stackby_parts)
472 if self.fit_z_score is None:
475 kwds = dict(z_score=self.fit_z_score)
477 fit_method_name = 'fit_' + str(self.fit)
479 fit_method = getattr(profile_plot, fit_method_name)
480 except BaseException:
481 profile_plot.fit(str(self.fit), **kwds)
485 elif plot_kind == "scatter":
486 profile_plot.scatter(x_parts,
488 lower_bound=self.lower_bound,
489 upper_bound=self.upper_bound,
490 outlier_z_score=self.outlier_z_score,
491 stackby=stackby_parts)
493 profile_plot.title = title
494 profile_plot.contact = contact
495 profile_plot.description = description
496 profile_plot.check = check
498 profile_plot.xlabel = compose_axis_label(x_part_name)
499 profile_plot.ylabel = compose_axis_label(y_part_name, self.y_unit)
502 profile_plot.write(tdirectory)
505 def has_more_than_one_value(xs):
506 """check if a list has at least two unique values"""
505 def has_more_than_one_value(xs):
…
312class Plot2DRefiner(Refiner):
…
515class SaveProfilesRefiner(Plot2DRefiner):
516 """Refiner for profile histograms"""
517 ## default name for this profile histogram
518 default_name = "{module.id}_{y_part_name}_by_{x_part_name}_profile{groupby_key}{stackby_key}"
519 ## default title for this profile histogram
520 default_title = "Profile of {y_part_name} by {x_part_name} from {module.title}"
521 ## default contact person for this profile histogram
522 default_contact = "{module.contact}"
523 ## default description for this profile histogram
524 default_description = "This is a profile of {y_part_name} over {x_part_name}."
525 ## default user-check action for this profile histogram
526 default_check = "Check if the trend line is reasonable."
528 ## specify this as a profile histogram rather than a scatterplot
529 plot_kind = "profile"
515class SaveProfilesRefiner(Plot2DRefiner):
…
532class SaveScatterRefiner(Plot2DRefiner):
533 """Refiner for 2D scatterplots"""
534 ## default name for this scatterplot
535 default_name = "{module.id}_{y_part_name}_by_{x_part_name}_scatter{groupby_key}{stackby_key}"
536 ## default title for this scatterplot
537 default_title = "Scatter of {y_part_name} by {x_part_name} from {module.title}"
538 ## default contact person for this scatterplot
539 default_contact = "{module.contact}"
540 ## default description for this scatterplot
541 default_description = "This is a scatter of {y_part_name} over {x_part_name}."
542 ## default user-check action for this scatterplot
543 default_check = "Check if the distributions is reasonable."
545 ## specify this as a scatterplot rather than a profile histogram
546 plot_kind = "scatter"
532class SaveScatterRefiner(Plot2DRefiner):
…
549class SaveClassificationAnalysisRefiner(Refiner):
550 """Refiner for truth-classification analyses"""
552 ## default contact person for this truth-classification analysis
553 default_contact = "{module.contact}"
555 ## default name for the truth-classification analysis truth-values collection
556 default_truth_name = "{part_name}_truth"
557 ## default name for the truth-classification analysis estimates collection
558 default_estimate_name = "{part_name}_estimate"
569 outlier_z_score=None,
570 allow_discrete=False,
572 """Constructor for this refine
r"""
574 ## cached part name for this truth-classification analysis
575 self.part_name = part_name
576 ## cached contact person for this truth-classification analysis
577 self.contact = contact
578 ## cached estimates-collection name for this truth-classification analysis
579 self.estimate_name = estimate_name
580 ## cached truth-values-collection name for this truth-classification analysis
581 self.truth_name = truth_name
583 ## cached threshold of estimates for this truth-classification analysis
585 ## cached cut direction (> or <) of estimates for this truth-classification analysis
586 self.cut_direction = cut_direction
588 ## cached lower bound of estimates for this truth-classification analysis
589 self.lower_bound = lower_bound
590 ## cached upper bound of estimates for this truth-classification analysis
591 self.upper_bound = upper_bound
592 ## cached Z-score (for outlier detection) of estimates for this truth-classification analysis
593 self.outlier_z_score = outlier_z_score
594 ## cached discrete-value flag of estimates for this truth-classification analysis
595 self.allow_discrete = allow_discrete
596 ## cached measurement unit of estimates for this truth-classification analysis
603 groupby_part_name=None,
606 """Process the truth-classification analysis"""
608 replacement_dict = dict(
610 module=harvesting_module,
611 groupby_key='_' + groupby_part_name + groupby_value if groupby_part_name else "",
612 groupby=groupby_part_name, # deprecated
613 groupby_value=groupby_value, # deprecated
616 contact = self.contact or self.default_contact
617 contact = formatter.format(contact, **replacement_dict)
619 if self.truth_name is not None:
620 truth_name = self.truth_name
622 truth_name = self.default_truth_name
624 truth_name = formatter.format(truth_name, part_name=self.part_name)
625 truths = crops[truth_name]
627 if self.estimate_name is not None:
628 estimate_name = self.estimate_name
630 estimate_name = self.default_estimate_name
632 if isinstance(estimate_name, str):
633 estimate_names = [estimate_name, ]
635 estimate_names = estimate_name
637 for estimate_name in estimate_names:
638 estimate_name = formatter.format(estimate_name, part_name=self.part_name)
639 estimates = crops[estimate_name]
641 classification_analysis = ClassificationAnalysis(quantity_name=estimate_name,
643 cut_direction=self.cut_direction,
645 lower_bound=self.lower_bound,
646 upper_bound=self.upper_bound,
647 outlier_z_score=self.outlier_z_score,
648 allow_discrete=self.allow_discrete,
651 classification_analysis.analyse(estimates, truths)
654 classification_analysis.write(tdirectory)
549class SaveClassificationAnalysisRefiner(Refiner):
…
657class SavePullAnalysisRefiner(Refiner):
658 """Refiner for pull analyses"""
660 ## default name for this pull analysis
661 default_name = "{module.id}_{quantity_name}"
662 ## default contact person for this pull analysis
663 default_contact = "{module.contact}"
664 ## default suffix for the title of this pull analysis
665 default_title_postfix = " from {module.title}"
667 ## default name for the pull analysis truth-values collection
668 default_truth_name = "{part_name}_truth"
669 ## default name for the pull analysis estimates collection
670 default_estimate_name = "{part_name}_estimate"
671 ## default name for the pull analysis variances collection
672 default_variance_name = "{part_name}_variance"
686 outlier_z_score=None,
689 """Constructor for this refine
r"""
690 if aux_names is None:
692 ## cached name for this pull analysis
694 ## cached contact person for this pull analysis
695 self.contact = contact
696 ## cached suffix for the title of this pull analysis
697 self.title_postfix = title_postfix
699 ## cached array of part names for this pull analysis
701 if part_names is not None:
702 self.part_names = part_names
704 if part_name is not None:
705 self.part_names.append(part_name)
707 ## cached name for the pull analysis truth-values collection
708 self.truth_name = truth_name
709 ## cached name for the pull analysis estimates collection
710 self.estimate_name = estimate_name
711 ## cached name for the pull analysis variances collection
712 self.variance_name = variance_name
714 ## cached name of the quantity for the pull analysis
715 self.quantity_name = quantity_name
716 ## cached measurement unit for the pull analysis
719 ## cached auxiliary names for the pull analysis
720 self.aux_names = aux_names
722 ## cached Z-score (for outlier detection) for the pull analysis
723 self.outlier_z_score = outlier_z_score
724 ## cached absolute-value-comparison flag for the pull analysis
725 self.absolute = absolute
726 ## cached list of plots produced by the pull analysis
727 self.which_plots = which_plots
733 groupby_part_name=None,
736 """Process the pull analysis"""
738 replacement_dict = dict(
740 module=harvesting_module,
741 # stackby_key='_' + stackby if stackby else "",
742 groupby_key='_' + groupby_part_name + groupby_value if groupby_part_name else "",
743 groupby=groupby_part_name, # deprecated
744 groupby_value=groupby_value, # deprecated
747 contact = self.contact or self.default_contact
748 contact = formatter.format(contact, **replacement_dict)
750 name = self.name or self.default_name
753 auxiliaries = select_crop_parts(crops, self.aux_names)
757 for part_name in self.part_names:
758 name = formatter.format(name, part_name=part_name, **replacement_dict)
759 plot_name = name + "_{subplot_name}"
761 title_postfix = self.title_postfix
762 if title_postfix is None:
763 title_postfix = self.default_title_postfix
765 title_postfix = formatter.format(title_postfix, part_name=part_name, **replacement_dict)
766 plot_title = "{subplot_title} of {quantity_name}" + title_postfix
768 if self.truth_name is not None:
769 truth_name = self.truth_name
771 truth_name = self.default_truth_name
773 if self.estimate_name is not None:
774 estimate_name = self.estimate_name
776 estimate_name = self.default_estimate_name
778 if self.variance_name is not None:
779 variance_name = self.variance_name
781 variance_name = self.default_variance_name
783 truth_name = formatter.format(truth_name, part_name=part_name)
784 estimate_name = formatter.format(estimate_name, part_name=part_name)
785 variance_name = formatter.format(variance_name, part_name=part_name)
787 truths = crops[truth_name]
788 estimates = crops[estimate_name]
790 variances = crops[variance_name]
794 quantity_name = self.quantity_name or part_name
796 which_plots = self.which_plots
798 pull_analysis = PullAnalysis(quantity_name,
800 absolute=self.absolute,
801 outlier_z_score=self.outlier_z_score,
803 plot_title=plot_title)
805 pull_analysis.analyse(truths,
808 auxiliaries=auxiliaries,
809 which_plots=which_plots)
811 pull_analysis.contact = contact
814 pull_analysis.write(tdirectory)
657class SavePullAnalysisRefiner(Refiner):
…
817class SaveTreeRefiner(Refiner):
818 """Refiner for ROOT TTrees"""
820 ## default name for this TTree
821 default_name = "{module.id}_tree"
822 ## default title for this TTree
823 default_title = "Tree of {module.id}"
828 """Constructor for this refine
r"""
831 ## cached name for this TTree
833 ## cached title for this TTree
840 groupby_part_name=None,
843 """Process the TTree"""
845 replacement_dict = dict(
847 module=harvesting_module,
848 groupby_key='_' + groupby_part_name + groupby_value if groupby_part_name else "",
849 groupby=groupby_part_name, # deprecated
850 groupby_value=groupby_value, # deprecated
853 with root_cd(tdirectory):
854 name = self.name or self.default_name
855 title = self.title or self.default_title
857 name = formatter.format(name, **replacement_dict)
858 title = formatter.format(title, **replacement_dict)
860 output_ttree = ROOT.TTree(root_save_name(name), title)
861 for part_name, parts in iter_items_sorted_for_key(crops):
862 self.add_branch(output_ttree, part_name, parts)
864 output_ttree.FlushBaskets()
867 def add_branch(self, output_ttree, part_name, parts):
868 """Add a TBranch to the TTree"""
869 input_value = np.zeros(1, dtype=float)
871 branch_type_spec = f'{part_name}/D'
872 tbranch = output_ttree.Branch(part_name, input_value, branch_type_spec)
874 if output_ttree.GetNbranches() == 1:
875 # On filling of the first branch we need to use the fill method of the TTree
876 # For all other branches we can use the one of the branch
879 input_value[0] = value
884 input_value[0] = value
887 output_ttree.GetEntry(0)
888 output_ttree.ResetBranchAddress(tbranch)
889 also_subbranches = True # No subbranches here but we drop the buffers just in case.
890 output_ttree.DropBranchFromCache(tbranch, also_subbranches)
867 def add_branch(self, output_ttree, part_name, parts):
…
817class SaveTreeRefiner(Refiner):
…
893class FilterRefiner(Refiner):
894 """Refiner for filters"""
896 def __init__(self, wrapped_refiner, filter=None, on=None):
897 """Constructor for this refine
r"""
899 ## cached value of the wrapped refiner
900 self.wrapped_refiner = wrapped_refiner
903 ## cached value of the filter
904 self.filter = np.nonzero
908 ## cached value of the part name to filter on
896 def __init__(self, wrapped_refiner, filter=None, on=None):
…
911 def refine(self, harvesting_module, crops, *args, **kwds):
912 """Process this filte
r"""
913 filtered_crops = filter_crops(crops, self.filter, part_name=self.on)
914 self.wrapped_refiner(harvesting_module, filtered_crops, *args, **kwds)
911 def refine(self, harvesting_module, crops, *args, **kwds): …
893class FilterRefiner(Refiner):
…
917class SelectRefiner(Refiner):
918 """Refiner for selection"""
920 def __init__(self, wrapped_refiner, select=None, exclude=None):
921 """Constructor for this refine
r"""
926 ## cached value of the wrapped refiner
927 self.wrapped_refiner = wrapped_refiner
928 ## cached value of the selector
930 ## cached value of the exclusion flag
931 self.exclude = exclude
920 def __init__(self, wrapped_refiner, select=None, exclude=None):
…
933 def refine(self, harvesting_module, crops, *args, **kwds):
934 """Process this selection"""
935 selected_crops = select_crop_parts(crops, select=self.select, exclude=self.exclude)
936 self.wrapped_refiner(harvesting_module, selected_crops, *args, **kwds)
933 def refine(self, harvesting_module, crops, *args, **kwds): …
917class SelectRefiner(Refiner): …
939class GroupByRefiner(Refiner):
940 """Refiner for grouping"""
942 ## default value of the exclude-by classifier
943 default_exclude_by = True
949 """Constructor for this refine
r"""
952 ## cached value of the wrapped refiner
953 self.wrapped_refiner = wrapped_refiner
954 ## cached value of the group-by classifier
956 ## cached value of the exclude-by classifier
957 self.exclude_by = exclude_by if exclude_by is not None else self.default_exclude_by
962 groupby_part_name=None,
966 """Process this grouping"""
970 # A single name to do the group by
971 if isinstance(by, str) or by is None:
973 # Wrap it into a list an continue with the general case
976 for groupby_spec in by:
977 if groupby_spec is None:
978 # Using empty string as groupby_value to indicate that all values have been selected
980 self.wrapped_refiner(harvesting_module,
982 groupby_part_name=None,
988 elif isinstance(groupby_spec, str):
989 part_name = groupby_spec
990 groupby_parts = crops[part_name]
991 unique_values, index_of_values = np.unique(groupby_parts, return_inverse=True)
992 groupby_values = [f" = {value}]" for value in unique_values]
994 elif isinstance(groupby_spec, tuple):
995 part_name = groupby_spec[0]
996 cuts = groupby_spec[1]
998 groupby_parts = crops[part_name]
1001 digitization_cuts = list(np.sort(cuts))
1002 if digitization_cuts[-1] != np.inf:
1003 digitization_cuts.append(np.inf)
1004 index_of_values = np.digitize(groupby_parts, digitization_cuts, right=True)
1006 groupby_values = [f"below {digitization_cuts[0]}"]
1007 bin_bounds = list(zip(digitization_cuts[0:], digitization_cuts[1:]))
1008 for lower_bound, upper_bound in bin_bounds:
1009 if lower_bound == upper_bound:
1010 # degenerated bin case
1011 groupby_values.append(f"= {lower_bound}")
1012 elif upper_bound == np.inf:
1013 groupby_values.append(f"above {lower_bound}")
1015 groupby_values.append(f"between {lower_bound} and {upper_bound}")
1016 groupby_values.append("is nan")
1017 assert len(groupby_values) == len(digitization_cuts) + 1
1020 raise ValueError(f"Unknown groupby specification {groupby_spec}")
1022 # Exclude the groupby variable if desired
1023 selected_crops = select_crop_parts(crops, exclude=part_name if self.exclude_by else None)
1024 for index_of_value, groupby_value in enumerate(groupby_values):
1025 indices_for_value = index_of_values == index_of_value
1026 if not np.any(indices_for_value):
1029 filtered_crops = filter_crops(selected_crops, indices_for_value)
1031 self.wrapped_refiner(harvesting_module,
1033 groupby_part_name=part_name,
1034 groupby_value=groupby_value,
939class GroupByRefiner(Refiner):
…
1039class CdRefiner(Refiner):
1040 """Refiner for change-directory"""
1042 ## Folder name to be used if a groupby selection is active.
1043 default_folder_name = ""
1044 ## Default suffix for a groupby selection
1045 default_groupby_addition = "_groupby_{groupby}_{groupby_value}"
1050 groupby_addition=None):
1051 """Constructor for this refine
r"""
1053 ## cached value of the wrapped refiner
1054 self.wrapped_refiner = wrapped_refiner
1055 ## cached value of the folder name
1056 self.folder_name = folder_name
1057 ## cached value of the suffix for a groupby selection
1058 self.groupby_addition = groupby_addition
1047 def __init__(self,
…
1064 groupby_part_name=None,
1068 """Process the change-directory"""
1070 folder_name = self.folder_name
1071 if folder_name is None:
1072 if groupby_value is not None:
1073 folder_name = "{groupby_addition}"
1075 folder_name = self.default_folder_name
1077 groupby_addition = self.groupby_addition
1079 if groupby_addition is None:
1080 groupby_addition = self.default_groupby_addition
1082 if groupby_part_name is None and groupby_value is None:
1083 groupby_addition = ""
1085 groupby_addition = formatter.format(groupby_addition,
1086 groupby=groupby_part_name,
1087 groupby_value=groupby_value)
1089 folder_name = formatter.format(folder_name,
1090 groupby_addition=groupby_addition,
1091 groupby=groupby_part_name,
1092 groupby_value=groupby_value)
1094 folder_name = '/'.join(root_save_name(name) for name in folder_name.split('/'))
1096 with root_cd(tdirectory):
1097 with root_cd(folder_name) as tdirectory:
1098 self.wrapped_refiner(harvesting_module,
1100 tdirectory=tdirectory,
1101 groupby_part_name=groupby_part_name,
1102 groupby_value=groupby_value,
1039class CdRefiner(Refiner):
…
1107class ExpertLevelRefiner(Refiner):
1108 """Refiner for expert-level categorization"""
1110 def __init__(self, wrapped_refiner, above_expert_level=None, below_expert_level=None):
1111 """Constructor for this refine
r"""
1113 ## cached value of the wrapped refiner
1114 self.wrapped_refiner = wrapped_refiner
1115 ## cached value of the upper range of the expert level
1116 self.above_expert_level = above_expert_level
1117 ## cached value of the lower range of the expert level
1118 self.below_expert_level = below_expert_level
1110 def __init__(self, wrapped_refiner, above_expert_level=None, below_expert_level=None):
…
1120 def refine(self, harvesting_module, crops, *args, **kwds):
1121 """Process the expert-level categorization"""
1123 above_expert_level = self.above_expert_level
1124 below_expert_level = self.below_expert_level
1127 if above_expert_level is not None:
1128 proceed = proceed and harvesting_module.expert_level > above_expert_level
1130 if below_expert_level is not None:
1131 proceed = proceed and harvesting_module.expert_level < below_expert_level
1134 self.wrapped_refiner(harvesting_module, crops, *args, **kwds)
1137# Meta refiner decorators
1120 def refine(self, harvesting_module, crops, *args, **kwds): …
1107class ExpertLevelRefiner(Refiner):
…
1138def groupby(refiner=None, **kwds):
1139 def group_decorator(wrapped_refiner):
1140 return GroupByRefiner(wrapped_refiner, **kwds)
1142 return group_decorator
1144 return group_decorator(refiner)
1147def select(refiner=None, **kwds):
1148 def select_decorator(wrapped_refiner):
1149 return SelectRefiner(wrapped_refiner, **kwds)
1151 return select_decorator
1153 return select_decorator(refiner)
1156def filter(refiner=None, **kwds):
1157 def filter_decorator(wrapped_refiner):
1158 return FilterRefiner(wrapped_refiner, **kwds)
1160 return filter_decorator
1162 return filter_decorator(refiner)
1165def cd(refiner=None, **kwds):
1166 def cd_decorator(wrapped_refiner):
1167 return CdRefiner(wrapped_refiner, **kwds)
1171 return cd_decorator(refiner)
1174def context(refiner=None,
1175 above_expert_level=None, below_expert_level=None,
1176 folder_name=None, folder_groupby_addition=None,
1177 filter=None, filter_on=None,
1178 groupby=None, exclude_groupby=None,
1179 select=None, exclude=None):
1181 def context_decorator(wrapped_refiner):
1182 # Apply meta refiners in the reverse order that they shall be executed
1183 if exclude is not None or select is not None:
1184 wrapped_refiner = SelectRefiner(wrapped_refiner,
1185 select=select, exclude=exclude)
1187 if folder_name is not None or groupby is not None or folder_groupby_addition is not None:
1188 wrapped_refiner = CdRefiner(wrapped_refiner,
1189 folder_name=folder_name,
1190 groupby_addition=folder_groupby_addition)
1192 if groupby is not None:
1193 wrapped_refiner = GroupByRefiner(wrapped_refiner,
1195 exclude_by=exclude_groupby)
1197 if filter is not None or filter_on is not None:
1198 wrapped_refiner = FilterRefiner(wrapped_refiner,
1202 if above_expert_level is not None or below_expert_level is not None:
1203 wrapped_refiner = ExpertLevelRefiner(wrapped_refiner,
1204 above_expert_level=above_expert_level,
1205 below_expert_level=below_expert_level)
1207 if not isinstance(wrapped_refiner, Refiner):
1208 wrapped_refiner = Refiner(wrapped_refiner)
1210 return wrapped_refiner
1213 return context_decorator
1215 return functools.wraps(refiner)(context_decorator(refiner))
1218def refiner_with_context(refiner_factory):
1219 @functools.wraps(refiner_factory)
1220 def module_decorator_with_context(above_expert_level=None, below_expert_level=None,
1221 folder_name=None, folder_groupby_addition=None,
1222 filter=None, filter_on=None,
1223 groupby=None, exclude_groupby=None,
1224 select=None, exclude=None,
1225 **kwds_for_refiner_factory):
1227 refiner = refiner_factory(**kwds_for_refiner_factory)
1229 return context(refiner,
1230 above_expert_level=above_expert_level, below_expert_level=below_expert_level,
1231 folder_name=folder_name, folder_groupby_addition=folder_groupby_addition,
1232 filter=filter, filter_on=filter_on,
1233 groupby=groupby, exclude_groupby=exclude_groupby,
1234 select=select, exclude=exclude)
1236 return module_decorator_with_context
1239@refiner_with_context
1240def save_fom(**kwds):
1241 return SaveFiguresOfMeritRefiner(**kwds)
1244@refiner_with_context
1245def save_histograms(**kwds):
1246 return SaveHistogramsRefiner(**kwds)
1249@refiner_with_context
1250def save_profiles(**kwds):
1251 return SaveProfilesRefiner(**kwds)
1254@refiner_with_context
1255def save_scatters(**kwds):
1256 return SaveScatterRefiner(**kwds)
1259@refiner_with_context
1260def save_classification_analysis(**kwds):
1261 return SaveClassificationAnalysisRefiner(**kwds)
1264@refiner_with_context
1265def save_pull_analysis(**kwds):
1266 return SavePullAnalysisRefiner(**kwds)
1269@refiner_with_context
1270def save_tree(**kwds):
1271 return SaveTreeRefiner(**kwds)
1274def select_crop_parts(crops, select=None, exclude=None):
1280 if isinstance(select, str):
1283 if isinstance(exclude, str):
1284 exclude = [exclude, ]
1286 if isinstance(crops, collections.abc.MutableMapping):
1287 part_names = list(crops.keys())
1289 if not select and not exclude:
1293 not_selected_part_names = [name for name in part_names if name not in select]
1295 # if the selection item is a callable function do not count it as not selectable yet
1296 select_not_in_part_names = [name for name in select
1297 if not isinstance(name, collections.abc.Callable) and name not in part_names]
1298 if select_not_in_part_names:
1299 get_logger().warning("Cannot select %s, because they are not in crop part names %s",
1300 select_not_in_part_names, sorted(part_names))
1302 not_selected_part_names = []
1305 excluded_part_names = [name for name in part_names if name in exclude]
1307 excluded_part_names = []
1309 excluded_part_names.extend(not_selected_part_names)
1311 # Make a shallow copy
1312 selected_crops = copy.copy(crops)
1313 for part_name in set(excluded_part_names):
1314 del selected_crops[part_name]
1316 if isinstance(select, collections.abc.Mapping):
1317 # select is a rename mapping
1318 for part_name, new_part_name in list(select.items()):
1319 if isinstance(part_name, collections.abc.Callable):
1320 selected_crops[new_part_name] = part_name(**crops)
1321 elif part_name in selected_crops:
1322 parts = selected_crops[part_name]
1323 del selected_crops[part_name]
1324 selected_crops[new_part_name] = parts
1326 return selected_crops
1329 raise ValueError(f"Unrecognised crop {crops} of type {type(crops)}")
1332def filter_crops(crops, filter_function, part_name=None):
1333 if isinstance(filter_function, np.ndarray):
1334 filter_indices = filter_function
1336 parts = crops[part_name]
1337 filter_indices = filter_function(parts)
1339 if isinstance(crops, np.ndarray):
1340 return crops[filter_indices]
1342 elif isinstance(crops, collections.abc.MutableMapping):
1343 # Make a shallow copy
1344 filtered_crops = copy.copy(crops)
1345 for part_name, parts in list(crops.items()):
1346 filtered_crops[part_name] = parts[filter_indices]
1347 return filtered_crops
1350 raise ValueError(f"Unrecognised crop {crops} of type {type(crops)}")
1353def iter_items_sorted_for_key(crops):
1354 # is the type of crops is a dictionary assume, that it should be sorted
1355 # in all other cases the users class has to take care of the sorting
1356 if isinstance(crops, dict):
1357 keys = sorted(crops.keys())
1358 return ((key, crops[key]) for key in keys)
1360 return list(crops.items())
__call__(self, harvesting_module, crops=None, *args, **kwds)
refiner_function
cached copy of the instance's refiner function
refine(self, harvesting_module, *args, **kwds)
__get__(self, harvesting_module, cls=None)
__init__(self, refiner_function=None)