Belle II Software  release-08-01-10
setup_modules_ml.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 import basf2 as b2
13 
14 
15 def add_fbdtclassifier_training(path,
16  networkInputName,
17  outputFileName='FBDTClassifier.dat',
18  train=True,
19  storeSamples=True,
20  useSamples=False,
21  samplesFileName='FBDTClassifier_samples.dat',
22  nTrees=100,
23  treeDepth=3,
24  shrinkage=0.15,
25  randRatio=0.5,
26  logLevel=b2.LogLevel.INFO,
27  dbgLvl=1):
28  """This function adds the FastBDTClassifierTraining-module to the given path and exposes all its parameters
29  @param path the path to which the module should be added
30  @param networkInputName parameter passed to the module
31  @param outputFileName the filename to which the FBDTClassifier will be stored
32  @param train actually do the training
33  @param nTrees number of trees in the module
34  @param treeDepth the number of layers in the trees
35  @param shrinkage the shrinkage parameter
36  @param randRatio the ratio of all samples used for training of each tree
37  @param logLevel the LogLevel of the module
38  @param dbgLvl the debugLevel of the module
39  """
40  fbdtTrainer = b2.register_module('FastBDTClassifierTraining')
41  fbdtTrainer.logging.log_level = logLevel
42  fbdtTrainer.logging.debug_level = dbgLvl
43  fbdtTrainer.param({'networkInputName': networkInputName,
44  'outputFileName': outputFileName,
45  'train': train,
46  'nTrees': nTrees,
47  'treeDepth': treeDepth,
48  'shrinkage': shrinkage,
49  'randRatio': randRatio,
50  'storeSamples': storeSamples,
51  'useSamples': useSamples,
52  'samplesFileName': samplesFileName
53  })
54 
55  path.add_module(fbdtTrainer)
56 
57 
58 def add_ml_threehitfilters(path,
59  networkInputName,
60  fbdtFileName='FBDTClassifier.dat',
61  cutVal=0.5,
62  logLevel=b2.LogLevel.INFO,
63  dbgLvl=1):
64  """This function adds the MLSegmentNetworkProducerModule to the given path and exposes its parameters
65  @param path the path to which the module should be added
66  @param networkInputName parameter passed to the module
67  @param fbdtFileName the filename where the FBDT is stored
68  """
69  ml_segment = b2.register_module('MLSegmentNetworkProducer')
70  ml_segment.logging.log_level = logLevel
71  ml_segment.logging.debug_level = dbgLvl
72  ml_segment.param({'networkInputName': networkInputName,
73  'FBDTFileName': fbdtFileName,
74  'cutValue': cutVal,
75  })
76 
77  path.add_module(ml_segment)
78 
79 
80 def add_fbdtclassifier_analyzer(path,
81  fbdtFileName,
82  trainSamp,
83  testSamp,
84  outputFN='FBDTAnalyzer_out.root',
85  logLevel=b2.LogLevel.DEBUG,
86  dbgLvl=50):
87  """This function analyses all presented training and test samples and stores the outputs into a root file for later analysis
88  @param path the path to which the module should be added
89  @param fbdtFileName the filename of the FBDTClassifier
90  @param trainSamp the file name where the training samples are stored
91  @param testSamp the file name where the test smples are stored
92  @param outputFN the file name of the root file which is created
93  """
94  fbdtAnalyzer = b2.register_module('FastBDTClassifierAnalyzer')
95  fbdtAnalyzer.logging.log_level = logLevel
96  fbdtAnalyzer.logging.debug_level = dbgLvl
97  fbdtAnalyzer.param({'fbdtFileName': fbdtFileName,
98  'testSamples': testSamp,
99  'trainSamples': trainSamp,
100  'outputFileName': outputFN
101  })
102 
103  path.add_module(fbdtAnalyzer)