Belle II Software  release-08-02-04
grlneurotrainer.py
1 #!/user/bin/env python
2 
3 
10 
11 import basf2
12 from ROOT import Belle2
13 import os
14 
15 """
16 Example script showing how to train neural networks
17 to be used with the CDCTriggerNeuroModule.
18 
19 This script uses realistic values for the amount of training data
20 and the number of runs, so it will run a long time.
21 """
22 
23 # ------------ #
24 # user options #
25 # ------------ #
26 
27 # set random seed
28 basf2.set_random_seed(1)
29 
30 # paths for the trained networks, the training data and the log files
31 mlpdir = Belle2.FileSystem.findFile('trg/grl/data')
32 traindir = Belle2.FileSystem.findFile('trg/grl/data')
33 logdir = Belle2.FileSystem.findFile('trg/grl/data')
34 # filenames for the trained networks, the training data and the log files
35 mlpname = 'GRLNeuro.root'
36 trainname = 'GRLNeuroTraindata.root'
37 logname = 'GRLNeuroLog' # file extensions are appended automatically
38 
39 # number of threads to be used for parallel training
40 nthreads = 1
41 
42 
43 # ------------------------- #
44 # create path up to trigger #
45 # ------------------------- #
46 
47 main = basf2.create_path()
48 
49 main.add_module('Progress')
50 main.add_module('RootInput')
51 
52 # ---------------- #
53 # add the training #
54 # ---------------- #
55 
56 main.add_module('GRLNeuroTrainer',
57  # output
58  filename=os.path.join(mlpdir, mlpname),
59  trainFilename=os.path.join(traindir, trainname),
60 
61  # network structure
62  # nMLP=20, # total number of sectors
63  # multiplyHidden=False, # set the number of hidden nodes directly
64  # nHidden=[[10]], # 1 hidden layer with 81 nodes for all sectors
65  # i_cdc_sector=[0,0,0,0,0,1,1,1,1,1,2,2,2,2,2,3,3,3,3,3],
66  # i_ecl_sector=[0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4],
67  nMLP=8, # total number of sectors
68  multiplyHidden=False, # set the number of hidden nodes directly
69  nHidden=[[100]], # 1 hidden layer with 81 nodes for all sectors
70  n_cdc_sector=1,
71  n_ecl_sector=8,
72  i_cdc_sector=[0*2+36*3, 0*2+36*3, 0*2+36*3, 0*2+36*3, 0*2+36*3, 0*2+36*3, 0*2+36*3, 0*2+36*3],
73  i_ecl_sector=[0*3, 1*3, 2*3, 3*3, 4*3, 5*3, 6*3, 7*3],
74 
75  wMax=63., # limit weights to [-63, 63]
76  # training parameters
77  # multiplyNTrain=True, # set training data relative to degrees of freedom
78  # nTrainMax=10., # training data (10x degrees of freedom)
79  # nTrainMin=10., # don't train if there is less than 10x DoF training data
80  multiplyNTrain=False, # set training data relative to degrees of freedom
81  nTrainMax=2000, # training data (10x degrees of freedom)
82  nTrainMin=2000, # don't train if there is less than 10x DoF training data
83  nValid=1000, # number of validation samples (to avoid overtraining)
84  nTest=1000, # number of test samples (to select best of several runs)
85  # repeatTrain=10, # train each sector 10x with different initial weights
86  repeatTrain=1, # train each sector 10x with different initial weights
87  checkInterval=500, # stop training if validation error does not improve for 500 epochs
88  # maxEpochs=10000, # stop training after 10000 epochs
89  maxEpochs=1000, # stop training after 10000 epochs
90  nThreads=nthreads, # number of parallel threads
91  # log level
92  logLevel=basf2.LogLevel.DEBUG, # show some debug output
93  debugLevel=50)
94 
95 # show only the message of the debug output
96 # basf2.logging.set_info(basf2.LogLevel.DEBUG, basf2.LogInfo.LEVEL | basf2.LogInfo.MESSAGE)
97 
98 
99 # Process events
100 basf2.process(main)
101 
102 # Print call statistics
103 print(basf2.statistics)
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:148