57 {
58
59 Weightfile weightfile;
60
61 GeneralOptions general_options;
62 auto expert_weightfile = Weightfile::load(m_specific_options.m_weightfile);
63 expert_weightfile.getOptions(general_options);
64
65
66 GeneralOptions mod_general_options = m_general_options;
67 mod_general_options.m_variables = general_options.m_variables;
68 mod_general_options.m_spectators = general_options.m_spectators;
69 mod_general_options.m_target_variable = general_options.m_target_variable;
70 mod_general_options.m_weight_variable = general_options.m_weight_variable;
71
72
73 if (m_specific_options.m_variable != "") {
74 if (std::find(mod_general_options.m_variables.begin(), mod_general_options.m_variables.end(),
75 m_specific_options.m_variable) == mod_general_options.m_variables.end() and
76 std::find(mod_general_options.m_spectators.begin(), mod_general_options.m_spectators.end(),
77 m_specific_options.m_variable) == mod_general_options.m_spectators.end() and
78 mod_general_options.m_target_variable != m_specific_options.m_variable and
79 mod_general_options.m_weight_variable != m_specific_options.m_variable) {
80 mod_general_options.m_spectators.push_back(m_specific_options.m_variable);
81 }
82 }
83
84 AbstractInterface::initSupportedInterfaces();
85 auto supported_interfaces = AbstractInterface::getSupportedInterfaces();
86 if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
87 B2ERROR("Couldn't find method named " + general_options.m_method);
88 throw std::runtime_error("Couldn't find method named " + general_options.m_method);
89 }
90 auto expert = supported_interfaces[general_options.m_method]->getExpert();
91 expert->load(expert_weightfile);
92
93 auto prediction = expert->apply(training_data);
94
95 double data_fraction = expert_weightfile.getSignalFraction();
96 double data_over_mc_fraction = data_fraction / (1 - data_fraction);
97
98 double sum_reweights = 0;
99 unsigned long int count_reweights = 0;
100
101 auto isSignal = training_data.getSignals();
102
103 if (m_specific_options.m_variable != "") {
104 auto variable = training_data.getSpectator(training_data.getSpectatorIndex(m_specific_options.m_variable));
105 for (unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
106
107
108 if (isSignal[iEvent]) {
109 continue;
110 }
111
112 if (variable[iEvent] == 1.0) {
113 if (prediction[iEvent] > 0.995)
114 prediction[iEvent] = 0.995;
115 if (prediction[iEvent] < 0.005)
116 prediction[iEvent] = 0.005;
117
118 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
119 sum_reweights += prediction[iEvent];
120 count_reweights++;
121 }
122 }
123 } else {
124 for (unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
125
126
127 if (isSignal[iEvent]) {
128 continue;
129 }
130
131 if (prediction[iEvent] > 0.995)
132 prediction[iEvent] = 0.995;
133 if (prediction[iEvent] < 0.005)
134 prediction[iEvent] = 0.005;
135
136 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
137 sum_reweights += prediction[iEvent];
138 count_reweights++;
139 }
140 }
141
142 double norm = sum_reweights / count_reweights / data_over_mc_fraction;
143
144 weightfile.addOptions(mod_general_options);
145 weightfile.addOptions(m_specific_options);
146 weightfile.addFile("Reweighter_Weightfile", m_specific_options.m_weightfile);
147 weightfile.addSignalFraction(data_fraction);
148 weightfile.addElement("Reweighter_norm", norm);
149
150 return weightfile;
151
152 }