Source code for n3fit.layers.mask

import numpy as np

from n3fit.backends import MetaLayer
from n3fit.backends import operations as op


[docs] class Mask(MetaLayer): """ This layers applies a boolean mask to an input tensor. The mask admit a multiplier for all outputs which will be internally saved as a weight so it can be updated during trainig. Typical usage is to apply training/validation split masks or applying a multiplier to a given layer Parameters ---------- bool_mask: np.array of shape (n_replicas, n_features) numpy array with the boolean mask to be applied c: float constant multiplier for every output """ def __init__(self, bool_mask=None, c=None, **kwargs): self._raw_mask = bool_mask self._flattened_indices = None if bool_mask is None: self.mask = None self.last_dim = -1 else: self.mask = op.numpy_to_tensor(bool_mask, dtype=bool) self.last_dim = np.count_nonzero(bool_mask[0, ...]) self.c = c self.masked_output_shape = None super().__init__(**kwargs)
[docs] def build(self, input_shape): if self.c is not None: initializer = MetaLayer.init_constant(value=self.c) self.kernel = self.builder_helper("mask", (1,), initializer, trainable=False) # Make sure reshape will succeed: set the last dimension to the unmasked data length and before-last to # the number of replicas if self.mask is not None: # Prepare the indices to mask indices = np.where(self._raw_mask) # The batch dimension can be ignored nreps = self.mask.shape[-2] self._flattened_indices = np.ravel_multi_index(indices, self._raw_mask.shape) self.masked_output_shape = [-1 if d is None else d for d in input_shape] self.masked_output_shape[-1] = self.last_dim self.masked_output_shape[-2] = nreps super().build(input_shape)
[docs] def call(self, ret): """ Apply the mask to the input tensor, and multiply by the constant if present. Parameters ---------- ret: Tensor of shape (batch_size, n_replicas, n_features) Returns ------- Tensor of shape (batch_size, n_replicas, n_features) """ if self.mask is not None: ret = op.gather(op.flatten(ret), self._flattened_indices) ret = op.reshape(ret, self.masked_output_shape) if self.c is not None: ret = ret * self.kernel return ret