19 return logging.getLogger(__name__)
23 formatter = TolerateMissingKeyFormatter()
27 """Python module to refine a peeled dictionary"""
30 """Constructor of the Refiner instance"""
34 def __get__(self, harvesting_module, cls=None):
35 """Getter of the Refiner instance"""
36 if harvesting_module
is None:
43 def bound_call(*args, **kwds):
44 return refine(harvesting_module, *args, **kwds)
47 def __call__(self, harvesting_module, crops=None, *args, **kwds):
48 """implementation of the function-call of the Refiner instance
50 r(harvester) # decoration
51 r(harvester, crops, args, keywords) # refinement
55 harvesting_module.refiners.append(self)
56 return harvesting_module
59 return self.
refine(harvesting_module, crops, *args, **kwds)
61 def refine(self, harvesting_module, *args, **kwds):
62 """Apply the instance's refiner function"""
67 """Refiner for figures of merit"""
69 default_name =
"{module.id}_figures_of_merit{groupby_key}"
71 default_title =
"Figures of merit in {module.title}"
73 default_contact =
"{module.contact}"
75 default_description =
"Figures of merit are the {aggregation.__name__} of {keys}"
77 default_check =
"Check for reasonable values"
79 default_key =
"{aggregation.__name__}_{part_name}"
87 default_aggregation = mean
98 """Constructor for this refiner"""
123 groupby_part_name=None,
126 """Process the figures of merit"""
136 replacement_dict = dict(
138 module=harvesting_module,
139 aggregation=aggregation,
140 groupby_key=
'_' + groupby_part_name + groupby_value
if groupby_part_name
else "",
141 groupby=groupby_part_name,
142 groupby_value=groupby_value,
145 name = formatter.format(name, **replacement_dict)
146 title = formatter.format(title, **replacement_dict)
147 contact = formatter.format(contact, **replacement_dict)
153 for part_name, parts
in iter_items_sorted_for_key(crops):
155 key = formatter.format(key, part_name=part_name, **replacement_dict)
158 keys = list(figures_of_merit.keys())
160 description = formatter.format(description, keys=keys, **replacement_dict)
161 check = formatter.format(check, keys=keys, **replacement_dict)
163 figures_of_merit.description = description
164 figures_of_merit.check = check
167 figures_of_merit.write(tdirectory)
169 print(figures_of_merit)
173 """Refiner for histograms"""
175 default_name =
"{module.id}_{part_name}_histogram{groupby_key}{stackby_key}"
177 default_title =
"Histogram of {part_name}{groupby_key}{stackby_key} from {module.title}"
179 default_contact =
"{module.contact}"
181 default_description =
"This is a histogram of {part_name}{groupby_key}{stackby_key}."
183 default_check =
"Check if the distribution is reasonable"
194 outlier_z_score=None,
195 allow_discrete=False,
199 """Constructor for this refiner"""
201 super(SaveHistogramsRefiner, self).
__init__()
238 groupby_part_name=None,
241 """Process the histogram"""
245 stackby_parts = crops[stackby]
249 replacement_dict = dict(
251 module=harvesting_module,
252 stackby_key=
' stacked by ' + stackby
if stackby
else "",
253 groupby_key=
' in group ' + groupby_part_name + groupby_value
if groupby_part_name
else "",
257 contact = formatter.format(contact, **replacement_dict)
259 for part_name, parts
in iter_items_sorted_for_key(crops):
265 name = formatter.format(name, part_name=part_name, **replacement_dict)
266 title = formatter.format(title, part_name=part_name, **replacement_dict)
267 description = formatter.format(description, part_name=part_name, **replacement_dict)
268 check = formatter.format(check, part_name=part_name, **replacement_dict)
271 histogram.hist(parts,
277 stackby=stackby_parts)
279 histogram.title = title
280 histogram.contact = contact
281 histogram.description = description
282 histogram.check = check
284 histogram.xlabel = compose_axis_label(part_name)
292 fit_method_name =
'fit_' + str(self.
fit)
294 fit_method = getattr(histogram, fit_method_name)
295 except AttributeError:
296 histogram.fit(str(fit), **kwds)
301 histogram.write(tdirectory)
305 """Refiner for profile histograms and 2D scatterplots"""
307 plot_kind =
"profile"
324 outlier_z_score=None,
327 skip_single_valued=False,
328 allow_discrete=False):
329 """Constructor for this refiner"""
382 groupby_part_name=None,
385 """Process the profile histogram / scatterplot"""
389 stackby_parts = crops[stackby]
393 replacement_dict = dict(
395 module=harvesting_module,
396 stackby_key=
' stacked by ' + stackby
if stackby
else "",
397 groupby_key=
' in group ' + groupby_part_name + groupby_value
if groupby_part_name
else "",
400 contact = self.
contact or self.default_contact
401 contact = formatter.format(contact, **replacement_dict)
403 y_crops = select_crop_parts(crops, select=self.
y)
404 x_crops = select_crop_parts(crops, select=self.
x, exclude=self.
y)
406 for y_part_name, y_parts
in iter_items_sorted_for_key(y_crops):
407 for x_part_name, x_parts
in iter_items_sorted_for_key(x_crops):
410 get_logger().info(
'Skipping "%s" by "%s" profile because x has only a single value "%s"',
417 get_logger().info(
'Skipping "%s" by "%s" profile because y has only a single value "%s"',
423 name = self.
name or self.default_name
424 title = self.
title or self.default_title
425 description = self.
description or self.default_description
426 check = self.
check or self.default_check
428 name = formatter.format(name,
429 x_part_name=x_part_name,
430 y_part_name=y_part_name,
433 title = formatter.format(title,
434 x_part_name=x_part_name,
435 y_part_name=y_part_name,
438 description = formatter.format(description,
439 x_part_name=x_part_name,
440 y_part_name=y_part_name,
443 check = formatter.format(check,
444 x_part_name=x_part_name,
445 y_part_name=y_part_name,
451 if plot_kind ==
"profile":
452 profile_plot.profile(x_parts,
461 stackby=stackby_parts)
469 fit_method_name =
'fit_' + str(self.
fit)
471 fit_method = getattr(profile_plot, fit_method_name)
472 except BaseException:
473 profile_plot.fit(str(fit), **kwds)
477 elif plot_kind ==
"scatter":
478 profile_plot.scatter(x_parts,
483 stackby=stackby_parts)
485 profile_plot.title = title
486 profile_plot.contact = contact
487 profile_plot.description = description
488 profile_plot.check = check
490 profile_plot.xlabel = compose_axis_label(x_part_name)
491 profile_plot.ylabel = compose_axis_label(y_part_name, self.
y_unit)
494 profile_plot.write(tdirectory)
498 """check if a list has at least two unique values"""
508 """Refiner for profile histograms"""
510 default_name =
"{module.id}_{y_part_name}_by_{x_part_name}_profile{groupby_key}{stackby_key}"
512 default_title =
"Profile of {y_part_name} by {x_part_name} from {module.title}"
514 default_contact =
"{module.contact}"
516 default_description =
"This is a profile of {y_part_name} over {x_part_name}."
518 default_check =
"Check if the trend line is reasonable."
521 plot_kind =
"profile"
525 """Refiner for 2D scatterplots"""
527 default_name =
"{module.id}_{y_part_name}_by_{x_part_name}_scatter{groupby_key}{stackby_key}"
529 default_title =
"Scatter of {y_part_name} by {x_part_name} from {module.title}"
531 default_contact =
"{module.contact}"
533 default_description =
"This is a scatter of {y_part_name} over {x_part_name}."
535 default_check =
"Check if the distributions is reasonable."
538 plot_kind =
"scatter"
542 """Refiner for truth-classification analyses"""
545 default_contact =
"{module.contact}"
548 default_truth_name =
"{part_name}_truth"
550 default_estimate_name =
"{part_name}_estimate"
561 outlier_z_score=None,
562 allow_discrete=False,
564 """Constructor for this refiner"""
595 groupby_part_name=None,
598 """Process the truth-classification analysis"""
600 replacement_dict = dict(
602 module=harvesting_module,
603 groupby_key=
'_' + groupby_part_name + groupby_value
if groupby_part_name
else "",
604 groupby=groupby_part_name,
605 groupby_value=groupby_value,
609 contact = formatter.format(contact, **replacement_dict)
616 truth_name = formatter.format(truth_name, part_name=self.
part_name)
617 truths = crops[truth_name]
624 if isinstance(estimate_name, str):
625 estimate_names = [estimate_name, ]
627 estimate_names = estimate_name
629 for estimate_name
in estimate_names:
630 estimate_name = formatter.format(estimate_name, part_name=self.
part_name)
631 estimates = crops[estimate_name]
643 classification_analysis.analyse(estimates, truths)
646 classification_analysis.write(tdirectory)
650 """Refiner for pull analyses"""
653 default_name =
"{module.id}_{quantity_name}"
655 default_contact =
"{module.contact}"
657 default_title_postfix =
" from {module.title}"
660 default_truth_name =
"{part_name}_truth"
662 default_estimate_name =
"{part_name}_estimate"
664 default_variance_name =
"{part_name}_variance"
678 outlier_z_score=None,
681 """Constructor for this refiner"""
692 if part_names
is not None:
695 if part_name
is not None:
724 groupby_part_name=None,
727 """Process the pull analysis"""
729 replacement_dict = dict(
731 module=harvesting_module,
733 groupby_key=
'_' + groupby_part_name + groupby_value
if groupby_part_name
else "",
734 groupby=groupby_part_name,
735 groupby_value=groupby_value,
739 contact = formatter.format(contact, **replacement_dict)
744 auxiliaries = select_crop_parts(crops, self.
aux_names)
749 name = formatter.format(name, part_name=part_name, **replacement_dict)
750 plot_name = name +
"_{subplot_name}"
753 if title_postfix
is None:
756 title_postfix = formatter.format(title_postfix, part_name=part_name, **replacement_dict)
757 plot_title =
"{subplot_title} of {quantity_name}" + title_postfix
774 truth_name = formatter.format(truth_name, part_name=part_name)
775 estimate_name = formatter.format(estimate_name, part_name=part_name)
776 variance_name = formatter.format(variance_name, part_name=part_name)
778 truths = crops[truth_name]
779 estimates = crops[estimate_name]
781 variances = crops[variance_name]
794 plot_title=plot_title)
796 pull_analysis.analyse(truths,
799 auxiliaries=auxiliaries,
800 which_plots=which_plots)
802 pull_analysis.contact = contact
805 pull_analysis.write(tdirectory)
809 """Refiner for ROOT TTrees"""
812 default_name =
"{module.id}_tree"
814 default_title =
"Tree of {module.id}"
819 """Constructor for this refiner"""
820 super(SaveTreeRefiner, self).
__init__()
831 groupby_part_name=None,
834 """Process the TTree"""
836 replacement_dict = dict(
838 module=harvesting_module,
839 groupby_key=
'_' + groupby_part_name + groupby_value
if groupby_part_name
else "",
840 groupby=groupby_part_name,
841 groupby_value=groupby_value,
844 with root_cd(tdirectory):
848 name = formatter.format(name, **replacement_dict)
849 title = formatter.format(title, **replacement_dict)
851 output_ttree = ROOT.TTree(root_save_name(name), title)
852 for part_name, parts
in iter_items_sorted_for_key(crops):
853 self.
add_branch(output_ttree, part_name, parts)
855 output_ttree.FlushBaskets()
859 """Add a TBranch to the TTree"""
860 input_value = np.zeros(1, dtype=float)
862 branch_type_spec =
'%s/D' % part_name
863 tbranch = output_ttree.Branch(part_name, input_value, branch_type_spec)
865 if output_ttree.GetNbranches() == 1:
870 input_value[0] = value
875 input_value[0] = value
878 output_ttree.GetEntry(0)
879 output_ttree.ResetBranchAddress(tbranch)
880 also_subbranches =
True
881 output_ttree.DropBranchFromCache(tbranch, also_subbranches)
885 """Refiner for filters"""
887 def __init__(self, wrapped_refiner, filter=None, on=None):
888 """Constructor for this refiner"""
902 def refine(self, harvesting_module, crops, *args, **kwds):
903 """Process this filter"""
904 filtered_crops = filter_crops(crops, self.
filter, part_name=self.
on)
909 """Refiner for selection"""
911 def __init__(self, wrapped_refiner, select=[], exclude=[]):
912 """Constructor for this refiner"""
921 def refine(self, harvesting_module, crops, *args, **kwds):
922 """Process this selection"""
923 selected_crops = select_crop_parts(crops, select=self.
select, exclude=self.
exclude)
928 """Refiner for grouping"""
931 default_exclude_by =
True
937 """Constructor for this refiner"""
949 groupby_part_name=None,
953 """Process this grouping"""
958 if isinstance(by, str)
or by
is None:
963 for groupby_spec
in by:
964 if groupby_spec
is None:
969 groupby_part_name=
None,
975 elif isinstance(groupby_spec, str):
976 part_name = groupby_spec
977 groupby_parts = crops[part_name]
978 unique_values, index_of_values = np.unique(groupby_parts, return_inverse=
True)
979 groupby_values = [
" = {value}]".format(value=value)
for value
in unique_values]
981 elif isinstance(groupby_spec, tuple):
982 part_name = groupby_spec[0]
983 cuts = groupby_spec[1]
985 groupby_parts = crops[part_name]
988 digitization_cuts = list(np.sort(cuts))
989 if digitization_cuts[-1] != np.inf:
990 digitization_cuts.append(np.inf)
991 index_of_values = np.digitize(groupby_parts, digitization_cuts, right=
True)
993 groupby_values = [
"below {upper_bound}".format(upper_bound=digitization_cuts[0])]
994 bin_bounds = list(zip(digitization_cuts[0:], digitization_cuts[1:]))
995 for lower_bound, upper_bound
in bin_bounds:
996 if lower_bound == upper_bound:
998 groupby_values.append(
"= {lower_bound}".format(lower_bound=lower_bound))
999 elif upper_bound == np.inf:
1000 groupby_values.append(
"above {lower_bound}".format(lower_bound=lower_bound))
1002 groupby_values.append(
"between {lower_bound} and {upper_bound}".format(lower_bound=lower_bound,
1003 upper_bound=upper_bound))
1004 groupby_values.append(
"is nan")
1005 assert len(groupby_values) == len(digitization_cuts) + 1
1008 raise ValueError(
"Unknown groupby specification %s" % groupby_spec)
1011 selected_crops = select_crop_parts(crops, exclude=part_name
if self.
exclude_by else None)
1012 for index_of_value, groupby_value
in enumerate(groupby_values):
1013 indices_for_value = index_of_values == index_of_value
1014 if not np.any(indices_for_value):
1017 filtered_crops = filter_crops(selected_crops, indices_for_value)
1021 groupby_part_name=part_name,
1022 groupby_value=groupby_value,
1028 """Refiner for change-directory"""
1031 default_folder_name =
""
1033 default_groupby_addition =
"_groupby_{groupby}_{groupby_value}"
1038 groupby_addition=None):
1039 """Constructor for this refiner"""
1052 groupby_part_name=None,
1056 """Process the change-directory"""
1059 if folder_name
is None:
1060 if groupby_value
is not None:
1061 folder_name =
"{groupby_addition}"
1067 if groupby_addition
is None:
1070 if groupby_part_name
is None and groupby_value
is None:
1071 groupby_addition =
""
1073 groupby_addition = formatter.format(groupby_addition,
1074 groupby=groupby_part_name,
1075 groupby_value=groupby_value)
1077 folder_name = formatter.format(folder_name,
1078 groupby_addition=groupby_addition,
1079 groupby=groupby_part_name,
1080 groupby_value=groupby_value)
1082 folder_name =
'/'.join(root_save_name(name)
for name
in folder_name.split(
'/'))
1084 with root_cd(tdirectory):
1085 with root_cd(folder_name)
as tdirectory:
1088 tdirectory=tdirectory,
1089 groupby_part_name=groupby_part_name,
1090 groupby_value=groupby_value,
1096 """Refiner for expert-level categorization"""
1098 def __init__(self, wrapped_refiner, above_expert_level=None, below_expert_level=None):
1099 """Constructor for this refiner"""
1108 def refine(self, harvesting_module, crops, *args, **kwds):
1109 """Process the expert-level categorization"""
1115 if above_expert_level
is not None:
1116 proceed = proceed
and harvesting_module.expert_level > above_expert_level
1118 if below_expert_level
is not None:
1119 proceed = proceed
and harvesting_module.expert_level < below_expert_level
1126 def groupby(refiner=None, **kwds):
1127 def group_decorator(wrapped_refiner):
1130 return group_decorator
1132 return group_decorator(refiner)
1135 def select(refiner=None, **kwds):
1136 def select_decorator(wrapped_refiner):
1139 return select_decorator
1141 return select_decorator(refiner)
1144 def filter(refiner=None, **kwds):
1145 def filter_decorator(wrapped_refiner):
1148 return filter_decorator
1150 return filter_decorator(refiner)
1153 def cd(refiner=None, **kwds):
1154 def cd_decorator(wrapped_refiner):
1155 return CdRefiner(wrapped_refiner, **kwds)
1159 return cd_decorator(refiner)
1162 def context(refiner=None,
1163 above_expert_level=None, below_expert_level=None,
1164 folder_name=None, folder_groupby_addition=None,
1165 filter=None, filter_on=None,
1166 groupby=None, exclude_groupby=None,
1167 select=None, exclude=None):
1169 def context_decorator(wrapped_refiner):
1171 if exclude
is not None or select
is not None:
1173 select=select, exclude=exclude)
1175 if folder_name
is not None or groupby
is not None or folder_groupby_addition
is not None:
1176 wrapped_refiner =
CdRefiner(wrapped_refiner,
1177 folder_name=folder_name,
1178 groupby_addition=folder_groupby_addition)
1180 if groupby
is not None:
1183 exclude_by=exclude_groupby)
1185 if filter
is not None or filter_on
is not None:
1190 if above_expert_level
is not None or below_expert_level
is not None:
1192 above_expert_level=above_expert_level,
1193 below_expert_level=below_expert_level)
1195 if not isinstance(wrapped_refiner, Refiner):
1196 wrapped_refiner =
Refiner(wrapped_refiner)
1198 return wrapped_refiner
1201 return context_decorator
1203 return functools.wraps(refiner)(context_decorator(refiner))
1206 def refiner_with_context(refiner_factory):
1207 @functools.wraps(refiner_factory)
1208 def module_decorator_with_context(above_expert_level=None, below_expert_level=None,
1209 folder_name=None, folder_groupby_addition=None,
1210 filter=None, filter_on=None,
1211 groupby=None, exclude_groupby=None,
1212 select=None, exclude=None,
1213 **kwds_for_refiner_factory):
1215 refiner = refiner_factory(**kwds_for_refiner_factory)
1217 return context(refiner,
1218 above_expert_level=above_expert_level, below_expert_level=below_expert_level,
1219 folder_name=folder_name, folder_groupby_addition=folder_groupby_addition,
1220 filter=filter, filter_on=filter_on,
1221 groupby=groupby, exclude_groupby=exclude_groupby,
1222 select=select, exclude=exclude)
1224 return module_decorator_with_context
1227 @refiner_with_context
1228 def save_fom(**kwds):
1232 @refiner_with_context
1233 def save_histograms(**kwds):
1237 @refiner_with_context
1238 def save_profiles(**kwds):
1242 @refiner_with_context
1243 def save_scatters(**kwds):
1247 @refiner_with_context
1248 def save_classification_analysis(**kwds):
1252 @refiner_with_context
1253 def save_pull_analysis(**kwds):
1257 @refiner_with_context
1258 def save_tree(**kwds):
1262 def select_crop_parts(crops, select=[], exclude=[]):
1263 if isinstance(select, str):
1266 if isinstance(exclude, str):
1267 exclude = [exclude, ]
1269 if isinstance(crops, collections.MutableMapping):
1270 part_names = list(crops.keys())
1272 if not select
and not exclude:
1276 not_selected_part_names = [name
for name
in part_names
if name
not in select]
1279 select_not_in_part_names = [name
for name
in select
1280 if not isinstance(name, collections.Callable)
and name
not in part_names]
1281 if select_not_in_part_names:
1282 get_logger().warning(
"Cannot select %s, because they are not in crop part names %s",
1283 select_not_in_part_names, sorted(part_names))
1285 not_selected_part_names = []
1288 excluded_part_names = [name
for name
in part_names
if name
in exclude]
1290 excluded_part_names = []
1292 excluded_part_names.extend(not_selected_part_names)
1295 selected_crops = copy.copy(crops)
1296 for part_name
in set(excluded_part_names):
1297 del selected_crops[part_name]
1299 if isinstance(select, collections.Mapping):
1301 for part_name, new_part_name
in list(select.items()):
1302 if isinstance(part_name, collections.Callable):
1303 selected_crops[new_part_name] = part_name(**crops)
1304 elif part_name
in selected_crops:
1305 parts = selected_crops[part_name]
1306 del selected_crops[part_name]
1307 selected_crops[new_part_name] = parts
1309 return selected_crops
1312 raise ValueError(
"Unrecognised crop %s of type %s" % (crop, type(crop)))
1315 def filter_crops(crops, filter_function, part_name=None):
1316 if isinstance(filter_function, np.ndarray):
1317 filter_indices = filter_function
1319 parts = crops[part_name]
1320 filter_indices = filter_function(parts)
1322 if isinstance(crops, np.ndarray):
1323 return crops[filter_indices]
1325 elif isinstance(crops, collections.MutableMapping):
1327 filtered_crops = copy.copy(crops)
1328 for part_name, parts
in list(crops.items()):
1329 filtered_crops[part_name] = parts[filter_indices]
1330 return filtered_crops
1333 raise ValueError(
"Unrecognised crop %s of type %s" % (crop, type(crop)))
1336 def iter_items_sorted_for_key(crops):
1339 if isinstance(crops, dict):
1340 keys = sorted(crops.keys())
1341 return ((key, crops[key])
for key
in keys)
1343 return list(crops.items())