Belle II Software  release-05-01-25
setup_modules_ml.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 from basf2 import *
5 
6 
7 def add_fbdtclassifier_training(path,
8  networkInputName,
9  outputFileName='FBDTClassifier.dat',
10  train=True,
11  storeSamples=True,
12  useSamples=False,
13  samplesFileName='FBDTClassifier_samples.dat',
14  nTrees=100,
15  treeDepth=3,
16  shrinkage=0.15,
17  randRatio=0.5,
18  logLevel=LogLevel.INFO,
19  dbgLvl=1):
20  """This function adds the FastBDTClassifierTraining-module to the given path and exposes all its parameters
21  @param path the path to which the module should be added
22  @param networkInputName parameter passed to the module
23  @param outputFileName the filename to which the FBDTClassifier will be stored
24  @param train actually do the training
25  @param nTrees number of trees in the module
26  @param treeDepth the number of layers in the trees
27  @param shrinkage the shrinkage parameter
28  @param randRatio the ratio of all samples used for training of each tree
29  @param logLevel the LogLevel of the module
30  @param dbgLvl the debugLevel of the module
31  """
32  fbdtTrainer = register_module('FastBDTClassifierTraining')
33  fbdtTrainer.logging.log_level = logLevel
34  fbdtTrainer.logging.debug_level = dbgLvl
35  fbdtTrainer.param({'networkInputName': networkInputName,
36  'outputFileName': outputFileName,
37  'train': train,
38  'nTrees': nTrees,
39  'treeDepth': treeDepth,
40  'shrinkage': shrinkage,
41  'randRatio': randRatio,
42  'storeSamples': storeSamples,
43  'useSamples': useSamples,
44  'samplesFileName': samplesFileName
45  })
46 
47  path.add_module(fbdtTrainer)
48 
49 
50 def add_ml_threehitfilters(path,
51  networkInputName,
52  fbdtFileName='FBDTClassifier.dat',
53  cutVal=0.5,
54  logLevel=LogLevel.INFO,
55  dbgLvl=1):
56  """This function adds the MLSegmentNetworkProducerModule to the given path and exposes its parameters
57  @param path the path to which the module should be added
58  @param networkInputName parameter passed to the module
59  @param fbdtFileName the filename where the FBDT is stored
60  """
61  ml_segment = register_module('MLSegmentNetworkProducer')
62  ml_segment.logging.log_level = logLevel
63  ml_segment.logging.debug_level = dbgLvl
64  ml_segment.param({'networkInputName': networkInputName,
65  'FBDTFileName': fbdtFileName,
66  'cutValue': cutVal,
67  })
68 
69  path.add_module(ml_segment)
70 
71 
72 def add_fbdtclassifier_analyzer(path,
73  fbdtFileName,
74  trainSamp,
75  testSamp,
76  outputFN='FBDTAnalyzer_out.root',
77  logLevel=LogLevel.DEBUG,
78  dbgLvl=50):
79  """This function analyses all presented training and test samples and stores the outputs into a root file for later analysis
80  @param path the path to which the module should be added
81  @param fbdtFileName the filename of the FBDTClassifier
82  @param trainSamp the file name where the training samples are stored
83  @param testSamp the file name where the test smples are stored
84  @param outputFN the file name of the root file which is created
85  """
86  fbdtAnalyzer = register_module('FastBDTClassifierAnalyzer')
87  fbdtAnalyzer.logging.log_level = logLevel
88  fbdtAnalyzer.logging.debug_level = dbgLvl
89  fbdtAnalyzer.param({'fbdtFileName': fbdtFileName,
90  'testSamples': testSamp,
91  'trainSamples': trainSamp,
92  'outputFileName': outputFN
93  })
94 
95  path.add_module(fbdtAnalyzer)