"""
n3fit_data_utils.py
This module reads validphys :py:class:`validphys.core.DataSetSpec`
and extracts the relevant information into :py:class:`validphys.n3fit_data_utils.FittableDataSet`
The ``validphys_group_extractor`` will loop over every dataset of a given group
loading their fktables (and applying any necessary cuts).
"""
import dataclasses
from itertools import zip_longest
import numpy as np
[docs]
@dataclasses.dataclass
class FittableDataSet:
"""
Representation of the DataSet information necessary to run a fit
Parameters
----------
name: str
name of the dataset
fktables_data: list(:py:class:`validphys.coredata.FKTableData`)
list of coredata fktable objects
operation: str
operation to be applied to the fktables in the dataset, default "NULL"
frac: float
fraction of the data to enter the training set
training_mask: bool
training mask to apply to the fktable
"""
name: str
fktables_data: list # of validphys.coredata.FKTableData objects
# Things that can have default values:
operation: str = "NULL"
frac: float = 1.0
training_mask: np.ndarray = None # boolean array
def __post_init__(self):
self._tr_mask = None
self._vl_mask = None
if self.training_mask is not None:
data_idx = self.fktables_data[0].sigma.index.get_level_values(0).unique()
self._tr_mask = data_idx[self.training_mask].values
self._vl_mask = data_idx[~self.training_mask].values
@property
def ndata(self):
"""Number of datapoints in the dataset"""
return self.fktables_data[0].ndata
@property
def hadronic(self):
"""Returns true if this is a hadronic collision dataset"""
return self.fktables_data[0].hadronic
[docs]
def fktables(self):
"""Return the list of fktable tensors for the dataset"""
return [fk.get_np_fktable() for fk in self.fktables_data]
[docs]
def training_fktables(self):
"""Return the fktable tensors for the trainig data"""
if self._tr_mask is not None:
return [fk.with_cuts(self._tr_mask).get_np_fktable() for fk in self.fktables_data]
return self.fktables()
[docs]
def validation_fktables(self):
"""Return the fktable tensors for the validation data"""
if self._vl_mask is not None:
return [fk.with_cuts(self._vl_mask).get_np_fktable() for fk in self.fktables_data]
return self.fktables()