Belle II Software  release-06-01-15
flavorTaggerTrainingNtuple.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 
13 """
14 This file tests the functionality of sampling needed to train the flavor tagger.
15 """
16 
17 import b2test_utils
18 import basf2
19 from basf2 import set_random_seed, create_path, process
20 import modularAnalysis as ma
21 from variables import variables as vm
22 import flavorTagger as ft
23 import ROOT
24 import os
25 import math
26 
27 # make logging more reproducible by replacing some strings
29 set_random_seed("1337")
30 testinput = [b2test_utils.require_file('analysis/tests/Btonunubar.root')]
31 
32 
34 testpath = create_path()
35 testpath.add_module('RootInput', inputFileNames=testinput)
36 
37 ma.fillParticleListFromMC('nu_tau', '', path=testpath)
38 ma.reconstructMCDecay(decayString='B0:sig -> nu_tau anti-nu_tau', cut='', path=testpath)
39 
40 # Test to build a rest of event from the MC B0 decaying to two neutrinos
41 ma.buildRestOfEvent('B0:sig', path=testpath)
42 
43 # Test MC association of MC particle
44 ma.applyCuts('B0:sig', ' abs(isRelatedRestOfEventB0Flavor) == 1', path=testpath)
45 
46 
48 roe_path = create_path()
49 deadEndPath = create_path()
50 
51 ma.signalSideParticleListsFilter(
52  ['B0:sig'],
53  'nROE_Charged(, 0) > 0 and abs(qrCombined) == 1',
54  roe_path,
55  deadEndPath)
56 
57 # Create list of ROE electrons
58 ma.fillParticleList('e+:inRoe', 'isInRestOfEvent > 0.5 and passesROEMask() > 0.5 and ' +
59  'isNAN(p) !=1 and isInfinity(p) != 1', path=roe_path)
60 
61 # Save pseudo training sample
62 methodPrefixEventLevel = "FlavorTagger_Belle2_B2nunuBGx1EventLevelElectronFBDT"
63 targetVariable = 'isRightCategory(Electron)'
64 ma.applyCuts('e+:inRoe', 'isRightCategory(mcAssociated) > 0', path=roe_path)
65 
66 # Skip electrons if list is empty
67 eventLevelpath = create_path()
68 SkipEmptyParticleList = basf2.register_module("SkimFilter")
69 SkipEmptyParticleList.set_name('SkimFilter_EventLevelElectron')
70 SkipEmptyParticleList.param('particleLists', 'e+:inRoe')
71 SkipEmptyParticleList.if_true(eventLevelpath, basf2.AfterConditionPath.CONTINUE)
72 roe_path.add_module(SkipEmptyParticleList)
73 
74 ntuple = basf2.register_module('VariablesToNtuple')
75 ntuple.param('fileName', methodPrefixEventLevel + "sampled0.root")
76 ntuple.param('treeName', methodPrefixEventLevel + "_tree")
77 
78 # Call variable aliases from flavor tagger
79 ft.set_FlavorTagger_pid_aliases()
80 
81 variablesToBeSaved = ['useCMSFrame(p)',
82  'useCMSFrame(pt)',
83  'p',
84  'pt',
85  'cosTheta',
86  'electronID',
87  'eid_TOP',
88  'eid_ARICH',
89  'eid_ECL',
90  'BtagToWBosonVariables(recoilMassSqrd)',
91  'BtagToWBosonVariables(pMissCMS)',
92  'BtagToWBosonVariables(cosThetaMissCMS)',
93  'BtagToWBosonVariables(EW90)',
94  'cosTPTO',
95  'chiProb',
96  'hasHighestProbInCat(e+:inRoe, isRightTrack(Electron))',
97  targetVariable, 'ancestorHasWhichFlavor',
98  'isSignal', 'mcPDG', 'mcErrors', 'genMotherPDG',
99  'nMCMatches', 'B0mcErrors'
100  ]
101 ntuple.param('variables', variablesToBeSaved)
102 ntuple.param('particleList', 'e+:inRoe')
103 eventLevelpath.add_module(ntuple)
104 
105 testpath.for_each('RestOfEvent', 'RestOfEvents', roe_path)
106 
107 
108 
110  process(testpath, 5)
111 
112  # Testing
113  assert os.path.isfile(methodPrefixEventLevel + "sampled0.root"), methodPrefixEventLevel + "sampled0.root" + " wasn't created"
114  f = ROOT.TFile(methodPrefixEventLevel + "sampled0.root")
115  t1 = f.Get(methodPrefixEventLevel + "_tree")
116  assert bool(t1), methodPrefixEventLevel + "_tree" + " isn't contained in file"
117  assert t1.GetEntries() > 0, methodPrefixEventLevel + "_tree" + "contains zero entries"
118  for iVariable in variablesToBeSaved:
119  iROOTVariable = str(ROOT.Belle2.makeROOTCompatible(iVariable))
120  assert t1.GetListOfBranches().Contains(iROOTVariable), iROOTVariable +\
121  " branch is missing from " + methodPrefixEventLevel + "_tree"
122 
123  assert t1.GetEntries() == 40, "40 entries should be saved in the test training ntuple, otherwise some problem happened."
124 
125  mcPDGCodes = [
126  211.0, 211.0, 211.0, 211.0, 211.0, 321.0, 13.0, 11.0,
127  211.0, 211.0, 211.0, 11.0, 211.0, 211.0, 211.0, 211.0,
128  211.0, 11.0, 211.0, 211.0, 211.0, 211.0, 211.0, 211.0, 211.0,
129  2212.0, 211.0, 211.0, 2212.0, 211.0, 13.0, 211.0, 211.0,
130  211.0, 211.0, 211.0, 211.0, 321.0, 321.0, 211.0]
131 
132  for iEntry in range(t1.GetEntries()):
133  t1.GetEntry(iEntry)
134  assert abs(t1.useCMSFrame__bop__bc) > 0, " p* should be greater than 0"
135  assert abs(t1.useCMSFrame__bopt__bc) > 0, " pt* should be greater than 0"
136  assert abs(t1.p) > 0, " p should be greater than 0"
137  assert abs(t1.pt) > 0, " pt should be greater than 0"
138  assert abs(t1.cosTheta) > 0, " cosTheta should be greater than 0"
139  assert abs(t1.electronID) > 0, " electronID should be greater than 0"
140  assert abs(t1.eid_TOP) > 0, " eid_TOP should be greater than 0"
141  assert abs(t1.eid_ARICH) > 0, " eid_ARICH should be greater than 0"
142  assert abs(t1.eid_ECL) > 0, "eid_ECL should be greater than 0"
143  assert abs(t1.BtagToWBosonVariables__borecoilMassSqrd__bc) > 0, " recoilMassSqrd should be greater than 0"
144  assert abs(t1.BtagToWBosonVariables__bopMissCMS__bc) > 0, " pMissCMS should be greater than 0"
145  assert abs(t1.BtagToWBosonVariables__bocosThetaMissCMS__bc) > 0, " cosThetaMissCMS should be greater than 0"
146  assert abs(t1.BtagToWBosonVariables__boEW90__bc) > 0, " EW90 should be greater than 0"
147  assert abs(t1.cosTPTO) > 0, " cosTPTO should be greater than 0"
148  assert abs(t1.chiProb) > 0, " chiProb should be greater than 0"
149  if math.isnan(
150  t1.hasHighestProbInCat__boe__pl__clinRoe__cm__spisRightTrack__boElectron__bc__bc
151  ):
152  basf2.B2FATAL(" hasHighestProbInCat Electron should not be nan")
153  if math.isnan(t1.isRightCategory__boElectron__bc):
154  basf2.B2FATAL(" isRightCategory Electron should not be nan ")
155  if math.isnan(t1.ancestorHasWhichFlavor):
156  basf2.B2FATAL(" ancestorHasWhichFlavor should not be nan")
157  if math.isnan(t1.isSignal):
158  basf2.B2FATAL(" isSignal should not be nan")
159  assert abs(t1.mcPDG) == mcPDGCodes[iEntry], " Some mismatch between PDG codes happened in entry " + iEntry
160  if math.isnan(t1.mcErrors):
161  basf2.B2FATAL(" mcErrors should not be equal to nan")
162  assert abs(t1.genMotherPDG) > 0, " genMotherPDG should be greater than 0"
163  assert abs(t1.nMCMatches) > 0, " nMCMatches should be greater than 0"
def configure_logging_for_tests(user_replacements=None)
Definition: __init__.py:106
def require_file(filename, data_type="", py_case=None)
Definition: __init__.py:54
def clean_working_directory()
Definition: __init__.py:185