Solvers#

Variants of CEBRA solvers for single- and multi-session training.

This package contains wrappers around training loops. If you want to customize how different encoder models are used to transform the reference, positive and negative samples, how the loss functions are applied to the data, or adapt specifics on how results are logged, extending the classes in this package is the right way to go.

The module contains the base cebra.solver.base.Solver class along with multiple variants to deal with single- and multi-session datasets.

This module is a registry and currently contains the options [‘multi-session’, ‘multi-session-aux’, ‘supervised-solver-xcebra’, ‘multiobjective-solver’, ‘single-session’, ‘single-session-aux’, ‘single-session-hybrid’, ‘single-session-full’, ‘regularized-solver’].

To retrieve a list of options, call:

>>> print(cebra.solver.get_options())
['multi-session', 'multi-session-aux', 'supervised-solver-xcebra', ...]

To obtain an initialized instance, call cebra.solver.init, defined in cebra.registry.add_helper_functions(). The first parameter to provide is the solver name to use, which is one of the available options presented above. Then the required positional arguments specific to the module are provided, if needed.

You can register additional options by defining and registering classes with a name. To do that, you can add a decorator on top of it: @cebra.solver.register("my-cebra-solver").

Later, initialize your class similarly to the pre-defined options, using cebra.solver.init with the solver name set to my-cebra-solver.

Note that these customized options will not be automatically added to this docstring.

This package contains abstract base classes for different solvers.

Solvers are used to package models, criterions and optimizers and implement training loops. When subclassing abstract solvers, in the simplest case only the Solver._inference() needs to be overridden.

For more complex use cases, the Solver.step() and Solver.fit() method can be overridden to implement larger changes to the training loop.

class cebra.solver.base.Solver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#

Bases: ABC, HasDevice

Solver base class.

A solver contains helper methods for bundling a model, criterion and optimizer.

model#

The encoder for transforming reference, positive and negative samples.

Type:

torch.nn.Module

criterion#

The criterion computed from the similarities between positive pairs and negative pairs. The criterion can have trainable parameters on its own.

Type:

torch.nn.Module

optimizer#

A PyTorch optimizer for updating model and criterion parameters.

Type:

torch.optim.Optimizer

history#

Deprecated since 0.0.2. Use log.

Type:

List

decode_history#

Deprecated since 0.0.2. Use a hook during training for validation and decoding. See the arguments of fit().

Type:

List

log#

The logs recorded during training, typically contains the total loss as well as the logs for positive (pos) and negative (neg) pairs. For the standard criterions in CEBRA, also contains the value of the temperature.

Type:

Dict

tqdm_on#

Use tqdm for showing a progress bar during training.

Type:

bool

state_dict()#

Return a dictionary fully describing the current solver state.

Return type:

dict

Returns:

State dictionary, including the state dictionary of the models and optimizer. Also contains the training history and the CEBRA version the model was trained with.

load_state_dict(state_dict, strict=True)#

Update the solver state with the given state_dict.

Parameters:
  • state_dict (dict) – Dictionary with parameters for the model, optimizer, and the past loss history for the solver.

  • strict (bool) – Make sure all states can be loaded. Set to False to allow to partially load the state for all given keys.

property num_parameters: int#

Total number of parameters in the encoder and criterion.

Return type:

int

parameters()#

Iterate over all parameters.

fit(loader, valid_loader=None, *, save_frequency=None, valid_frequency=None, decode=False, logdir=None, save_hook=None)#

Train model for the specified number of steps.

Parameters:
  • loader (Loader) – Data loader, which is an iterator over cebra.data.Batch instances. Each batch contains reference, positive and negative input samples.

  • valid_loader (Optional[Loader]) – Data loader used for validation of the model.

  • save_frequency (Optional[int]) – If not None, the frequency for automatically saving model checkpoints to logdir.

  • valid_frequency (Optional[int]) – The frequency for running validation on the valid_loader instance.

  • logdir (Optional[str]) – The logging directory for writing model checkpoints. The checkpoints can be read again using the solver.load function, or manually via loading the state dict.

step(batch)#

Perform a single gradient update.

Parameters:

batch (Batch) – The input samples

Return type:

dict

Returns:

Dictionary containing training metrics.

validation(loader, session_id=None)#

Compute score of the model on data.

Parameters:
  • loader (Loader) – Data loader, which is an iterator over cebra.data.Batch instances. Each batch contains reference, positive and negative input samples.

  • session_id (Optional[int]) – The session ID, an integer between 0 and the number of sessions in the multisession model, set to None for single session.

Returns:

Loss averaged over iterations on data batch.

decoding(train_loader, valid_loader)#

Deprecated since 0.0.2.

transform(inputs)#

Compute the embedding.

This function by default only applies the forward function of the given model, after switching it into eval mode.

Parameters:

inputs (Tensor) – The input signal

Return type:

Tensor

Returns:

The output embedding.

abstract _inference(batch)#

Given a batch of input examples, return the model outputs.

TODO: make this a public function?

Parameters:

batch (Batch) – The input data, not necessarily aligned across the batch dimension. This means that batch.index specifies the map between reference/positive samples, if not equal None.

Return type:

Batch

Returns:

Processed batch of data. While the input data might not be aligned across the sample dimensions, the output data should be aligned and batch.index should be set to None.

load(logdir, filename='checkpoint.pth')#

Load the experiment from its checkpoint file.

Parameters:

filename (str) – Checkpoint name for loading the experiment.

save(logdir, filename='checkpoint_last.pth')#

Save the model and optimizer params.

Parameters:
  • logdir – Logging directory for this model.

  • filename – Checkpoint name for saving the experiment.

class cebra.solver.base.MultiobjectiveSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True, num_behavior_features=3, renormalize_features=False, output_mode='overlapping', ignore_deprecation_warning=False)#

Bases: Solver

Train models to satisfy multiple learning objectives.

This variant of the standard cebra.solver.base.Solver implements multi-objective or “hybrid” training.

model#

A multi-objective CEBRA model

Type:

torch.nn.Module

optimizer#

The optimizer used for training.

Type:

torch.optim.Optimizer

num_behavior_features#

The feature dimension for the features dedicated to satisfy the behavior contrastive objective. The remainder is used for time contrastive learning.

Type:

int

renormalize_features#

If True, normalize the behavior and time contrastive features individually before computing similarity scores.

Type:

bool

ignore_deprecation_warning#

If True, suppress the deprecation warning.

Type:

bool

Note

This solver will be deprecated in a future version. Please use the functionality in cebra.solver.multiobjective instead, which provides more versatile multi-objective training capabilities. Instantiation of this solver will raise a deprecation warning.

_check_dimensions()#

Check the feature dimensions for behavior/time contrastive learning.

Raises:

ValueError – If feature dimensions are larger than the model features, or not sufficiently large for renormalization.

step(batch)#

Perform a single gradient update with multiple objectives.

Parameters:

batch (Batch) – The input samples

Return type:

dict

Returns:

Dictionary containing training metrics.

Single session training#

Single session solvers embed a single pair of time series.

class cebra.solver.single_session.SingleSessionSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#

Bases: Solver

Single session training with a symmetric encoder.

This solver assumes that reference, positive and negative samples are processed by the same features encoder.

get_embedding(data)#

Return the embedding of the given input data.

Note

This function assumes that the input data is sliced according to the receptive field of the model. The input data needs to match batch x dims x len(self.model.get_offset()) which is internally reduced to batch x dims x 1. The last dimension is squeezed, and the output is of shape time x features.

This function does not perform checks for correctness of the input.

Parameters:

data (Tensor) – The input data tensor of shape batch_time x dims x time

Return type:

Tensor

Returns:

The output data tensor of shape batch_time x features.

class cebra.solver.single_session.SingleSessionAuxVariableSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True, reference_model=None)#

Bases: Solver

Single session training for reference and positive/negative samples.

This solver processes reference samples with a model different from processing the positive and negative samples. Requires that the reference_model is initialized to be different from the model used to process the positive and negative samples.

Besides using an asymmetric encoder for the same modality, this solver also allows for e.g. time-contrastive learning across modalities, by using a reference model on modality A, and a different model processing the signal from modality B.

class cebra.solver.single_session.SingleSessionHybridSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True, num_behavior_features=3, renormalize_features=False, output_mode='overlapping', ignore_deprecation_warning=False)#

Bases: MultiobjectiveSolver

Single session training, contrasting neural data against behavior.

class cebra.solver.single_session.BatchSingleSessionSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#

Bases: SingleSessionSolver

Optimize a model with batch gradient descent.

Usage of this solver requires a sufficient amount of GPU memory. Using this solver is equivalent to using a single session solver with batch size set to dataset size, but requires less computation.

fit(loader, *args, **kwargs)#

TODO

get_embedding(data)#

Compute the embedding of a full input dataset.

For convolutional models that implement cebra.models.model.ConvolutionalModelMixin), the embedding is computed via SingleSessionSolver.get_embedding().

For all other models, it is assumed that the data has shape (1, dim, time) and is transformed into (time, dim) format.

Parameters:

data – The input data

Returns:

The output embedding of shape (time, dimension)

Multi session training#

Solver implementations for multi-session datasetes.

class cebra.solver.multi_session.MultiSessionSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#

Bases: Solver

Multi session training, contrasting pairs of neural data.

validation(loader, session_id=None)#

Compute score of the model on data.

Parameters:
  • loader – Data loader, which is an iterator over cebra.data.datatypes.Batch instances. Each batch contains reference, positive and negative input samples.

  • session_id (Optional[int]) – The session ID, an integer between 0 and the number of sessions in the multisession model, set to None for single session.

Returns:

Loss averaged over iterations on data batch.

class cebra.solver.multi_session.MultiSessionAuxVariableSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#

Bases: Solver

Multi session training, contrasting neural data against behavior.

Training utilities#

Utility functions for solvers and their training loops.

class cebra.solver.util.Meter#

Bases: object

Track statistics of a metric.

add(value, num_elements=1)#

Add the value to the meter.

Parameters:
  • value (float) – The value to add to the meter.

  • num_elements (int) – Optional, if the value was already obtained by summing multiple elements (for example, loss values within a batch of data samples) and the average should be computed with respect to this unit.

property average: float#

Return the average value of the tracked metric.

Return type:

float

property sum: float#

Return the sum of all tracked values.

Return type:

float

class cebra.solver.util.ProgressBar(loader, log_format)#

Bases: object

Log and display values during training.

property use_tqdm: bool#

Display tqdm as the progress bar.

Return type:

bool

set_description(stats)#

Update the progress bar description.

The description is updated by computing a formatted string from the given stats in the format {key}: {value: .4f} with a space as the divider between dictionary elements.

Behavior depends on the selected progress bar.