14from basf2 
import B2ERROR, B2FATAL
 
   16from variables 
import variables 
as vm
 
   17import modularAnalysis 
as ma
 
   20def get_variables(particle_list, ranked_variable, variables=None, particleNumber=1):
 
   21    """ creates variable name pattern requested by the basf2 variable getVariableByRank() 
   23    :param ranked_variable: 
   25    :param particleNumber: 
   30        for i_num 
in range(1, particleNumber + 1):
 
   31            var_list.append(
'getVariableByRank(' + particle_list + 
', ' + ranked_variable + 
', ' + var + 
', ' +
 
   36def construct_default_variable_names(particle_lists=None, ranked_variable='p', variables=None, particleNumber=5):
 
   37    """ construct default variables (that are sorted by charge and ranked by momentum) 
   38    :param particle_lists: 
   39    :param ranked_variable: 
   41    :param particleNumber: 
   44    from ROOT 
import Belle2  
 
   45    if particle_lists 
is None:
 
   46        particle_lists = [
'pi+:pos_charged', 
'pi+:neg_charged']
 
   49    for p_list 
in particle_lists:
 
   50        variable_names += get_variables(p_list, ranked_variable, variables, particleNumber)
 
   53    root_compatible_list = []
 
   54    for var 
in variable_names:
 
   57    return root_compatible_list
 
   60def DeepFlavorTagger(particle_lists, mode='expert', working_dir='', uniqueIdentifier='standard', variable_list=None,
 
   61                     target='qrCombined', overwrite=False,
 
   62                     transform_to_probability=False, signal_fraction=-1.0, classifier_args=None,
 
   63                     train_valid_fraction=.92, mva_steering_file='analysis/scripts/dft/tensorflow_dnn_interface.py',
 
   67    Interfacing for the DeepFlavorTagger. This function can be used for training (``teacher``), preparation of 
   68    training datasets (``sampler``) and inference (``expert``). 
   70    This function requires reconstructed B meson signal particle list and where an RestOfEvent is built. 
   72    :param particle_lists:  string or list[string], particle list(s) of the reconstructed signal B meson 
   73    :param mode: string, valid modes are ``expert`` (default), ``teacher``, ``sampler`` 
   74    :param working_dir: string, working directory for the method 
   75    :param uniqueIdentifier: string, database identifier for the method 
   76    :param variable_list: list[string], name of the basf2 variables used for discrimination 
   77    :param target: string, target variable 
   78    :param overwrite: bool, overwrite already (locally!) existing training 
   79    :param transform_to_probability: bool, enable a purity transformation to compensate potential over-training, 
   80     can only be set during training 
   81    :param signal_fraction: float, (experimental) signal fraction override, 
   82     transform to output to a probability if an uneven signal/background fraction is used in the training data, 
   83     can only be set during training 
   84    :param classifier_args: dictionary, customized arguments for the mlp 
   85     possible attributes of the dictionary are: 
   86     lr_dec_rate: learning rate decay rate 
   87     lr_init: learning rate initial value 
   88     mom_init: momentum initial value 
   89     min_epochs: minimal number of epochs 
   90     max_epochs: maximal number of epochs 
   91     stop_epochs: epochs to stop without improvements on the validation set for early stopping 
   92     batch_size: batch size 
   93     seed: random seed for tensorflow 
   94     layers: [[layer name, activation function, input_width, output_width, init_bias, init_weights],..] 
   95     wd_coeffs: weight decay coefficients, length of layers 
   96     cuda_visible_devices: selection of cuda devices 
   97     tensorboard_dir: addition directory for logging the training process 
   98    :param train_valid_fraction: float, train-valid fraction (.92). If transform to probability is 
   99     enabled, train valid fraction will be split into a test set (.5) 
  100    :param maskName: get ROE particles from a specified ROE mask 
  101    :param path: basf2 path obj 
  105    if isinstance(particle_lists, str):
 
  106        particle_lists = [particle_lists]
 
  108    if mode 
not in [
'expert', 
'teacher', 
'sampler']:
 
  109        B2FATAL(f
'Invalid mode  {mode}')
 
  111    if variable_list 
is None and mode 
in [
'sampler', 
'teacher']:
 
  114            'useCMSFrame(cosTheta)',
 
  127    if variable_list 
is not None and mode == 
'expert':
 
  128        B2ERROR(
'DFT: Variables from identifier file are used. Input variables will be ignored.')
 
  130    if classifier_args 
is None:
 
  133        assert isinstance(classifier_args, dict)
 
  135    classifier_args[
'transform_to_prob'] = transform_to_probability
 
  137    output_file_name = os.path.join(working_dir, uniqueIdentifier + 
'_training_data.root')
 
  140    roe_path = basf2.create_path()
 
  141    dead_end_path = basf2.create_path()
 
  144    extension = particle_lists[0].replace(
':', 
'_to_')
 
  145    roe_particle_list_cut = 
'' 
  146    roe_particle_list = 
'pi+:dft' + 
'_' + extension
 
  148    tree_name = 
'dft_variables' 
  151    ma.signalSideParticleListsFilter(particle_lists, 
'hasRestOfEventTracks > 0', roe_path, dead_end_path)
 
  156    ma.fillParticleList(roe_particle_list, roe_particle_list_cut, path=roe_path)
 
  158    dft_particle_lists = [
'pi+:pos_charged', 
'pi+:neg_charged']
 
  160    pos_cut = 
'charge > 0 and isInRestOfEvent == 1 and passesROEMask(' + maskName + 
') > 0.5 and p < infinity' 
  161    neg_cut = 
'charge < 0 and isInRestOfEvent == 1 and passesROEMask(' + maskName + 
') > 0.5 and p < infinity' 
  163    ma.cutAndCopyList(dft_particle_lists[0], roe_particle_list, pos_cut, writeOut=
True, path=roe_path)
 
  164    ma.cutAndCopyList(dft_particle_lists[1], roe_particle_list, neg_cut, writeOut=
True, path=roe_path)
 
  172        features = get_variables(dft_particle_lists[0], rank_variable, variable_list, particleNumber=5)
 
  173        features += get_variables(dft_particle_lists[1], rank_variable, variable_list, particleNumber=5)
 
  175    for particles 
in dft_particle_lists:
 
  176        ma.rankByHighest(particles, rank_variable, path=roe_path)
 
  178    if mode == 
'sampler':
 
  179        if os.path.isfile(output_file_name) 
and not overwrite:
 
  180            B2FATAL(f
'Outputfile {output_file_name} already exists. Aborting writeout.')
 
  183        all_variables = features + [target]
 
  186        ma.variablesToNtuple(
'', all_variables, tree_name, output_file_name, roe_path)
 
  189        extern_command = f
'basf2_mva_teacher --datafile {output_file_name} --treename {tree_name}' + \
 
  190                         f
' --identifier {uniqueIdentifier} ' + \
 
  191                         '--variables "{}"  '.format(
'" "'.join(features)) + \
 
  192                         f
'--target_variable {target}' + \
 
  193                         f
' --method Python --training_fraction {train_valid_fraction}' + \
 
  194                         f
" --config '{json.dumps(classifier_args)}' --framework tensorflow" + \
 
  195                         f
' --steering_file {mva_steering_file}' 
  197        with open(os.path.join(working_dir, uniqueIdentifier + 
'_teacher_command'), 
'w') 
as f:
 
  198            f.write(extern_command)
 
  200    elif mode == 
'teacher':
 
  201        if not os.path.isfile(output_file_name):
 
  202            B2FATAL(
'There is no training data file available. Run flavor tagger in sampler mode first.')
 
  203        general_options = basf2_mva.GeneralOptions()
 
  204        general_options.m_datafiles = basf2_mva.vector(output_file_name)
 
  206        general_options.m_treename = tree_name
 
  207        general_options.m_target_variable = target
 
  208        general_options.m_variables = basf2_mva.vector(*features)
 
  210        general_options.m_identifier = uniqueIdentifier
 
  212        specific_options = basf2_mva.PythonOptions()
 
  213        specific_options.m_framework = 
'tensorflow' 
  214        specific_options.m_steering_file = mva_steering_file
 
  215        specific_options.m_training_fraction = train_valid_fraction
 
  217        specific_options.m_config = json.dumps(classifier_args)
 
  219        basf2_mva.teacher(general_options, specific_options)
 
  221    elif mode == 
'expert':
 
  223        flavorTaggerInfoBuilder = basf2.register_module(
'FlavorTaggerInfoBuilder')
 
  224        path.add_module(flavorTaggerInfoBuilder)
 
  226        expert_module = basf2.register_module(
'MVAExpert')
 
  227        expert_module.param(
'listNames', particle_lists)
 
  228        expert_module.param(
'identifier', uniqueIdentifier)
 
  230        expert_module.param(
'extraInfoName', 
'dnn_output')
 
  231        expert_module.param(
'signalFraction', signal_fraction)
 
  233        roe_path.add_module(expert_module)
 
  235        flavorTaggerInfoFiller = basf2.register_module(
'FlavorTaggerInfoFiller')
 
  236        flavorTaggerInfoFiller.param(
'DNNmlp', 
True)
 
  237        roe_path.add_module(flavorTaggerInfoFiller)
 
  240        vm.addAlias(
'DNN_qrCombined', 
'qrOutput(DNN)')
 
  242    path.for_each(
'RestOfEvent', 
'RestOfEvents', roe_path)
 
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.