import dataclasses
import enum
import logging
import numbers
import typing
import numpy as np
import pandas as pd
from nnpdf_data.utils import parse_yaml_inp
from reportengine.floatformatting import format_number
from reportengine.utils import ChainMap
from validphys.core import CommonDataSpec, DataSetSpec
from validphys.coredata import CommonData
from validphys.plotoptions.plottingoptions import PlottingOptions, default_labels, labeler_functions
from validphys.plotoptions.utils import apply_to_all_columns
log = logging.getLogger(__name__)
[docs]
def get_info(data, *, normalize=False, cuts=None, use_plotfiles=True):
"""Retrieve and process the plotting information for the input data (which could
be a DatasetSpec or a CommonDataSpec).
If ``use_plotfiles`` is ``True`` (the default), the PLOTTING files will be
used to retrieve the infromation. Otherwise the default configuration
(which depends of the process type) will be used.
If cuts is None, the cuts of the dataset will be used, but no cuts for
commondata.
If cuts is False, no cuts will be used.
If cuts is an instance of Cuts, it will be used.
If normalize is True, the specialization for ratio plots will be used to
generate the PlotInfo objects.
"""
if cuts is None:
if isinstance(data, DataSetSpec):
cuts = data.cuts.load() if data.cuts else None
elif hasattr(cuts, 'load'):
cuts = cuts.load()
if cuts is not None and not len(cuts):
raise NotImplementedError("No point passes the cuts. Cannot retieve info")
if isinstance(data, DataSetSpec):
data = data.commondata
if not isinstance(data, CommonDataSpec):
raise TypeError("Unrecognized data type: %s" % type(data))
info = PlotInfo.from_commondata(data, cuts=cuts, normalize=normalize)
return info
[docs]
class PlotInfo:
def __init__(
self,
kinlabels,
dataset_label,
*,
experiment=None,
x=None,
extra_labels=None,
func_labels=None,
figure_by=None,
line_by=None,
kinematics_override=None,
result_transform=None,
y_label=None,
x_label=None,
x_scale=None,
y_scale=None,
ds_metadata=None,
process_description='-',
nnpdf31_process,
**kwargs,
):
self.kinlabels = kinlabels
self._experiment = experiment
self.nnpdf31_process = nnpdf31_process
if x is None:
x = 'idat'
self.x = x
self.extra_labels = extra_labels
self.func_labels = func_labels
self.figure_by = figure_by
self.line_by = line_by
if kinematics_override is None:
raise ValueError(f'A kinematics_override must be set for {dataset_label}')
self.kinematics_override = kinematics_override
self.result_transform = result_transform
self._x_label = x_label
self.y_label = y_label
self.x_scale = x_scale
self.y_scale = y_scale
self.dataset_label = dataset_label
self.process_description = process_description
# Metadata of the dataset
self.ds_metadata = ds_metadata
[docs]
def name_to_label(self, name):
if name in labeler_functions:
func = labeler_functions[name]
return getattr(func, 'label', name)
try:
ix = ('k1', 'k2', 'k3').index(name)
except ValueError:
return name
return self.kinlabels[ix]
@property
def experiment(self):
if self._experiment is None:
raise ValueError(
"Somehow PlotInfo was loaded with an empty experiment, this should not happen"
)
return self._experiment
@property
def process_type(self):
return self.ds_metadata.process_type
@property
def xlabel(self):
if self._x_label is not None:
return self._x_label
return self.name_to_label(self.x)
[docs]
def get_xcol(self, table):
"""Return a numpy array with the x column or the index as appropriate"""
if self.x == 'idat':
return np.array(table.index)
else:
return np.asarray(table[self.x])
[docs]
def group_label(self, same_vals, groupby):
if not groupby:
return ''
if len(same_vals) == 1 and isinstance(same_vals[0], str):
return f'({same_vals[0]})'
pieces = []
for column, val in zip(groupby, same_vals):
if (
self.ds_metadata is not None
and not self.ds_metadata.is_ported_dataset
and column in ('k1', 'k2', 'k3')
):
# If this is a new-style commondata (it has metadata)
# _and_ it is not simply an automatic port of the old dataset
# _and_ we have the information on the requested column...
# then we can have a nicer label!
ix = ('k1', 'k2', 'k3').index(column)
var_key = self.ds_metadata.kinematic_coverage[ix]
pieces.append(self.ds_metadata.kinematics.apply_label(var_key, val))
else:
label = self.name_to_label(column)
if isinstance(val, numbers.Real):
val = format_number(val)
pieces.append('{} = {}'.format(label, val))
return '%s' % ' '.join(pieces)
[docs]
@classmethod
def from_commondata(cls, commondata, cuts=None, normalize=False):
plot_params = ChainMap()
kinlabels = commondata.plot_kinlabels
if commondata.legacy:
if commondata.plotfiles:
for file in commondata.plotfiles:
pf = parse_yaml_inp(file, PlottingFile)
config_params = dataclasses.asdict(pf, dict_factory=dict_factory)
plot_params = plot_params.new_child(config_params)
if normalize and 'normalize' in plot_params:
plot_params = plot_params.new_child(config_params['normalize'])
if 'dataset_label' not in plot_params:
log.warning(f"'dataset_label' key not found in {file}")
plot_params['dataset_label'] = commondata.name
else:
plot_params = {'dataset_label': commondata.name}
else:
pcd = commondata.metadata.plotting_options
config_params = dataclasses.asdict(pcd, dict_factory=dict_factory)
plot_params = plot_params.new_child(config_params)
# Add a reference to the metadata to the plot_params so that it is stored in PlotInfo
plot_params["ds_metadata"] = commondata.metadata
# If normalize, we need to update some of the parameters
if normalize and pcd.normalize is not None:
plot_params = plot_params.new_child(pcd.normalize)
kinlabels = plot_params['kinematics_override'].new_labels(*kinlabels)
plot_params["process_type"] = commondata.metadata.process_type
if "extra_labels" in plot_params and cuts is not None:
cut_extra_labels = {
k: [v[i] for i in cuts] for k, v in plot_params["extra_labels"].items()
}
plot_params["extra_labels"] = cut_extra_labels
return cls(kinlabels=kinlabels, **plot_params)
[docs]
def dict_factory(key_value_pairs):
"""A dictionary factory to be used in conjunction with dataclasses.asdict
to remove nested None values and convert enums to their name.
https://stackoverflow.com/questions/59481989/dict-from-nested-dataclasses
"""
new_kv_pairs = []
for key, value in key_value_pairs:
if isinstance(value, enum.Enum):
new_kv_pairs.append((key, value.name))
elif value is not None:
new_kv_pairs.append((key, value))
return dict(new_kv_pairs)
[docs]
class KinLabel(enum.Enum):
k1 = enum.auto()
k2 = enum.auto()
k3 = enum.auto()
[docs]
@dataclasses.dataclass
class PlottingFile(PlottingOptions):
normalize: typing.Optional[PlottingOptions] = None
[docs]
def kitable(data, info, *, cuts=None):
"""Obtain a DataFrame with the kinematics for each data point
Parameters
----------
data: (DataSetSpec, CommonDataSpec, Dataset, CommonData)
A data object to extract the kinematics from.
info: PlotInfo
The description of the transformations to apply to the kinematics.
See :py:func:`get_info`
cuts: Cuts or None, default=None
An object to load the cuts from. It is an error to set this if ``data``
is a dataset. If `data` is a CommonData, these **must** be the same as
those passed to :py:func:`get_info`.
Returns
-------
table: pd.DataFrame
A DataFrame containing the kinematics for all points after cuts.
"""
if isinstance(data, (DataSetSpec)) and cuts is not None:
raise TypeError("Cuts must be None when a dataset is given")
if isinstance(data, DataSetSpec):
data = data.load_commondata()
elif isinstance(data, CommonDataSpec):
data = data.load()
table = pd.DataFrame(data.get_kintable(), columns=default_labels[1:])
if isinstance(data, CommonData) and cuts is not None:
table = table.loc[cuts.load()]
table.index.name = default_labels[0]
if info.kinematics_override:
transform = apply_to_all_columns(table, info.kinematics_override)
table = pd.DataFrame(np.array(transform).T, columns=table.columns, index=table.index)
# TODO: This is a little bit ugly. We want to call the functions
# with all the
# extra labels
if info.extra_labels:
vals = tuple(info.extra_labels.items())
else:
vals = ()
if info.func_labels:
funcs = tuple(info.func_labels.items())
else:
funcs = ()
for label, value in vals:
table[label] = value
nreal_labels = len(table.columns)
for label, func in funcs:
# Pass only the "real" labels and not the derived functions
table[label] = apply_to_all_columns(table.iloc[:, :nreal_labels], func)
return table
[docs]
def get_xq2map(kintable, info):
"""Return a tuple of (x,Q²) from the kinematic values defined in kitable
(usually obtained by calling ``kitable``) using the process type if available
Otherwise it will fallback to the legacy mode, i.e., "using machinery specified in``info``
"""
try:
return info.process_type.xq2map(kintable, info.ds_metadata)
except AttributeError:
try:
return apply_to_all_columns(kintable, info.kinematics_override.xq2map)
except NotImplementedError as e:
raise NotImplementedError(
f"The process type {info.process_type} for {info.ds_metadata.name} is not implemented"
) from e