Belle II Software development
flavorTaggerTrainingNtuple.py
1#!/usr/bin/env python3
2
3
10
11
12"""
13This file tests the functionality of sampling needed to train the flavor tagger.
14"""
15
16import b2test_utils
17import basf2
18from basf2 import set_random_seed, create_path, process
19import modularAnalysis as ma
20import flavorTagger as ft
21import ROOT
22import os
23import math
24
25# make logging more reproducible by replacing some strings
27set_random_seed("1337")
28testinput = [b2test_utils.require_file('analysis/tests/Btonunubar.root')]
29
30
32testpath = create_path()
33testpath.add_module('RootInput', inputFileNames=testinput)
34
35ma.fillParticleListFromMC('nu_tau', '', path=testpath)
36ma.reconstructMCDecay(decayString='B0:sig -> nu_tau anti-nu_tau', cut='', path=testpath)
37
38# Test to build a rest of event from the MC B0 decaying to two neutrinos
39ma.buildRestOfEvent('B0:sig', path=testpath)
40
41# Test MC association of MC particle
42ma.applyCuts('B0:sig', ' abs(isRelatedRestOfEventB0Flavor) == 1', path=testpath)
43
44
46roe_path = create_path()
47deadEndPath = create_path()
48
49ma.signalSideParticleListsFilter(
50 ['B0:sig'],
51 'nROE_Charged(all, 0) > 0 and abs(qrCombined) == 1',
52 roe_path,
53 deadEndPath)
54
55# Create list of ROE electrons
56ma.fillParticleList('e+:inRoe', 'isInRestOfEvent > 0.5 and passesROEMask() > 0.5 and ' +
57 'isNAN(p) !=1 and isInfinity(p) != 1', path=roe_path)
58
59# Save pseudo training sample
60methodPrefixEventLevel = "FlavorTagger_Belle2_B2nunuBGx1EventLevelElectronFBDT"
61targetVariable = 'isRightCategory(Electron)'
62ma.applyCuts('e+:inRoe', 'isRightCategory(mcAssociated) > 0', path=roe_path)
63
64# Skip electrons if list is empty
65eventLevelpath = create_path()
66SkipEmptyParticleList = basf2.register_module("SkimFilter")
67SkipEmptyParticleList.set_name('SkimFilter_EventLevelElectron')
68SkipEmptyParticleList.param('particleLists', 'e+:inRoe')
69SkipEmptyParticleList.if_true(eventLevelpath, basf2.AfterConditionPath.CONTINUE)
70roe_path.add_module(SkipEmptyParticleList)
71
72ntuple = basf2.register_module('VariablesToNtuple')
73ntuple.param('fileName', methodPrefixEventLevel + "sampled0.root")
74ntuple.param('treeName', methodPrefixEventLevel + "_tree")
75
76# Call variable aliases from flavor tagger
77ft.set_FlavorTagger_pid_aliases()
78
79variablesToBeSaved = ['useCMSFrame(p)',
80 'useCMSFrame(pt)',
81 'p',
82 'pt',
83 'cosTheta',
84 'electronID',
85 'eid_TOP',
86 'eid_ARICH',
87 'eid_ECL',
88 'BtagToWBosonVariables(recoilMassSqrd)',
89 'BtagToWBosonVariables(pMissCMS)',
90 'BtagToWBosonVariables(cosThetaMissCMS)',
91 'BtagToWBosonVariables(EW90)',
92 'cosTPTO',
93 'chiProb',
94 'hasHighestProbInCat(e+:inRoe, isRightTrack(Electron))',
95 targetVariable, 'ancestorHasWhichFlavor',
96 'isSignal', 'mcPDG', 'mcErrors', 'genMotherPDG',
97 'nMCMatches', 'B0mcErrors'
98 ]
99ntuple.param('variables', variablesToBeSaved)
100ntuple.param('particleList', 'e+:inRoe')
101eventLevelpath.add_module(ntuple)
102
103testpath.for_each('RestOfEvent', 'RestOfEvents', roe_path)
104
105
106
108 process(testpath, 5)
109
110 # Testing
111 assert os.path.isfile(methodPrefixEventLevel + "sampled0.root"), methodPrefixEventLevel + "sampled0.root" + " wasn't created"
112 f = ROOT.TFile(methodPrefixEventLevel + "sampled0.root")
113 t1 = f.Get(methodPrefixEventLevel + "_tree")
114 assert bool(t1), methodPrefixEventLevel + "_tree" + " isn't contained in file"
115 assert t1.GetEntries() > 0, methodPrefixEventLevel + "_tree" + "contains zero entries"
116 for iVariable in variablesToBeSaved:
117 iROOTVariable = str(ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(iVariable))
118 assert t1.GetListOfBranches().Contains(iROOTVariable), iROOTVariable +\
119 " branch is missing from " + methodPrefixEventLevel + "_tree"
120
121 assert t1.GetEntries() == 40, "40 entries should be saved in the test training ntuple, otherwise some problem happened."
122
123 mcPDGCodes = [
124 211.0, 211.0, 211.0, 211.0, 211.0, 321.0, 13.0, 11.0,
125 211.0, 211.0, 211.0, 11.0, 211.0, 211.0, 211.0, 211.0,
126 211.0, 11.0, 211.0, 211.0, 211.0, 211.0, 211.0, 211.0, 211.0,
127 2212.0, 211.0, 211.0, 2212.0, 211.0, 13.0, 211.0, 211.0,
128 211.0, 211.0, 211.0, 211.0, 321.0, 321.0, 211.0]
129
130 for iEntry in range(t1.GetEntries()):
131 t1.GetEntry(iEntry)
132 assert abs(t1.useCMSFrame__bop__bc) > 0, " p* should be greater than 0"
133 assert abs(t1.useCMSFrame__bopt__bc) > 0, " pt* should be greater than 0"
134 assert abs(t1.p) > 0, " p should be greater than 0"
135 assert abs(t1.pt) > 0, " pt should be greater than 0"
136 assert abs(t1.cosTheta) > 0, " cosTheta should be greater than 0"
137 assert abs(t1.electronID) > 0, " electronID should be greater than 0"
138 assert abs(t1.BtagToWBosonVariables__borecoilMassSqrd__bc) > 0, " recoilMassSqrd should be greater than 0"
139 assert abs(t1.BtagToWBosonVariables__bopMissCMS__bc) > 0, " pMissCMS should be greater than 0"
140 assert abs(t1.BtagToWBosonVariables__bocosThetaMissCMS__bc) > 0, " cosThetaMissCMS should be greater than 0"
141 assert abs(t1.BtagToWBosonVariables__boEW90__bc) > 0, " EW90 should be greater than 0"
142 assert abs(t1.cosTPTO) > 0, " cosTPTO should be greater than 0"
143 assert abs(t1.chiProb) > 0, " chiProb should be greater than 0"
144 if math.isnan(
145 t1.hasHighestProbInCat__boe__pl__clinRoe__cm__spisRightTrack__boElectron__bc__bc
146 ):
147 basf2.B2FATAL(" hasHighestProbInCat Electron should not be nan")
148 if math.isnan(t1.isRightCategory__boElectron__bc):
149 basf2.B2FATAL(" isRightCategory Electron should not be nan ")
150 if math.isnan(t1.ancestorHasWhichFlavor):
151 basf2.B2FATAL(" ancestorHasWhichFlavor should not be nan")
152 if math.isnan(t1.isSignal):
153 basf2.B2FATAL(" isSignal should not be nan")
154 assert abs(t1.mcPDG) == mcPDGCodes[iEntry], " Some mismatch between PDG codes happened in entry " + iEntry
155 if math.isnan(t1.mcErrors):
156 basf2.B2FATAL(" mcErrors should not be equal to nan")
157 assert abs(t1.genMotherPDG) > 0, " genMotherPDG should be greater than 0"
158 assert abs(t1.nMCMatches) > 0, " nMCMatches should be greater than 0"
def require_file(filename, data_type="", py_case=None)
Definition: __init__.py:54
def clean_working_directory()
Definition: __init__.py:189
def configure_logging_for_tests(user_replacements=None)
Definition: __init__.py:106