Solvers#

Variants of CEBRA solvers for single- and multi-session training.

This package contains wrappers around training loops. If you want to customize how different encoder models are used to transform the reference, positive and negative samples, how the loss functions are applied to the data, or adapt specifics on how results are logged, extending the classes in this package is the right way to go.

The module contains the base cebra.solver.base.Solver class along with multiple variants to deal with single- and multi-session datasets.

This module is a registry and currently contains the options [‘multi-session’, ‘multi-session-aux’, ‘single-session’, ‘single-session-aux’, ‘single-session-hybrid’, ‘single-session-full’, ‘supervised-solver-xcebra’, ‘multiobjective-solver’, ‘regularized-solver’].

To retrieve a list of options, call:

>>> print(cebra.solver.get_options())
['multi-session', 'multi-session-aux', 'single-session', ...]

To obtain an initialized instance, call cebra.solver.init, defined in cebra.registry.add_helper_functions(). The first parameter to provide is the solver name to use, which is one of the available options presented above. Then the required positional arguments specific to the module are provided, if needed.

You can register additional options by defining and registering classes with a name. To do that, you can add a decorator on top of it: @cebra.solver.register("my-cebra-solver").

Later, initialize your class similarly to the pre-defined options, using cebra.solver.init with the solver name set to my-cebra-solver.

Note that these customized options will not be automatically added to this docstring.

This package contains abstract base classes for different solvers.

Solvers are used to package models, criterions and optimizers and implement training loops. When subclassing abstract solvers, in the simplest case only the Solver._inference() needs to be overridden.

For more complex use cases, the Solver.step() and Solver.fit() method can be overridden to implement larger changes to the training loop.

cebra.solver.base._check_indices(batch_start_idx, batch_end_idx, offset, num_samples)#

Check that indices in a batch are in a correct range.

First and last index must be positive integers, smaller than the total length of inputs in the dataset, the first index must be smaller than the last and the batch size cannot be smaller than the offset of the model.

Parameters:
  • batch_start_idx (int) – Index of the first sample in the batch.

  • batch_end_idx (int) – Index of the first sample in the batch.

  • offset (Offset) – Model offset.

  • num_samples (int) – Total number of samples in the input.

cebra.solver.base._add_batched_zero_padding(batched_data, offset, batch_start_idx, batch_end_idx, num_samples)#

Add zero padding to the input data before inference.

Parameters:
  • batched_data (Tensor) – Data to apply the inference on.

  • offset (Offset) – Offset of the model to consider when padding.

  • batch_start_idx (int) – Index of the first sample in the batch.

  • batch_end_idx (int) – Index of the first sample in the batch.

  • num_samples (int) – Total number of samples in the data.

Return type:

Tensor

Returns:

The padded batch.

cebra.solver.base._get_batch(inputs, offset, batch_start_idx, batch_end_idx, pad_before_transform)#

Get a batch of samples between the batch_start_idx and batch_end_idx.

Parameters:
  • inputs (Tensor) – Input data.

  • offset (Optional[Offset]) – Model offset.

  • batch_start_idx (int) – Index of the first sample in the batch.

  • batch_end_idx (int) – Index of the last sample in the batch.

  • pad_before_transform (bool) – If True zero-pad the batched data.

Return type:

Tensor

Returns:

The batch.

cebra.solver.base._inference_transform(model, inputs)#

Compute the embedding on the inputs using the model provided.

Parameters:
  • model (Model) – Model to use for inference.

  • inputs (Tensor) – Data.

Return type:

Tensor

Returns:

The embedding.

cebra.solver.base._not_batched_transform(model, inputs, pad_before_transform, offset)#

Compute the embedding.

Parameters:
  • model (Model) – The model to use for inference.

  • inputs (Tensor) – Input data.

  • pad_before_transform (bool) – If True, the input data is zero padded before inference.

  • offset (Offset) – Model offset.

Returns:

The (potentially) padded data.

Return type:

torch.Tensor

Raises:

ValueError – If add_padding is True and offset is not provided.

cebra.solver.base._batched_transform(model, inputs, batch_size, pad_before_transform, offset)#

Compute the embedding on batched inputs.

Parameters:
  • model (Model) – The model to use for inference.

  • inputs (Tensor) – Input data.

  • batch_size (int) – Integer corresponding to the batch size.

  • pad_before_transform (bool) – If True, the input data is zero padded before inference.

  • offset (Offset) – Model offset.

Return type:

Tensor

Returns:

The embedding.

class cebra.solver.base.Solver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#

Bases: ABC, HasDevice

Solver base class.

A solver contains helper methods for bundling a model, criterion and optimizer.

model#

The encoder for transforming reference, positive and negative samples.

Type:

torch.nn.Module

criterion#

The criterion computed from the similarities between positive pairs and negative pairs. The criterion can have trainable parameters on its own.

Type:

torch.nn.Module

optimizer#

A PyTorch optimizer for updating model and criterion parameters.

Type:

torch.optim.Optimizer

history#

Deprecated since 0.0.2. Use log.

Type:

List

decode_history#

Deprecated since 0.0.2. Use a hook during training for validation and decoding. See the arguments of fit().

Type:

List

log#

The logs recorded during training, typically contains the total loss as well as the logs for positive (pos) and negative (neg) pairs. For the standard criterions in CEBRA, also contains the value of the temperature.

Type:

Dict

tqdm_on#

Use tqdm for showing a progress bar during training.

Type:

bool

state_dict()#

Return a dictionary fully describing the current solver state.

Return type:

dict

Returns:

State dictionary, including the state dictionary of the models and optimizer. Also contains the training history and the CEBRA version the model was trained with.

load_state_dict(state_dict, strict=True)#

Update the solver state with the given state_dict.

Parameters:
  • state_dict (dict) – Dictionary with parameters for the model, optimizer, and the past loss history for the solver.

  • strict (bool) – Make sure all states can be loaded. Set to False to allow to partially load the state for all given keys.

property num_parameters: int#

Total number of parameters in the encoder and criterion.

Return type:

int

abstract parameters(session_id=None)#

Iterate over all parameters of the model.

Parameters:

session_id (Optional[int]) – The session ID, an int between 0 and the number of sessions -1 for multisession, and set to None for single session.

Yields:

The parameters of the model.

abstract _set_fitted_params(loader)#

Set parameters once the solver is fitted.

Parameters:

loader (Loader) – Loader used to fit the solver.

fit(loader, valid_loader=None, *, save_frequency=None, valid_frequency=None, decode=False, logdir=None, save_hook=None)#

Train model for the specified number of steps.

Parameters:
  • loader (Loader) – Data loader, which is an iterator over cebra.data.Batch instances. Each batch contains reference, positive and negative input samples.

  • valid_loader (Optional[Loader]) – Data loader used for validation of the model.

  • save_frequency (Optional[int]) – If not None, the frequency for automatically saving model checkpoints to logdir.

  • valid_frequency (Optional[int]) – The frequency for running validation on the valid_loader instance.

  • logdir (Optional[str]) – The logging directory for writing model checkpoints. The checkpoints can be read again using the solver.load function, or manually via loading the state dict.

step(batch)#

Perform a single gradient update.

Parameters:

batch (Batch) – The input samples

Return type:

dict

Returns:

Dictionary containing training metrics.

validation(loader, session_id=None)#

Compute score of the model on data.

Parameters:
  • loader (Loader) – Data loader, which is an iterator over cebra.data.Batch instances. Each batch contains reference, positive and negative input samples.

  • session_id (Optional[int]) – The session ID, an int between 0 and the number of sessions -1 for multisession, and set to None for single session.

Returns:

Loss averaged over iterations on data batch.

decoding(train_loader, valid_loader)#

Deprecated since 0.0.2.

_check_is_inputs_valid(inputs, session_id)#

Check that the inputs can be inferred using the selected model.

Note: This method checks that the number of neurons in the input is similar to the input dimension to the selected model.

Parameters:
  • inputs (Tensor) – Data to infer using the selected model.

  • session_id (int) – The session ID, an int between 0 and the number of sessions -1 for multisession, and set to None for single session.

abstract _check_is_session_id_valid(session_id=None)#

Check that the session ID provided is valid for the solver instance.

Parameters:

session_id (Optional[int]) – The session ID to check.

_select_model(inputs, session_id)#

Select the model based on the input dimension and session ID.

Parameters:
  • inputs (Union[Tensor, List[Tensor]]) – Data to infer using the selected model.

  • session_id (Optional[int]) – The session ID, an int between 0 and the number of sessions -1 for multisession, and set to None for single session.

Return type:

Tuple[Union[List[Module], Module], Offset]

Returns:

The model (first returns) and the offset of the model (second returns).

abstract _get_model(session_id=None)#

Get the model to use for inference.

Parameters:

session_id (Optional[int]) – The session ID, an int between 0 and the number of sessions -1 for multisession, and set to None for single session.

Return type:

Model

Returns:

The model.

_check_is_fitted()#

Check if the model is fitted.

If the model is fitted, the solver should have a n_features attribute.

Raises:

ValueError – If the model is not fitted.

transform(inputs, pad_before_transform=True, session_id=None, batch_size=None)#

Compute the embedding.

This function by default only applies the forward function of the given model, after switching it into eval mode.

Parameters:
  • inputs (Tensor) – The input signal (T, N).

  • pad_before_transform (Optional[bool]) – If False, no padding is applied to the input sequence and the output sequence will be smaller than the input sequence due to the receptive field of the model. If the input sequence is n steps long, and a model with receptive field m is used, the output sequence would only be n-m+1 steps long.

  • session_id (Optional[int]) – The session ID, an int between 0 and the number of sessions -1 for multisession, and set to None for single session.

  • batch_size (Optional[int]) – If not None, batched inference will not be applied.

Return type:

Tensor

Returns:

The output embedding.

_transform(model, inputs, pad_before_transform, offset, batch_size)#

Compute the embedding on the inputs using the model provided.

Parameters:
  • model (Model) – Model to use for inference.

  • inputs (Tensor) – Data.

  • pad_before_transform (bool) – If True zero-pad the batched data.

  • offset (Offset) – Offset of the model to consider when padding.

  • batch_size (Optional[int]) – If not None, batched inference will not be applied.

Return type:

Tensor

Returns:

The embedding.

abstract _inference(batch)#

Given a batch of input examples, return the model outputs.

Parameters:

batch (Batch) – The input data, not necessarily aligned across the batch dimension. This means that batch.index specifies the map between reference/positive samples, if not equal None.

Return type:

Batch

Returns:

Processed batch of data. While the input data might not be aligned across the sample dimensions, the output data should be aligned and batch.index should be set to None.

load(logdir, filename='checkpoint.pth')#

Load the experiment from its checkpoint file.

Parameters:
  • logdir (str) – Logging directory.

  • filename (str) – Checkpoint name for loading the experiment.

save(logdir, filename='checkpoint_last.pth')#

Save the model and optimizer params.

Parameters:
  • logdir (str) – Logging directory for this model.

  • filename (str) – Checkpoint name for saving the experiment.

class cebra.solver.base.MultiobjectiveSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True, num_behavior_features=3, renormalize_features=False, output_mode='overlapping', ignore_deprecation_warning=False)#

Bases: Solver

Train models to satisfy multiple learning objectives.

This variant of the standard cebra.solver.base.Solver implements multi-objective or “hybrid” training.

model#

A multi-objective CEBRA model

Type:

torch.nn.Module

optimizer#

The optimizer used for training.

Type:

torch.optim.Optimizer

num_behavior_features#

The feature dimension for the features dedicated to satisfy the behavior contrastive objective. The remainder is used for time contrastive learning.

Type:

int

renormalize_features#

If True, normalize the behavior and time contrastive features individually before computing similarity scores.

Type:

bool

ignore_deprecation_warning#

If True, suppress the deprecation warning.

Type:

bool

Note

This solver will be deprecated in a future version. Please use the functionality in cebra.solver.multiobjective instead, which provides more versatile multi-objective training capabilities. Instantiation of this solver will raise a deprecation warning.

_check_dimensions()#

Check the feature dimensions for behavior/time contrastive learning.

Raises:

ValueError – If feature dimensions are larger than the model features, or not sufficiently large for renormalization.

step(batch)#

Perform a single gradient update with multiple objectives.

Parameters:

batch (Batch) – The input samples

Return type:

dict

Returns:

Dictionary containing training metrics.

class cebra.solver.base.AuxiliaryVariableSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#

Bases: Solver

transform(inputs, pad_before_transform=True, session_id=None, batch_size=None, use_reference_model=False)#

Compute the embedding. This function by default use model that was trained to encode the positive and negative samples. To use reference_model instead of model use_reference_model should be equal True. :type inputs: Tensor :param inputs: The input signal :type use_reference_model: bool :param use_reference_model: Flag for using reference_model

Return type:

Tensor

Returns:

The output embedding.

Single session training#

Single session solvers embed a single pair of time series.

class cebra.solver.single_session.SingleSessionSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#

Bases: Solver

Single session training with a symmetric encoder.

This solver assumes that reference, positive and negative samples are processed by the same features encoder and that a single session is provided to that encoder.

parameters(session_id=None)#

Iterate over all parameters.

Parameters:

session_id (Optional[int]) – The session ID, an int between 0 and the number of sessions -1 for multisession, and set to None for single session.

Yields:

The parameters of the model.

get_embedding(data)#

Return the embedding of the given input data.

Note

This function assumes that the input data is sliced according to the receptive field of the model. The input data needs to match batch x dims x len(self.model.get_offset()) which is internally reduced to batch x dims x 1. The last dimension is squeezed, and the output is of shape time x features.

This function does not perform checks for correctness of the input.

Parameters:

data (Tensor) – The input data tensor of shape batch_time x dims x time

Return type:

Tensor

Returns:

The output data tensor of shape batch_time x features.

class cebra.solver.single_session.SingleSessionAuxVariableSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True, reference_model=None)#

Bases: SingleSessionSolver, AuxiliaryVariableSolver

Single session training for reference and positive/negative samples.

This solver processes reference samples with a model different from processing the positive and negative samples. Requires that the reference_model is initialized to be different from the model used to process the positive and negative samples.

Besides using an asymmetric encoder for the same modality, this solver also allows for e.g. time-contrastive learning across modalities, by using a reference model on modality A, and a different model processing the signal from modality B.

class cebra.solver.single_session.SingleSessionHybridSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True, num_behavior_features=3, renormalize_features=False, output_mode='overlapping', ignore_deprecation_warning=False)#

Bases: MultiobjectiveSolver, SingleSessionSolver

Single session training, contrasting neural data against behavior.

class cebra.solver.single_session.BatchSingleSessionSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#

Bases: SingleSessionSolver

Optimize a model with batch gradient descent.

Usage of this solver requires a sufficient amount of GPU memory. Using this solver is equivalent to using a single session solver with batch size set to dataset size, but requires less computation.

fit(loader, *args, **kwargs)#

TODO

get_embedding(data)#

Compute the embedding of a full input dataset.

For convolutional models that implement cebra.models.model.ConvolutionalModelMixin), the embedding is computed via SingleSessionSolver.get_embedding().

For all other models, it is assumed that the data has shape (1, dim, time) and is transformed into (time, dim) format.

Parameters:

data – The input data

Returns:

The output embedding of shape (time, dimension)

Multi session training#

Solver implementations for multi-session datasetes.

class cebra.solver.multi_session.MultiSessionSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#

Bases: Solver

Multi session training, contrasting pairs of neural data.

parameters(session_id=None)#

Iterate over all parameters.

Parameters:

session_id (Optional[int]) – The session ID, an int between 0 and the number of sessions -1 for multisession, and set to None for single session.

Yields:

The parameters of the model.

validation(loader, session_id=None)#

Compute score of the model on data.

Parameters:
  • loader – Data loader, which is an iterator over cebra.data.datatypes.Batch instances. Each batch contains reference, positive and negative input samples.

  • session_id (Optional[int]) – The session ID, an integer between 0 and the number of sessions in the multisession model, set to None for single session.

Returns:

Loss averaged over iterations on data batch.

class cebra.solver.multi_session.MultiSessionAuxVariableSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#

Bases: MultiSessionSolver, AuxiliaryVariableSolver

Multi session training, contrasting neural data against behavior.

Training utilities#

Utility functions for solvers and their training loops.

class cebra.solver.util.Meter#

Bases: object

Track statistics of a metric.

add(value, num_elements=1)#

Add the value to the meter.

Parameters:
  • value (float) – The value to add to the meter.

  • num_elements (int) – Optional, if the value was already obtained by summing multiple elements (for example, loss values within a batch of data samples) and the average should be computed with respect to this unit.

property average: float#

Return the average value of the tracked metric.

Return type:

float

property sum: float#

Return the sum of all tracked values.

Return type:

float

class cebra.solver.util.ProgressBar(loader, log_format)#

Bases: object

Log and display values during training.

property use_tqdm: bool#

Display tqdm as the progress bar.

Return type:

bool

set_description(stats)#

Update the progress bar description.

The description is updated by computing a formatted string from the given stats in the format {key}: {value: .4f} with a space as the divider between dictionary elements.

Behavior depends on the selected progress bar.