Source code for n3fit.stopping

"""
    Module containing the classes related to the stopping alogirthm

    In this module there are four Classes:

    - FitState: this class contains the information of the fit
            for a given point in history
    - FitHistory: this class contains the information necessary
            in order to reset the state of the fit to the point
            in which the history was saved.
            i.e., a list of FitStates
    - Stopping: this class monitors the chi2 of the validation
            and training sets and decides when to stop
    - Positivity: Decides whether a given point fullfills the positivity conditions
    - Validation: Controls the NNPDF cross-validation algorithm

    Note:
        There are situations in which the validation set is empty, in those cases
    the training set is used as validation set.
    This implies several changes in the behaviour of this class as the training chi2 will
    now be monitored for stability.
        In order to parse the set of loss functions coming from the backend::MetaModel,
    the function `parse_losses` relies on the fact that they are all suffixed with `_loss`
    the validation case, instead, is suffixed with `val_loss`. In the particular casse in
    which both training and validation model correspond to the same backend::MetaModel only
    the `_loss` suffix can be found. This is taken into account by the class `Stopping`
    which will tell `Validation` that no validation set was found and that the training is to
    be used instead.
"""

import logging

import numpy as np

log = logging.getLogger(__name__)

# Put a very big number here so that we for sure discard this run
# AND we have a clear marker that something went wrong, not just a bad fit
TERRIBLE_CHI2 = 1e10
INITIAL_CHI2 = 1e9

# Pass/veto keys
POS_OK = "POS_PASS"
POS_BAD = "POS_VETO"
THRESHOLD_POS = 1e-6


[docs] def parse_ndata(all_data): """ Parses the list of dictionaries received from ModelTrainer into a dictionary containing only the name of the experiments together with the number of points. Returns ------- `tr_ndata` dictionary of {'exp' : ndata} `vl_ndata` dictionary of {'exp' : ndata} `pos_set`: list of the names of the positivity sets Note: if there is no validation (total number of val points == 0) then vl_ndata will point to tr_ndata """ tr_ndata_dict = {} vl_ndata_dict = {} pos_set = [] for dictionary in all_data: exp_name = dictionary["name"] if dictionary.get("count_chi2"): tr_ndata = dictionary["ndata"] vl_ndata = dictionary["ndata_vl"] if tr_ndata: tr_ndata_dict[exp_name] = tr_ndata if vl_ndata: vl_ndata_dict[exp_name] = vl_ndata if dictionary.get("positivity") and not dictionary.get("integrability"): pos_set.append(exp_name) if not vl_ndata_dict: vl_ndata_dict = None return tr_ndata_dict, vl_ndata_dict, pos_set
[docs] def parse_losses(history_object, data, suffix="loss"): """ Receives an object containing the chi2 Usually a history object, but it can come in the form of a dictionary. It loops over the dictionary and uses the npoints_data dictionary to normalize the chi2 and return backs a tuple (`total`, `tr_chi2`) Parameters ---------- history_object: dict A history object dictionary data: dict dictionary with the name of the experiments to be taken into account and the number of datapoints of the experiments suffix: str (default: ``loss``) suffix of the loss layer, Keras default is _loss Returns ------- total_loss: float Total value for the loss dict_chi2: dict dictionary of {'expname' : loss } """ try: hobj = history_object.history except AttributeError: # So it works whether we pass the out or the out.history hobj = history_object # In the general case epochs = 1. # In case that we are doing more than 1 epoch, take the last result as it is the result # the model is in at the moment # This value is only used for printing output purposes so should not have any significance dict_chi2 = {} total_points = 0 total_loss = 0 for exp_name, npoints in data.items(): loss = np.array(hobj[exp_name + f"_{suffix}"]) dict_chi2[exp_name] = loss / npoints total_points += npoints total_loss += loss # By taking the loss from the history object we would be saving the total loss # including positivity sets and (if added/enabled) regularizsers # instead we want to restrict ourselves to the loss coming from experiments # total_loss = np.mean(hobj["loss"]) / total_points total_loss /= total_points dict_chi2["total"] = total_loss return total_loss, dict_chi2
[docs] class FitState: """ Holds the state of the chi2 during the fit, for all replicas and one epoch Note: the training chi2 is computed before the update of the weights so it is the chi2 that informed the updated corresponding to this state. The validation chi2 instead is computed after the update of the weights. Parameters ---------- training_info: dict all losses for the training model validation_info: dict all losses for the validation model training_loss: float total training loss, this can be given if per-exp``training_info`` is not available """ vl_ndata = None tr_ndata = None vl_suffix = None def __init__(self, training_info, validation_info, training_loss=None): if self.vl_ndata is None or self.tr_ndata is None or self.vl_suffix is None: raise ValueError( "FitState cannot be instantiated until vl_ndata, tr_ndata and vl_suffix are filled" ) self._training = training_info self.validation = validation_info self._parsed = False self._vl_chi2 = None # These are per replica self._tr_chi2 = None # This is an overall training chi2 self._vl_dict = None self._tr_dict = None # This can be given if ``training_info`` is not given self._training_loss = training_loss @property def vl_loss(self): """Return the total validation loss as it comes from the info dictionaries""" return self.validation.get("loss") @property def tr_loss(self): """Return the total validation loss as it comes from the info dictionaries""" if self._training is None: return self._training_loss return self._training.get("loss") def _parse_chi2(self): """ Parses the chi2 from the losses according to the `tr_ndata` and `vl_ndata` dictionaries of {dataset: n_points} """ if self._parsed: return if self._training is not None: self._tr_chi2, self._tr_dict = parse_losses(self._training, self.tr_ndata) if self.validation is not None: self._vl_chi2, self._vl_dict = parse_losses( self.validation, self.vl_ndata, suffix=self.vl_suffix ) @property def tr_chi2(self): self._parse_chi2() return self._tr_chi2 @property def vl_chi2(self): self._parse_chi2() return self._vl_chi2 @property def all_tr_chi2(self): self._parse_chi2() return self._tr_dict @property def all_vl_chi2(self): self._parse_chi2() return self._vl_dict
[docs] def all_tr_chi2_for_replica(self, i_replica): """Return the tr chi2 per dataset for a given replica""" return {k: np.take(v, i_replica) for k, v in self.all_tr_chi2.items()}
[docs] def all_vl_chi2_for_replica(self, i_replica): """Return the vl chi2 per dataset for a given replica""" return {k: np.take(v, i_replica) for k, v in self.all_vl_chi2.items()}
[docs] def total_partial_tr_chi2(self): """Return the tr chi2 summed over replicas per experiment""" return {k: np.sum(v) for k, v in self.all_tr_chi2.items()}
[docs] def total_partial_vl_chi2(self): """Return the vl chi2 summed over replicas per experiment""" return {k: np.sum(v) for k, v in self.all_vl_chi2.items()}
[docs] def total_tr_chi2(self): """Return the total tr chi2 summed over replicas""" return np.sum(self.tr_chi2)
[docs] def total_vl_chi2(self): """Return the total vl chi2 summed over replicas""" return np.sum(self.vl_chi2)
def __str__(self): return f"chi2: tr={self.tr_chi2} vl={self.vl_chi2}"
[docs] class FitHistory: """ Keeps a list of FitState items holding the full chi2 history of the fit. Parameters ---------- tr_ndata: dict dictionary of {dataset: n_points} for the training data vl_ndata: dict dictionary of {dataset: n_points} for the validation data """ def __init__(self, tr_ndata, vl_ndata): if vl_ndata is None: vl_ndata = tr_ndata vl_suffix = "loss" else: vl_suffix = "val_loss" # All instances of FitState should use these FitState.tr_ndata = tr_ndata FitState.vl_ndata = vl_ndata FitState.vl_suffix = vl_suffix # Save a list of status for the entire fit self._history = [] self.final_epoch = None
[docs] def get_state(self, epoch): """Get the FitState of the system for a given epoch""" try: return self._history[epoch] except IndexError as e: raise ValueError( f"Tried to get obtain the state for epoch {epoch} when only {len(self._history)} epochs have been saved" ) from e
[docs] def register(self, epoch, fitstate): """Save the current fitstate and the associated epoch and set the current epoch as the final one should the fit end now """ self.final_epoch = epoch self._history.append(fitstate)
[docs] class Stopping: """ Driver of the stopping algorithm Note, if the total number of points in the validation dictionary is None, it is assumed the validation_model actually corresponds to the training model. Parameters ---------- validation_model: n3fit.backends.MetaModel the model with the validation mask applied (and compiled with the validation data and covmat) all_data_dicts: dict list containg all dictionaries containing all information about the experiments/validation/regularizers/etc to be parsed by Stopping pdf_model: n3fit.backends.MetaModel pdf_model being trained threshold_positivity: float maximum value allowed for the sum of all positivity losses total_epochs: int total number of epochs stopping_patience: int how many epochs to wait for the validation loss to improve threshold_chi2: float maximum value allowed for chi2 dont_stop: bool dont care about early stopping """ def __init__( self, validation_model, all_data_dicts, pdf_model, threshold_positivity=THRESHOLD_POS, total_epochs=0, stopping_patience=7000, threshold_chi2=10.0, dont_stop=False, ): self._pdf_model = pdf_model # Save the validation object self._validation = validation_model # Create the History object tr_ndata, vl_ndata, pos_sets = parse_ndata(all_data_dicts) self._history = FitHistory(tr_ndata, vl_ndata) # And the positivity checker self._positivity = Positivity(threshold_positivity, pos_sets) # Initialize internal variables for the stopping self._n_replicas = pdf_model.num_replicas self._threshold_chi2 = threshold_chi2 self._stopping_degrees = np.zeros(self._n_replicas, dtype=int) self._counts = np.zeros(self._n_replicas, dtype=int) # Keep track of the replicas that should not be stopped yet self._dont_stop_me_now = np.ones(self._n_replicas, dtype=bool) self._dont_stop = dont_stop self._stop_now = False self.stopping_patience = stopping_patience self.total_epochs = total_epochs self._stop_epochs = [total_epochs - 1] * self._n_replicas self._best_epochs = [None] * self._n_replicas self.positivity_statuses = [POS_BAD] * self._n_replicas self._best_weights = [None] * self._n_replicas self._best_val_chi2s = [INITIAL_CHI2] * self._n_replicas @property def vl_chi2(self): """Current validation chi2""" validation_info = self._validation.compute_losses() fitstate = FitState(None, validation_info) return fitstate.vl_chi2 @property def e_best_chi2(self): """Epoch of the best chi2, if there is no best epoch, return last""" best_or_last_epochs = [ best if best is not None else last for best, last in zip(self._best_epochs, self._stop_epochs) ] return best_or_last_epochs @property def stop_epoch(self): """Epoch in which the fit is stopped""" return -1 if self._history.final_epoch is None else self._history.final_epoch + 1 @property def positivity_status(self): """Returns POS_PASS if positivity passes or veto if it doesn't for each replica""" return self.positivity_statuses
[docs] def evaluate_training(self, training_model): """Given the training model, evaluates the model and parses the chi2 of the training datasets Parameters ---------- training_model: n3fit.backends.MetaModel an object implementing the evaluate function Returns ------- tr_chi2: float chi2 of the given ``training_model`` """ training_info = training_model.compute_losses() fitstate = FitState(training_info, None) return fitstate.tr_chi2
[docs] def monitor_chi2(self, training_info, epoch, print_stats=False): """ Function to be called at the end of every epoch. Stores the total chi2 of the training set as well as the total chi2 of the validation set. If the training chi2 is below a certain threshold, stores the state of the model which gave the minimum chi2 as well as the epoch in which occurred If the epoch is a multiple of save_all_each then we also save the per-exp chi2 Returns True if the run seems ok and False if a NaN is found Parameters ---------- training_info: dict output of a .fit() call, dictionary of the total training loss (summed over replicas and experiments) epoch: int index of the epoch Returns ------- pass_ok: bool true/false according to the status of the run """ # Step 1. Check whether the fit has NaN'd and stop it if so if np.isnan(training_loss := training_info["loss"]): log.warning(" > NaN found, stopping activated") self.make_stop() return False # Step 2. Compute the validation metrics validation_info = self._validation.compute_losses() # Step 3. Register the current point in (the) history # and set the current final epoch as the current one fitstate = FitState(None, validation_info, training_loss) self._history.register(epoch, fitstate) if print_stats: self.print_current_stats(epoch, fitstate) # Step 4. Check whether this is a better fit # this means improving vl_chi2 and passing positivity # Don't start counting until the chi2 of the validation goes below a certain threshold # once we start counting, don't bother anymore passes = self._counts | (fitstate.vl_chi2 < self._threshold_chi2) passes &= fitstate.vl_loss < self._best_val_chi2s # And the ones that pass positivity passes &= self._positivity(fitstate) # Stop replicas that are ok being stopped (because they are finished or otherwise) passes &= self._dont_stop_me_now self._stopping_degrees += self._counts # Step 5. loop over the valid indices to check whether the vl improved for i_replica in np.where(passes)[0]: self._best_epochs[i_replica] = epoch # By definition, if we have a ``best_epoch`` then positivity passed self.positivity_statuses[i_replica] = POS_OK self._best_val_chi2s[i_replica] = self._history.get_state(epoch).vl_loss[i_replica] self._best_weights[i_replica] = self._pdf_model.get_replica_weights(i_replica) self._stopping_degrees[i_replica] = 0 self._counts[i_replica] = 1 stop_replicas = self._counts & (self._stopping_degrees > self.stopping_patience) for i_replica in np.where(stop_replicas)[0]: self._stop_epochs[i_replica] = epoch self._counts[i_replica] = 0 self._dont_stop_me_now[i_replica] = False # By using the stopping degree we only stop when none of the replicas are improving anymore if min(self._stopping_degrees) > self.stopping_patience: self.make_stop() return True
[docs] def make_stop(self): """Convenience method to set the stop_now flag and reload the history to the point of the best model if any """ self._stop_now = True self._restore_best_weights()
def _restore_best_weights(self): for i_replica, weights in enumerate(self._best_weights): if weights is not None: self._pdf_model.set_replica_weights(weights, i_replica)
[docs] def print_current_stats(self, epoch, fitstate): """ Prints ``fitstate`` validation chi2 for every experiment and the current total training loss as well as the validation loss after the training step """ epoch_index = epoch + 1 vl_chi2 = fitstate.total_vl_chi2() total_str = f"Epoch {epoch_index}/{self.total_epochs}: loss: {fitstate.tr_loss:.7f}" total_str += f"\nValidation loss after training step: {vl_chi2:.7f}." # The partial chi2 makes no sense for more than one replica at once: if self._n_replicas == 1: total_str += "\nValidation chi2s: " partial_vl_chi2 = fitstate.total_partial_vl_chi2() partials = [] for experiment, chi2 in partial_vl_chi2.items(): partials.append(f"{experiment}: {chi2:.3f}") total_str += ", ".join(partials) log.info(total_str)
[docs] def stop_here(self): """Returns the stopping status If `dont_stop` is set returns always False (i.e., never stop) """ if self._dont_stop: return False else: return self._stop_now
[docs] def chi2exps_json(self, i_replica=0, log_each=100): """ Returns and apt-for-json dictionary with the status of the fit every `log_each` epochs It reports the total training loss and the validation loss broken down by experiment. Parameters ---------- i_replica: int which replica are we writing the log for log_each: int every how many epochs to print the log Returns ------- file_list: list(str) a list of strings to be printed as `chi2exps.log` """ final_epoch = self._history.final_epoch json_dict = {} for epoch in range(log_each - 1, final_epoch + 1, log_each): fitstate = self._history.get_state(epoch) # Get the training and validation losses tmp = {"training_loss": fitstate.tr_loss, "validation_loss": fitstate.vl_loss.tolist()} # And the validation chi2 broken down by experiment tmp["validation_chi2s"] = fitstate.all_vl_chi2_for_replica(i_replica) json_dict[epoch + 1] = tmp return json_dict
[docs] class Positivity: """ Controls the positivity requirements. In order to check the positivity passes will check the history of the fitting as the fitting included positivity sets. If the sum of all positivity sets losses is above a certain value the model is not accepted and the training continues. Parameters ---------- threshold_positivity: float maximum value allowed for the sum of all positivity losses positivity_sets: list list of positivity datasets """ def __init__(self, threshold, positivity_sets): self.threshold = threshold self.positivity_sets = positivity_sets
[docs] def check_positivity(self, history_object): """ This function receives a history objects and loops over the positivity_sets to check the value of the positivity loss. If the positivity loss is above the threshold, the positivity fails otherwise, it passes. It returns an array booleans which are True if positivity passed story_object[key_loss] < self.threshold Parameters ---------- history_object: dict dictionary of entries in the form {'name': loss}, output of a MetaModel .fit() """ positivity_pass = True for key in self.positivity_sets: key_loss = f"{key}_loss" positivity_pass &= history_object[key_loss] < self.threshold return np.array(positivity_pass)
def __call__(self, fitstate): """ Checks whether a given FitState object passes the positivity requirement """ return self.check_positivity(fitstate.validation)