#!/usr/bin/env python3
##########################################################################
# basf2 (Belle II Analysis Software Framework)                           #
# Author: The Belle II Collaboration                                     #
#                                                                        #
# See git log for contributors and copyright holders.                    #
# This file is licensed under LGPL-3.0, see LICENSE.md.                  #
##########################################################################
import re
import sys
from pathlib import Path
import json
import jsonschema
import yaml
from basf2 import find_file
[docs]
class Sample:
    """Base class for skim test samples."""
    def __init__(self, **kwargs):
        """
        Initialise Sample. Passing any unrecognised keywords will raise an error.
        """
        if kwargs:
            keys = ", ".join(kwargs.keys())
            raise ValueError(
                f"Unrecognised arguments in test sample initialisation: {keys}"
            )
    location = NotImplemented
    """Path of the test file."""
    @property
    def encodeable_name(self):
        """
        Identifying string which is safe to be included as a filename component or as a
        key in the skim stats JSON file.
        As a rough naming convention, data samples should start with 'Data-', MC sample
        with 'MC-', and custom samples with 'Custom-'.
        """
        return NotImplemented
    @property
    def printable_name(self):
        """
        Human-readable name for displaying in printed tables.
        """
        return NotImplemented
[docs]
    @staticmethod
    def resolve_path(location):
        """
        Replace ``'${SampleDirectory}'`` with ``Sample.SampleDirectory``, and resolve
        the path.
        Parameters:
            location (str, pathlib.Path): Filename to be resolved.
        Returns:
            pathlib.Path: Resolved path.
        """
        SampleDirectory = "/group/belle2/dataprod/MC/SkimTraining"
        location = str(location).replace("${SampleDirectory}", SampleDirectory)
        return Path(location).expanduser().resolve() 
    @property
    def as_dict(self):
        """
        Sample serialised as a dictionary.
        """
        return NotImplemented
    def __str__(self):
        return self.encodeable_name 
[docs]
class DataSample(Sample):
    def __init__(
        self,
        *,
        location,
        processing,
        experiment,
        beam_energy="4S",
        general_skim="all",
        **kwargs,
    ):
        # Pass unrecognised kwargs to base class
        super().__init__(**kwargs)
        self.location = self.resolve_path(location)
        self.processing = processing
        if isinstance(experiment, int) or not experiment.startswith("exp"):
            experiment = f"exp{experiment}"
        self.experiment = experiment
        self.beam_energy = beam_energy
        self.general_skim = general_skim
    def __repr__(self):
        return (
            f"{self.__class__.__name__}("
            f"location={repr(self.location)}, "
            f"processing={repr(self.processing)}, "
            f"experiment={repr(self.experiment)}, "
            f"beam_energy={repr(self.beam_energy)}, "
            f"general_skim={repr(self.general_skim)})"
        )
    @property
    def as_dict(self):
        return {
            "location": str(self.location),
            "processing": self.processing,
            "experiment": self.experiment,
            "beam_energy": self.beam_energy,
            "general_skim": self.general_skim,
        }
    @property
    def encodeable_name(self):
        return "-".join(
            (
                "Data",
                self.processing,
                self.experiment,
                self.beam_energy,
                self.general_skim,
            )
        )
    @property
    def printable_name(self):
        name = f"{self.processing} {self.experiment}"
        # Only print additional info in non-default situations
        if self.beam_energy != "4S":
            name += f", {self.beam_energy}"
        if self.general_skim != "all":
            name += f", ({self.general_skim})"
        return name 
[docs]
class MCSample(Sample):
    def __init__(
        self,
        *,
        location,
        process,
        campaign,
        beam_energy="4S",
        beam_background="BGx1",
        **kwargs,
    ):
        # Pass unrecognised kwargs to base class
        super().__init__(**kwargs)
        self.location = self.resolve_path(location)
        self.process = process
        self.beam_energy = beam_energy
        if isinstance(campaign, int) or not campaign.startswith("MC"):
            campaign = f"MC{campaign}"
        self.campaign = campaign
        if isinstance(beam_background, int) or not beam_background.startswith("BGx"):
            beam_background = f"BGx{beam_background}"
        self.beam_background = beam_background
    def __repr__(self):
        return (
            f"{self.__class__.__name__}("
            f"location={repr(self.location)}, "
            f"process={repr(self.process)}, "
            f"campaign={repr(self.campaign)}, "
            f"beam_energy={repr(self.beam_energy)}, "
            f"beam_background={repr(self.beam_background)})"
        )
    @property
    def as_dict(self):
        return {
            "location": str(self.location),
            "process": self.process,
            "campaign": self.campaign,
            "beam_energy": self.beam_energy,
            "beam_background": self.beam_background,
        }
    @property
    def encodeable_name(self):
        return "-".join(
            ("MC", self.campaign, self.beam_energy, self.process, self.beam_background)
        )
    @property
    def printable_name(self):
        name = f"{self.campaign} {self.process}"
        # Only print additional info in non-default situations
        if self.beam_background != "BGx1":
            name += f" {self.beam_background}"
        if self.beam_energy != "4S":
            name += f", {self.beam_energy}"
        return name 
[docs]
class CustomSample(Sample):
    def __init__(self, *, location, label=None, **kwargs):
        # Pass unrecognised kwargs to base class
        super().__init__(**kwargs)
        self.location = self.resolve_path(location)
        if label is None:
            self.label = str(location)
        else:
            self.label = label
        self.sanitised_label = re.sub(r"[^A-Za-z0-9]", "", self.label)
    def __repr__(self):
        return (
            f"{self.__class__.__name__}("
            f"location={repr(self.location)}, "
            f"label={repr(self.label)})"
        )
    @property
    def as_dict(self):
        return {"location": str(self.location), "label": self.label}
    @property
    def encodeable_name(self):
        return f"Custom-{self.sanitised_label}"
    @property
    def printable_name(self):
        return self.label 
[docs]
class TestSampleList:
    """Container class for lists of MC, data, and custom samples."""
    DefaultSampleYAML = (
        "/group/belle2/dataprod/MC/SkimTraining/SampleLists/TestFiles.yaml"
    )
    def __init__(self, *, SampleYAML=None, SampleDict=None, SampleList=None):
        """
        Initialise a list of test samples. Three methods are of initialisation are
        allowed. If no arguments are given this function will default to a standard list
        of samples defined in
        ``/group/belle2/dataprod/MC/SkimTraining/SampleLists/TestFiles.yaml``.
        Parameters:
            SampleYAML (str, pathlib.path): Path to a YAML file containing sample
                specifications.
            SampleDict (dict): Dict containing sample specifications.
            SampleList (list(Sample)): List of Sample objects.
        """
        if sum(p is not None for p in (SampleYAML, SampleDict, SampleList)) > 1:
            raise ValueError(
                "Only one out of SampleYAML, SampleDict, or SampleList can be passed."
            )
        if SampleList is not None:
            # Initialise from list of Sample objects
            self.mc_samples = [s for s in SampleList if isinstance(s, MCSample)]
            self.data_samples = [s for s in SampleList if isinstance(s, DataSample)]
            self.custom_samples = [s for s in SampleList if isinstance(s, CustomSample)]
            return
        if SampleDict is None:
            if SampleYAML is None:
                SampleYAML = self.DefaultSampleYAML
            with open(SampleYAML) as f:
                SampleDict = yaml.safe_load(f)
        self.validate_schema(SampleDict, SampleYAML)
        self._parse_all_samples(SampleDict)
    @property
    def _all_samples(self):
        return [*self.mc_samples, *self.data_samples, *self.custom_samples]
    def __iter__(self):
        yield from self._all_samples
    def __getitem__(self, i):
        return self._all_samples[i]
    def __len__(self):
        return len(self._all_samples)
    def __repr__(self):
        return f"{self.__class__.__name__}(" f"SampleList={repr(list(self))})"
    @property
    def SampleDict(self):
        return {
            "MC": [s.as_dict for s in self.mc_samples],
            "Data": [s.as_dict for s in self.data_samples],
            "Custom": [s.as_dict for s in self.custom_samples],
        }
[docs]
    def validate_schema(self, SampleDict, InputYAML=None):
        """
        Validate YAML input against JSON schema defined in
        ``skim/tools/resources/test_samples_schema.json``.
        """
        schema_file = find_file("skim/tools/resources/test_samples_schema.json")
        with open(schema_file) as f:
            schema = json.load(f)
        try:
            jsonschema.validate(SampleDict, schema)
        except jsonschema.exceptions.ValidationError as e:
            if InputYAML:
                raise ValueError(
                    f"Error in sample list configuration file {InputYAML}"
                ) from e
            raise e 
    @staticmethod
    def _parse_samples(SampleDict, BlockName, SampleClass):
        if SampleDict is None:
            return []
        try:
            InputSampleList = SampleDict[BlockName]
        except KeyError:
            return []
        if InputSampleList is None:
            return []
        samples = []
        for sample in InputSampleList:
            samples.append(SampleClass(**sample))
        return samples
    def _parse_all_samples(self, SampleDict):
        """Read in each block of the YAML and create lists of sample objects."""
        MissingParams = (
            "Error in '{block}' block of test sample yaml file.\n"
            "The following must all have defined values: {params}"
        )
        try:
            self.data_samples = self._parse_samples(SampleDict, "Data", DataSample)
        except TypeError as e:
            required = ", ".join(
                f"'{p}'"
                for p in ("location", "processing", "beam_energy", "experiment")
            )
            raise ValueError(MissingParams.format(block="Data", params=required)) from e
        try:
            self.mc_samples = self._parse_samples(SampleDict, "MC", MCSample)
        except TypeError as e:
            required = ", ".join(f"'{p}'" for p in ("location", "process", "campaign"))
            raise ValueError(MissingParams.format(block="MC", params=required)) from e
        try:
            self.custom_samples = self._parse_samples(
                SampleDict, "Custom", CustomSample
            )
        except TypeError as e:
            required = ", ".join(f"'{p}'" for p in ("location",))
            raise ValueError(
                MissingParams.format(block="Custom", params=required)
            ) from e
[docs]
    def query_mc_samples(
        self,
        *,
        process=None,
        campaign=None,
        beam_energy=None,
        beam_background=None,
        exact_match=False,
        inplace=False,
    ):
        """
        Find all MC samples matching query.
        Parameters:
            process (str): Simulated MC process to query.
            campaign (str, int): MC campaign number to query.
            beam_energy (str): Beam energy to query.
            beam_background (str, int): Nominal beam background to query.
            exact_match (bool): If passed, an error is raised if there is not exactly
                one matching sample. If there is exactly one matching sample, then the
                single sample is returned, rather than a list.
            inplace (bool): Replace MC samples with the list obtained from query.
        """
        if inplace and exact_match:
            raise ValueError(
                "Incompatible arguments passed: `inplace` and `exact_match`"
            )
        samples = [
            s
            for s in self.mc_samples
            if (process is None or s.process == process)
            and (campaign is None or s.campaign == campaign)
            and (beam_energy is None or s.beam_energy == beam_energy)
            and (beam_background is None or s.beam_background == beam_background)
        ]
        if exact_match:
            if len(samples) == 1:
                return samples[0]
            else:
                raise ValueError(
                    "`exact_match=True` was specified, but did not find exactly one match."
                )
        else:
            if inplace:
                self.mc_samples = samples
            else:
                return samples 
[docs]
    def query_data_samples(
        self,
        *,
        processing=None,
        experiment=None,
        beam_energy=None,
        general_skim=None,
        exact_match=False,
        inplace=False,
    ):
        """
        Find all MC samples matching query.
        Parameters:
            processing (str): Data processing campaign number to query.
            experiment (str, int): Experiment number to query.
            beam_energy (str): Beam energy to query.
            general_skim (str): ``GeneralSkimName`` to query.
            exact_match (bool): If passed, an error is raised if there is not exactly
                one matching sample. If there is exactly one matching sample, then the
                single sample is returned, rather than a list.
            inplace (bool): Replace MC samples with the list obtained from query.
        """
        if inplace and exact_match:
            raise ValueError(
                "Incompatible arguments passed: `inplace` and `exact_match`"
            )
        samples = [
            s
            for s in self.data_samples
            if (processing is None or s.processing == processing)
            and (experiment is None or s.experiment == experiment)
            and (beam_energy is None or s.beam_energy == beam_energy)
            and (general_skim is None or s.general_skim == general_skim)
        ]
        if exact_match:
            if len(samples) == 1:
                return samples[0]
            else:
                raise ValueError(
                    "`exact_match=True` was specified, but did not find exactly one match."
                )
        else:
            if inplace:
                self.data_samples = samples
            else:
                return samples 
 
[docs]
def get_test_file(process, *, SampleYAML=None):
    """
    Attempt to find a test sample of the given MC process.
    Parameters:
        process (str): Physics process, e.g. mixed, charged, ccbar, eemumu.
        SampleYAML (str, pathlib.Path): Path to a YAML file containing sample
                specifications.
    Returns:
        str: Path to test sample file.
    Raises:
        FileNotFoundError: Raised if no sample can be found.
    """
    samples = TestSampleList(SampleYAML=SampleYAML)
    matches = samples.query_mc_samples(process=process)
    try:
        # Return the first match found
        return matches[0].location
    except IndexError as e:
        raise ValueError(f"No test samples found for MC process '{process}'.") from e 
if __name__ == "__main__":
    # Print the parsed contents of the YAML file
    try:
        samples = TestSampleList(SampleYAML=sys.argv[1])
    except IndexError:
        samples = TestSampleList()
    print("Samples defined in YAML file:")
    for sample in samples:
        print(f"  * {repr(sample)}")