Belle II Software  release-08-01-10
flavorTaggerTrainingNtuple.py
1 #!/usr/bin/env python3
2 
3 
10 
11 
12 """
13 This file tests the functionality of sampling needed to train the flavor tagger.
14 """
15 
16 import b2test_utils
17 import basf2
18 from basf2 import set_random_seed, create_path, process
19 import modularAnalysis as ma
20 import flavorTagger as ft
21 import ROOT
22 import os
23 import math
24 
25 # make logging more reproducible by replacing some strings
27 set_random_seed("1337")
28 testinput = [b2test_utils.require_file('analysis/tests/Btonunubar.root')]
29 
30 
32 testpath = create_path()
33 testpath.add_module('RootInput', inputFileNames=testinput)
34 
35 ma.fillParticleListFromMC('nu_tau', '', path=testpath)
36 ma.reconstructMCDecay(decayString='B0:sig -> nu_tau anti-nu_tau', cut='', path=testpath)
37 
38 # Test to build a rest of event from the MC B0 decaying to two neutrinos
39 ma.buildRestOfEvent('B0:sig', path=testpath)
40 
41 # Test MC association of MC particle
42 ma.applyCuts('B0:sig', ' abs(isRelatedRestOfEventB0Flavor) == 1', path=testpath)
43 
44 
46 roe_path = create_path()
47 deadEndPath = create_path()
48 
49 ma.signalSideParticleListsFilter(
50  ['B0:sig'],
51  'nROE_Charged(all, 0) > 0 and abs(qrCombined) == 1',
52  roe_path,
53  deadEndPath)
54 
55 # Create list of ROE electrons
56 ma.fillParticleList('e+:inRoe', 'isInRestOfEvent > 0.5 and passesROEMask() > 0.5 and ' +
57  'isNAN(p) !=1 and isInfinity(p) != 1', path=roe_path)
58 
59 # Save pseudo training sample
60 methodPrefixEventLevel = "FlavorTagger_Belle2_B2nunuBGx1EventLevelElectronFBDT"
61 targetVariable = 'isRightCategory(Electron)'
62 ma.applyCuts('e+:inRoe', 'isRightCategory(mcAssociated) > 0', path=roe_path)
63 
64 # Skip electrons if list is empty
65 eventLevelpath = create_path()
66 SkipEmptyParticleList = basf2.register_module("SkimFilter")
67 SkipEmptyParticleList.set_name('SkimFilter_EventLevelElectron')
68 SkipEmptyParticleList.param('particleLists', 'e+:inRoe')
69 SkipEmptyParticleList.if_true(eventLevelpath, basf2.AfterConditionPath.CONTINUE)
70 roe_path.add_module(SkipEmptyParticleList)
71 
72 ntuple = basf2.register_module('VariablesToNtuple')
73 ntuple.param('fileName', methodPrefixEventLevel + "sampled0.root")
74 ntuple.param('treeName', methodPrefixEventLevel + "_tree")
75 
76 # Call variable aliases from flavor tagger
77 ft.set_FlavorTagger_pid_aliases()
78 
79 variablesToBeSaved = ['useCMSFrame(p)',
80  'useCMSFrame(pt)',
81  'p',
82  'pt',
83  'cosTheta',
84  'electronID',
85  'eid_TOP',
86  'eid_ARICH',
87  'eid_ECL',
88  'BtagToWBosonVariables(recoilMassSqrd)',
89  'BtagToWBosonVariables(pMissCMS)',
90  'BtagToWBosonVariables(cosThetaMissCMS)',
91  'BtagToWBosonVariables(EW90)',
92  'cosTPTO',
93  'chiProb',
94  'hasHighestProbInCat(e+:inRoe, isRightTrack(Electron))',
95  targetVariable, 'ancestorHasWhichFlavor',
96  'isSignal', 'mcPDG', 'mcErrors', 'genMotherPDG',
97  'nMCMatches', 'B0mcErrors'
98  ]
99 ntuple.param('variables', variablesToBeSaved)
100 ntuple.param('particleList', 'e+:inRoe')
101 eventLevelpath.add_module(ntuple)
102 
103 testpath.for_each('RestOfEvent', 'RestOfEvents', roe_path)
104 
105 
106 
108  process(testpath, 5)
109 
110  # Testing
111  assert os.path.isfile(methodPrefixEventLevel + "sampled0.root"), methodPrefixEventLevel + "sampled0.root" + " wasn't created"
112  f = ROOT.TFile(methodPrefixEventLevel + "sampled0.root")
113  t1 = f.Get(methodPrefixEventLevel + "_tree")
114  assert bool(t1), methodPrefixEventLevel + "_tree" + " isn't contained in file"
115  assert t1.GetEntries() > 0, methodPrefixEventLevel + "_tree" + "contains zero entries"
116  for iVariable in variablesToBeSaved:
117  iROOTVariable = str(ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(iVariable))
118  assert t1.GetListOfBranches().Contains(iROOTVariable), iROOTVariable +\
119  " branch is missing from " + methodPrefixEventLevel + "_tree"
120 
121  assert t1.GetEntries() == 40, "40 entries should be saved in the test training ntuple, otherwise some problem happened."
122 
123  mcPDGCodes = [
124  211.0, 211.0, 211.0, 211.0, 211.0, 321.0, 13.0, 11.0,
125  211.0, 211.0, 211.0, 11.0, 211.0, 211.0, 211.0, 211.0,
126  211.0, 11.0, 211.0, 211.0, 211.0, 211.0, 211.0, 211.0, 211.0,
127  2212.0, 211.0, 211.0, 2212.0, 211.0, 13.0, 211.0, 211.0,
128  211.0, 211.0, 211.0, 211.0, 321.0, 321.0, 211.0]
129 
130  for iEntry in range(t1.GetEntries()):
131  t1.GetEntry(iEntry)
132  assert abs(t1.useCMSFrame__bop__bc) > 0, " p* should be greater than 0"
133  assert abs(t1.useCMSFrame__bopt__bc) > 0, " pt* should be greater than 0"
134  assert abs(t1.p) > 0, " p should be greater than 0"
135  assert abs(t1.pt) > 0, " pt should be greater than 0"
136  assert abs(t1.cosTheta) > 0, " cosTheta should be greater than 0"
137  assert abs(t1.electronID) > 0, " electronID should be greater than 0"
138  assert abs(t1.BtagToWBosonVariables__borecoilMassSqrd__bc) > 0, " recoilMassSqrd should be greater than 0"
139  assert abs(t1.BtagToWBosonVariables__bopMissCMS__bc) > 0, " pMissCMS should be greater than 0"
140  assert abs(t1.BtagToWBosonVariables__bocosThetaMissCMS__bc) > 0, " cosThetaMissCMS should be greater than 0"
141  assert abs(t1.BtagToWBosonVariables__boEW90__bc) > 0, " EW90 should be greater than 0"
142  assert abs(t1.cosTPTO) > 0, " cosTPTO should be greater than 0"
143  assert abs(t1.chiProb) > 0, " chiProb should be greater than 0"
144  if math.isnan(
145  t1.hasHighestProbInCat__boe__pl__clinRoe__cm__spisRightTrack__boElectron__bc__bc
146  ):
147  basf2.B2FATAL(" hasHighestProbInCat Electron should not be nan")
148  if math.isnan(t1.isRightCategory__boElectron__bc):
149  basf2.B2FATAL(" isRightCategory Electron should not be nan ")
150  if math.isnan(t1.ancestorHasWhichFlavor):
151  basf2.B2FATAL(" ancestorHasWhichFlavor should not be nan")
152  if math.isnan(t1.isSignal):
153  basf2.B2FATAL(" isSignal should not be nan")
154  assert abs(t1.mcPDG) == mcPDGCodes[iEntry], " Some mismatch between PDG codes happened in entry " + iEntry
155  if math.isnan(t1.mcErrors):
156  basf2.B2FATAL(" mcErrors should not be equal to nan")
157  assert abs(t1.genMotherPDG) > 0, " genMotherPDG should be greater than 0"
158  assert abs(t1.nMCMatches) > 0, " nMCMatches should be greater than 0"
def configure_logging_for_tests(user_replacements=None)
Definition: __init__.py:106
def require_file(filename, data_type="", py_case=None)
Definition: __init__.py:54
def clean_working_directory()
Definition: __init__.py:189