17 from basf2 
import conditions, B2FATAL
 
   21 def get_argument_parser() -> argparse.ArgumentParser:
 
   22     """ Parses the command line options of the fliping mva training and returns the corresponding arguments. """ 
   23     parser = argparse.ArgumentParser()
 
   24     parser.add_argument(
'-train', default=
'', type=str,
 
   25                         help=
'Data file containing ROOT TTree used during training. Default: \'\'.')
 
   26     parser.add_argument(
'-data', default=
'', type=str,
 
   27                         help=
'Data file containing ROOT TTree with independent test data. Default: \'\'.')
 
   28     parser.add_argument(
'-tree', default=
'', type=str,
 
   29                         help=
'Treename in data file. Default: \'\'.')
 
   30     parser.add_argument(
'-mva', default=1, type=int,
 
   31                         help=
'index of mva to be trainned. Default: 1')
 
   36 def get_variables(index=1):
 
   39         var = [
"seed_pz_estimate",
 
   42                "seed_tan_lambda_estimate",
 
   48                "seed_omega_variance",
 
   50                "seed_tan_lambda_variance",
 
   54                "svd_layer3_positionSigma",
 
   57                "InOutArmTimeDifference",
 
   58                "InOutArmTimeDifferenceError",
 
   60                "inGoingArmTimeError",
 
   62                "outGoingArmTimeError"]
 
   67               "flipped_pz_estimate",
 
   68               "tan_lambda_estimate",
 
   83               "quality_flip_indicator",
 
   89 if __name__ == 
"__main__":
 
   91     parser = get_argument_parser()
 
   92     args = parser.parse_args()
 
   94     if args.train == 
'' or args.data == 
'':
 
   95         B2FATAL(
"Missing train or test samples. Terminating here.")
 
   98         B2FATAL(
"Missing Treename. Terminating here.")
 
  100     if args.mva 
not in [1, 2]:
 
  101         B2FATAL(
"MVA number must be either 1 or 2. Terminating here.")
 
  103     conditions.testing_payloads = [
'localdb/database.txt']
 
  106     training_data = basf2_mva.vector(args.train)
 
  107     test_data = basf2_mva.vector(args.data)
 
  110     variables = get_variables(args.mva)
 
  112     general_options = basf2_mva.GeneralOptions()
 
  113     general_options.m_datafiles = training_data
 
  115     # the official name of the weight file is `TRKTrackFlipAndRefit_MVA1_weightfile` 
  116     # But the evaluation scripts taking `Weightfile` as a default name, so that's why this line was commented 
  117     general_options.m_identifier = "TRKTrackFlipAndRefit_MVA1_weightfile" 
  118     or "TRKTrackFlipAndRefit_MVA2_weightfile" for second MVA 
  120     general_options.m_identifier = f
"TRKTrackFlipAndRefit_MVA{args.mva}_weightfile" 
  121     general_options.m_treename = args.tree
 
  122     general_options.m_variables = basf2_mva.vector(*variables)
 
  123     general_options.m_target_variable = 
"ismatched_WC" 
  124     general_options.m_max_events = 0
 
  126     fastbdt_options = basf2_mva.FastBDTOptions()
 
  129         fastbdt_options.m_nTrees = 150
 
  130         fastbdt_options.m_nCuts = 18
 
  131         fastbdt_options.m_nLevels = 4
 
  132         fastbdt_options.m_shrinkage = 0.2
 
  133         fastbdt_options.m_randRatio = 0.5
 
  134         fastbdt_options.m_purityTransformation = 
False 
  135         fastbdt_options.m_sPlot = 
False 
  139         fastbdt_options.m_nTrees = 400
 
  140         fastbdt_options.m_nCuts = 25
 
  141         fastbdt_options.m_nLevels = 2
 
  142         fastbdt_options.m_shrinkage = 0.6
 
  143         fastbdt_options.m_randRatio = 0.5
 
  144         fastbdt_options.m_purityTransformation = 
False 
  145         fastbdt_options.m_sPlot = 
False 
  146     basf2_mva.teacher(general_options, fastbdt_options)
 
  149     p, t = m.apply_expert(test_data, general_options.m_treename)
 
  153     print(
"Variable importances returned my method")
 
  155     imp = np.array([m.importances.get(v, 0.0) 
for v 
in m.variables])
 
  156     width = (np.max(imp) - np.min(imp))
 
  158     for var 
in m.variables:
 
  159         print(var, (m.importances.get(var, 0.0) - np.min(imp))/width * 100)
 
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)