Belle II Software  light-2212-foldex
flavorTaggerTrainingNtuple.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 
13 """
14 This file tests the functionality of sampling needed to train the flavor tagger.
15 """
16 
17 import b2test_utils
18 import basf2
19 from basf2 import set_random_seed, create_path, process
20 import modularAnalysis as ma
21 import flavorTagger as ft
22 import ROOT
23 import os
24 import math
25 
26 # make logging more reproducible by replacing some strings
28 set_random_seed("1337")
29 testinput = [b2test_utils.require_file('analysis/tests/Btonunubar.root')]
30 
31 
33 testpath = create_path()
34 testpath.add_module('RootInput', inputFileNames=testinput)
35 
36 ma.fillParticleListFromMC('nu_tau', '', path=testpath)
37 ma.reconstructMCDecay(decayString='B0:sig -> nu_tau anti-nu_tau', cut='', path=testpath)
38 
39 # Test to build a rest of event from the MC B0 decaying to two neutrinos
40 ma.buildRestOfEvent('B0:sig', path=testpath)
41 
42 # Test MC association of MC particle
43 ma.applyCuts('B0:sig', ' abs(isRelatedRestOfEventB0Flavor) == 1', path=testpath)
44 
45 
47 roe_path = create_path()
48 deadEndPath = create_path()
49 
50 ma.signalSideParticleListsFilter(
51  ['B0:sig'],
52  'nROE_Charged(all, 0) > 0 and abs(qrCombined) == 1',
53  roe_path,
54  deadEndPath)
55 
56 # Create list of ROE electrons
57 ma.fillParticleList('e+:inRoe', 'isInRestOfEvent > 0.5 and passesROEMask() > 0.5 and ' +
58  'isNAN(p) !=1 and isInfinity(p) != 1', path=roe_path)
59 
60 # Save pseudo training sample
61 methodPrefixEventLevel = "FlavorTagger_Belle2_B2nunuBGx1EventLevelElectronFBDT"
62 targetVariable = 'isRightCategory(Electron)'
63 ma.applyCuts('e+:inRoe', 'isRightCategory(mcAssociated) > 0', path=roe_path)
64 
65 # Skip electrons if list is empty
66 eventLevelpath = create_path()
67 SkipEmptyParticleList = basf2.register_module("SkimFilter")
68 SkipEmptyParticleList.set_name('SkimFilter_EventLevelElectron')
69 SkipEmptyParticleList.param('particleLists', 'e+:inRoe')
70 SkipEmptyParticleList.if_true(eventLevelpath, basf2.AfterConditionPath.CONTINUE)
71 roe_path.add_module(SkipEmptyParticleList)
72 
73 ntuple = basf2.register_module('VariablesToNtuple')
74 ntuple.param('fileName', methodPrefixEventLevel + "sampled0.root")
75 ntuple.param('treeName', methodPrefixEventLevel + "_tree")
76 
77 # Call variable aliases from flavor tagger
78 ft.set_FlavorTagger_pid_aliases()
79 
80 variablesToBeSaved = ['useCMSFrame(p)',
81  'useCMSFrame(pt)',
82  'p',
83  'pt',
84  'cosTheta',
85  'electronID',
86  'eid_TOP',
87  'eid_ARICH',
88  'eid_ECL',
89  'BtagToWBosonVariables(recoilMassSqrd)',
90  'BtagToWBosonVariables(pMissCMS)',
91  'BtagToWBosonVariables(cosThetaMissCMS)',
92  'BtagToWBosonVariables(EW90)',
93  'cosTPTO',
94  'chiProb',
95  'hasHighestProbInCat(e+:inRoe, isRightTrack(Electron))',
96  targetVariable, 'ancestorHasWhichFlavor',
97  'isSignal', 'mcPDG', 'mcErrors', 'genMotherPDG',
98  'nMCMatches', 'B0mcErrors'
99  ]
100 ntuple.param('variables', variablesToBeSaved)
101 ntuple.param('particleList', 'e+:inRoe')
102 eventLevelpath.add_module(ntuple)
103 
104 testpath.for_each('RestOfEvent', 'RestOfEvents', roe_path)
105 
106 
107 
109  process(testpath, 5)
110 
111  # Testing
112  assert os.path.isfile(methodPrefixEventLevel + "sampled0.root"), methodPrefixEventLevel + "sampled0.root" + " wasn't created"
113  f = ROOT.TFile(methodPrefixEventLevel + "sampled0.root")
114  t1 = f.Get(methodPrefixEventLevel + "_tree")
115  assert bool(t1), methodPrefixEventLevel + "_tree" + " isn't contained in file"
116  assert t1.GetEntries() > 0, methodPrefixEventLevel + "_tree" + "contains zero entries"
117  for iVariable in variablesToBeSaved:
118  iROOTVariable = str(ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(iVariable))
119  assert t1.GetListOfBranches().Contains(iROOTVariable), iROOTVariable +\
120  " branch is missing from " + methodPrefixEventLevel + "_tree"
121 
122  assert t1.GetEntries() == 40, "40 entries should be saved in the test training ntuple, otherwise some problem happened."
123 
124  mcPDGCodes = [
125  211.0, 211.0, 211.0, 211.0, 211.0, 321.0, 13.0, 11.0,
126  211.0, 211.0, 211.0, 11.0, 211.0, 211.0, 211.0, 211.0,
127  211.0, 11.0, 211.0, 211.0, 211.0, 211.0, 211.0, 211.0, 211.0,
128  2212.0, 211.0, 211.0, 2212.0, 211.0, 13.0, 211.0, 211.0,
129  211.0, 211.0, 211.0, 211.0, 321.0, 321.0, 211.0]
130 
131  for iEntry in range(t1.GetEntries()):
132  t1.GetEntry(iEntry)
133  assert abs(t1.useCMSFrame__bop__bc) > 0, " p* should be greater than 0"
134  assert abs(t1.useCMSFrame__bopt__bc) > 0, " pt* should be greater than 0"
135  assert abs(t1.p) > 0, " p should be greater than 0"
136  assert abs(t1.pt) > 0, " pt should be greater than 0"
137  assert abs(t1.cosTheta) > 0, " cosTheta should be greater than 0"
138  assert abs(t1.electronID) > 0, " electronID should be greater than 0"
139  assert abs(t1.BtagToWBosonVariables__borecoilMassSqrd__bc) > 0, " recoilMassSqrd should be greater than 0"
140  assert abs(t1.BtagToWBosonVariables__bopMissCMS__bc) > 0, " pMissCMS should be greater than 0"
141  assert abs(t1.BtagToWBosonVariables__bocosThetaMissCMS__bc) > 0, " cosThetaMissCMS should be greater than 0"
142  assert abs(t1.BtagToWBosonVariables__boEW90__bc) > 0, " EW90 should be greater than 0"
143  assert abs(t1.cosTPTO) > 0, " cosTPTO should be greater than 0"
144  assert abs(t1.chiProb) > 0, " chiProb should be greater than 0"
145  if math.isnan(
146  t1.hasHighestProbInCat__boe__pl__clinRoe__cm__spisRightTrack__boElectron__bc__bc
147  ):
148  basf2.B2FATAL(" hasHighestProbInCat Electron should not be nan")
149  if math.isnan(t1.isRightCategory__boElectron__bc):
150  basf2.B2FATAL(" isRightCategory Electron should not be nan ")
151  if math.isnan(t1.ancestorHasWhichFlavor):
152  basf2.B2FATAL(" ancestorHasWhichFlavor should not be nan")
153  if math.isnan(t1.isSignal):
154  basf2.B2FATAL(" isSignal should not be nan")
155  assert abs(t1.mcPDG) == mcPDGCodes[iEntry], " Some mismatch between PDG codes happened in entry " + iEntry
156  if math.isnan(t1.mcErrors):
157  basf2.B2FATAL(" mcErrors should not be equal to nan")
158  assert abs(t1.genMotherPDG) > 0, " genMotherPDG should be greater than 0"
159  assert abs(t1.nMCMatches) > 0, " nMCMatches should be greater than 0"
def configure_logging_for_tests(user_replacements=None)
Definition: __init__.py:106
def require_file(filename, data_type="", py_case=None)
Definition: __init__.py:54
def clean_working_directory()
Definition: __init__.py:185