Belle II Software development
setup_modules_ml.py
1#!/usr/bin/env python3
2
3
10
11import basf2 as b2
12
13
14def add_fbdtclassifier_training(path,
15 networkInputName,
16 outputFileName='FBDTClassifier.dat',
17 train=True,
18 storeSamples=True,
19 useSamples=False,
20 samplesFileName='FBDTClassifier_samples.dat',
21 nTrees=100,
22 treeDepth=3,
23 shrinkage=0.15,
24 randRatio=0.5,
25 logLevel=b2.LogLevel.INFO,
26 dbgLvl=1):
27 """This function adds the FastBDTClassifierTraining-module to the given path and exposes all its parameters
28 @param path the path to which the module should be added
29 @param networkInputName parameter passed to the module
30 @param outputFileName the filename to which the FBDTClassifier will be stored
31 @param train actually do the training
32 @param nTrees number of trees in the module
33 @param treeDepth the number of layers in the trees
34 @param shrinkage the shrinkage parameter
35 @param randRatio the ratio of all samples used for training of each tree
36 @param logLevel the LogLevel of the module
37 @param dbgLvl the debugLevel of the module
38 """
39 fbdtTrainer = b2.register_module('FastBDTClassifierTraining')
40 fbdtTrainer.logging.log_level = logLevel
41 fbdtTrainer.logging.debug_level = dbgLvl
42 fbdtTrainer.param({'networkInputName': networkInputName,
43 'outputFileName': outputFileName,
44 'train': train,
45 'nTrees': nTrees,
46 'treeDepth': treeDepth,
47 'shrinkage': shrinkage,
48 'randRatio': randRatio,
49 'storeSamples': storeSamples,
50 'useSamples': useSamples,
51 'samplesFileName': samplesFileName
52 })
53
54 path.add_module(fbdtTrainer)
55
56
57def add_ml_threehitfilters(path,
58 networkInputName,
59 fbdtFileName='FBDTClassifier.dat',
60 cutVal=0.5,
61 logLevel=b2.LogLevel.INFO,
62 dbgLvl=1):
63 """This function adds the MLSegmentNetworkProducerModule to the given path and exposes its parameters
64 @param path the path to which the module should be added
65 @param networkInputName parameter passed to the module
66 @param fbdtFileName the filename where the FBDT is stored
67 """
68 ml_segment = b2.register_module('MLSegmentNetworkProducer')
69 ml_segment.logging.log_level = logLevel
70 ml_segment.logging.debug_level = dbgLvl
71 ml_segment.param({'networkInputName': networkInputName,
72 'FBDTFileName': fbdtFileName,
73 'cutValue': cutVal,
74 })
75
76 path.add_module(ml_segment)
77
78
79def add_fbdtclassifier_analyzer(path,
80 fbdtFileName,
81 trainSamp,
82 testSamp,
83 outputFN='FBDTAnalyzer_out.root',
84 logLevel=b2.LogLevel.DEBUG,
85 dbgLvl=50):
86 """This function analyses all presented training and test samples and stores the outputs into a root file for later analysis
87 @param path the path to which the module should be added
88 @param fbdtFileName the filename of the FBDTClassifier
89 @param trainSamp the file name where the training samples are stored
90 @param testSamp the file name where the test smples are stored
91 @param outputFN the file name of the root file which is created
92 """
93 fbdtAnalyzer = b2.register_module('FastBDTClassifierAnalyzer')
94 fbdtAnalyzer.logging.log_level = logLevel
95 fbdtAnalyzer.logging.debug_level = dbgLvl
96 fbdtAnalyzer.param({'fbdtFileName': fbdtFileName,
97 'testSamples': testSamp,
98 'trainSamples': trainSamp,
99 'outputFileName': outputFN
100 })
101
102 path.add_module(fbdtAnalyzer)