Source code for validphys.closuretest.multiclosure_nsigma_helpers

"""
This module contains some helper functions that are used for the computation of nsigma in the
context of a multi-closure test.
"""

import dataclasses
import logging

import numpy as np
import pandas as pd

from nnpdf_data.coredata import CommonData
from reportengine import collect
from validphys.calcutils import calc_chi2
from validphys.convolution import central_predictions
from validphys.core import DataSetSpec

log = logging.getLogger(__name__)


[docs]@dataclasses.dataclass class CentralChi2Data: value: float ndata: int dataset: DataSetSpec @property def reduced(self): return self.value / self.ndata
[docs]def central_member_chi2( central_predictions: pd.DataFrame, sqrt_covmat: np.ndarray, dataset: DataSetSpec, loaded_commondata_with_cuts: CommonData, ) -> CentralChi2Data: """ Computes the chi2 value for a dataset. Parameters ---------- central_predictions: The central predictions for the dataset. sqrt_covmat: np.ndarray The square root of the covariance matrix. dataset: DataSetSpec The dataset. loaded_commondata_with_cuts: nnpdf_data.coredata.CommonData Returns ------- CentralChi2Data """ diff = ( central_predictions.values.ravel() - loaded_commondata_with_cuts.central_values.values.ravel() ) value = calc_chi2(sqrt_covmat, diff) ndata = len(central_predictions) return CentralChi2Data(value=value, ndata=ndata, dataset=dataset)
[docs]def chi2_nsigma_deviation(central_member_chi2: CentralChi2Data) -> float: """ Computes n_sigma as: (chi2 - ndata) / sqrt(2 * ndata) Parameters ---------- central_member_chi2: CentralChi2Data Returns ------- float The deviation in units of sigma. """ diff = central_member_chi2.value - central_member_chi2.ndata return diff / np.sqrt(2 * central_member_chi2.ndata)
datasets_chi2_nsigma_deviation = collect("chi2_nsigma_deviation", ("data_input",)) """ Collect the n_sigma values over list of ``dataset_input``. """ fits_datasets_chi2_nsigma_deviation = collect( "datasets_chi2_nsigma_deviation", ("fits", "fitcontext") ) """ Collects over fits and for all datasets the n_sigma values. """ fits_data = collect("data", ("fits", "fitinputcontext")) """ Collects the data for each fit. """
[docs]def is_weighted(fits_data: list) -> bool: """ Returns whether the considered multiclosure tests has been weighted or not. If the weighted datasets are not the same for all fits, or there is more than one weighted dataset, an error is raised. Parameters ---------- fits_data: list List of data for each fit. Returns ------- bool """ # Extract the set of unique weighted dataset names from all fits weighted_ds_sets = [{ds.name for ds in data.datasets if ds.weight != 1} for data in fits_data] # Ensure all fits have the same set of weighted datasets if len(set(frozenset(ds_set) for ds_set in weighted_ds_sets)) > 1: error_msg = "Weighted datasets are not the same for all fits in the same multiclosure test (dataspec)." log.error(error_msg) raise ValueError(error_msg) # Extract the single weighted dataset set (all should be identical) weighted_ds = next(iter(weighted_ds_sets)) # Ensure there is exactly one weighted dataset if len(weighted_ds) > 1: error_msg = "Only one dataset can be weighted in a multiclosure test." log.error(error_msg) raise ValueError(error_msg) return bool(weighted_ds)
[docs]def n_fits(dataspecs): """ Computes the total number of fits in the multiclosure test. If the number of fits is not the same across dataspecs it raises an error. """ n_fits = set() for dataspec in dataspecs: n_fits.add(len(dataspec["fits"])) if len(n_fits) > 1: error_msg = "The number of fits is not the same across dataspecs." log.error(error_msg) raise ValueError(error_msg) return next(iter(n_fits))