Solvers#
Variants of CEBRA solvers for single- and multi-session training.
This package contains wrappers around training loops. If you want to customize how different encoder models are used to transform the reference, positive and negative samples, how the loss functions are applied to the data, or adapt specifics on how results are logged, extending the classes in this package is the right way to go.
The module contains the base cebra.solver.base.Solver
class along with multiple variants to deal with
single- and multi-session datasets.
This module is a registry and currently contains the options [‘multi-session’, ‘multi-session-aux’, ‘single-session’, ‘single-session-aux’, ‘single-session-hybrid’, ‘single-session-full’, ‘supervised-solver-xcebra’, ‘multiobjective-solver’, ‘regularized-solver’].
To retrieve a list of options, call:
>>> print(cebra.solver.get_options())
['multi-session', 'multi-session-aux', 'single-session', ...]
To obtain an initialized instance, call cebra.solver.init
,
defined in cebra.registry.add_helper_functions()
.
The first parameter to provide is the solver name to use,
which is one of the available options presented above.
Then the required positional arguments specific to the module are provided, if
needed.
You can register additional options by defining and registering
classes with a name. To do that, you can add a decorator on top of it:
@cebra.solver.register("my-cebra-solver")
.
Later, initialize your class similarly to the pre-defined options, using cebra.solver.init
with the solver name set to my-cebra-solver
.
Note that these customized options will not be automatically added to this docstring.
This package contains abstract base classes for different solvers.
Solvers are used to package models, criterions and optimizers and implement training
loops. When subclassing abstract solvers, in the simplest case only the
Solver._inference()
needs to be overridden.
For more complex use cases, the Solver.step()
and
Solver.fit()
method can be overridden to
implement larger changes to the training loop.
- cebra.solver.base._check_indices(batch_start_idx, batch_end_idx, offset, num_samples)#
Check that indices in a batch are in a correct range.
First and last index must be positive integers, smaller than the total length of inputs in the dataset, the first index must be smaller than the last and the batch size cannot be smaller than the offset of the model.
- cebra.solver.base._add_batched_zero_padding(batched_data, offset, batch_start_idx, batch_end_idx, num_samples)#
Add zero padding to the input data before inference.
- Parameters:
batched_data (
Tensor
) – Data to apply the inference on.offset (
Offset
) – Offset of the model to consider when padding.batch_start_idx (
int
) – Index of the first sample in the batch.batch_end_idx (
int
) – Index of the first sample in the batch.num_samples (int) – Total number of samples in the data.
- Return type:
- Returns:
The padded batch.
- cebra.solver.base._get_batch(inputs, offset, batch_start_idx, batch_end_idx, pad_before_transform)#
Get a batch of samples between the batch_start_idx and batch_end_idx.
- Parameters:
- Return type:
- Returns:
The batch.
- cebra.solver.base._inference_transform(model, inputs)#
Compute the embedding on the inputs using the model provided.
- cebra.solver.base._not_batched_transform(model, inputs, pad_before_transform, offset)#
Compute the embedding.
- Parameters:
- Returns:
The (potentially) padded data.
- Return type:
- Raises:
ValueError – If add_padding is True and offset is not provided.
- cebra.solver.base._batched_transform(model, inputs, batch_size, pad_before_transform, offset)#
Compute the embedding on batched inputs.
- Parameters:
- Return type:
- Returns:
The embedding.
- class cebra.solver.base.Solver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#
-
Solver base class.
A solver contains helper methods for bundling a model, criterion and optimizer.
- model#
The encoder for transforming reference, positive and negative samples.
- Type:
- criterion#
The criterion computed from the similarities between positive pairs and negative pairs. The criterion can have trainable parameters on its own.
- Type:
- optimizer#
A PyTorch optimizer for updating model and criterion parameters.
- Type:
- decode_history#
Deprecated since 0.0.2. Use a hook during training for validation and decoding. See the arguments of
fit()
.- Type:
List
- log#
The logs recorded during training, typically contains the
total
loss as well as the logs for positive (pos
) and negative (neg
) pairs. For the standard criterions in CEBRA, also contains the value of thetemperature
.- Type:
Dict
- state_dict()#
Return a dictionary fully describing the current solver state.
- Return type:
- Returns:
State dictionary, including the state dictionary of the models and optimizer. Also contains the training history and the CEBRA version the model was trained with.
- load_state_dict(state_dict, strict=True)#
Update the solver state with the given state_dict.
- abstract parameters(session_id=None)#
Iterate over all parameters of the model.
- abstract _set_fitted_params(loader)#
Set parameters once the solver is fitted.
- Parameters:
loader (
Loader
) – Loader used to fit the solver.
- fit(loader, valid_loader=None, *, save_frequency=None, valid_frequency=None, decode=False, logdir=None, save_hook=None)#
Train model for the specified number of steps.
- Parameters:
loader (
Loader
) – Data loader, which is an iterator over cebra.data.Batch instances. Each batch contains reference, positive and negative input samples.valid_loader (
Optional
[Loader
]) – Data loader used for validation of the model.save_frequency (
Optional
[int
]) – If not None, the frequency for automatically saving model checkpoints to logdir.valid_frequency (
Optional
[int
]) – The frequency for running validation on thevalid_loader
instance.logdir (
Optional
[str
]) – The logging directory for writing model checkpoints. The checkpoints can be read again using the solver.load function, or manually via loading the state dict.
- step(batch)#
Perform a single gradient update.
- validation(loader, session_id=None)#
Compute score of the model on data.
- Parameters:
- Returns:
Loss averaged over iterations on data batch.
- decoding(train_loader, valid_loader)#
Deprecated since 0.0.2.
- _check_is_inputs_valid(inputs, session_id)#
Check that the inputs can be inferred using the selected model.
Note: This method checks that the number of neurons in the input is similar to the input dimension to the selected model.
- abstract _check_is_session_id_valid(session_id=None)#
Check that the session ID provided is valid for the solver instance.
- _select_model(inputs, session_id)#
Select the model based on the input dimension and session ID.
- Parameters:
- Return type:
- Returns:
The model (first returns) and the offset of the model (second returns).
- abstract _get_model(session_id=None)#
Get the model to use for inference.
- _check_is_fitted()#
Check if the model is fitted.
If the model is fitted, the solver should have a n_features attribute.
- Raises:
ValueError – If the model is not fitted.
- transform(inputs, pad_before_transform=True, session_id=None, batch_size=None)#
Compute the embedding.
This function by default only applies the
forward
function of the given model, after switching it into eval mode.- Parameters:
inputs (
Tensor
) – The input signal (T, N).pad_before_transform (
Optional
[bool
]) – IfFalse
, no padding is applied to the input sequence and the output sequence will be smaller than the input sequence due to the receptive field of the model. If the input sequence isn
steps long, and a model with receptive fieldm
is used, the output sequence would only ben-m+1
steps long.session_id (
Optional
[int
]) – The session ID, anint
between 0 and the number of sessions -1 for multisession, and set toNone
for single session.batch_size (
Optional
[int
]) – If not None, batched inference will not be applied.
- Return type:
- Returns:
The output embedding.
- _transform(model, inputs, pad_before_transform, offset, batch_size)#
Compute the embedding on the inputs using the model provided.
- Parameters:
- Return type:
- Returns:
The embedding.
- abstract _inference(batch)#
Given a batch of input examples, return the model outputs.
- Parameters:
batch (
Batch
) – The input data, not necessarily aligned across the batch dimension. This means thatbatch.index
specifies the map between reference/positive samples, if not equalNone
.- Return type:
- Returns:
Processed batch of data. While the input data might not be aligned across the sample dimensions, the output data should be aligned and
batch.index
should be set toNone
.
- load(logdir, filename='checkpoint.pth')#
Load the experiment from its checkpoint file.
- class cebra.solver.base.MultiobjectiveSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True, num_behavior_features=3, renormalize_features=False, output_mode='overlapping', ignore_deprecation_warning=False)#
Bases:
Solver
Train models to satisfy multiple learning objectives.
This variant of the standard
cebra.solver.base.Solver
implements multi-objective or “hybrid” training.- model#
A multi-objective CEBRA model
- Type:
- optimizer#
The optimizer used for training.
- Type:
- num_behavior_features#
The feature dimension for the features dedicated to satisfy the behavior contrastive objective. The remainder is used for time contrastive learning.
- Type:
- renormalize_features#
If
True
, normalize the behavior and time contrastive features individually before computing similarity scores.- Type:
Note
This solver will be deprecated in a future version. Please use the functionality in
cebra.solver.multiobjective
instead, which provides more versatile multi-objective training capabilities. Instantiation of this solver will raise a deprecation warning.- _check_dimensions()#
Check the feature dimensions for behavior/time contrastive learning.
- Raises:
ValueError – If feature dimensions are larger than the model features, or not sufficiently large for renormalization.
- class cebra.solver.base.AuxiliaryVariableSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#
Bases:
Solver
- transform(inputs, pad_before_transform=True, session_id=None, batch_size=None, use_reference_model=False)#
Compute the embedding. This function by default use
model
that was trained to encode the positive and negative samples. To usereference_model
instead ofmodel
use_reference_model
should be equalTrue
. :type inputs:Tensor
:param inputs: The input signal :type use_reference_model:bool
:param use_reference_model: Flag for usingreference_model
- Return type:
- Returns:
The output embedding.
Single session training#
Single session solvers embed a single pair of time series.
- class cebra.solver.single_session.SingleSessionSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#
Bases:
Solver
Single session training with a symmetric encoder.
This solver assumes that reference, positive and negative samples are processed by the same features encoder and that a single session is provided to that encoder.
- parameters(session_id=None)#
Iterate over all parameters.
- get_embedding(data)#
Return the embedding of the given input data.
Note
This function assumes that the input data is sliced according to the receptive field of the model. The input data needs to match
batch x dims x len(self.model.get_offset())
which is internally reduced tobatch x dims x 1
. The last dimension is squeezed, and the output is of shapetime x features
.This function does not perform checks for correctness of the input.
- class cebra.solver.single_session.SingleSessionAuxVariableSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True, reference_model=None)#
Bases:
SingleSessionSolver
,AuxiliaryVariableSolver
Single session training for reference and positive/negative samples.
This solver processes reference samples with a model different from processing the positive and negative samples. Requires that the
reference_model
is initialized to be different from themodel
used to process the positive and negative samples.Besides using an asymmetric encoder for the same modality, this solver also allows for e.g. time-contrastive learning across modalities, by using a reference model on modality A, and a different model processing the signal from modality B.
- class cebra.solver.single_session.SingleSessionHybridSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True, num_behavior_features=3, renormalize_features=False, output_mode='overlapping', ignore_deprecation_warning=False)#
Bases:
MultiobjectiveSolver
,SingleSessionSolver
Single session training, contrasting neural data against behavior.
- class cebra.solver.single_session.BatchSingleSessionSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#
Bases:
SingleSessionSolver
Optimize a model with batch gradient descent.
Usage of this solver requires a sufficient amount of GPU memory. Using this solver is equivalent to using a single session solver with batch size set to dataset size, but requires less computation.
- fit(loader, *args, **kwargs)#
TODO
- get_embedding(data)#
Compute the embedding of a full input dataset.
For convolutional models that implement
cebra.models.model.ConvolutionalModelMixin
), the embedding is computed viaSingleSessionSolver.get_embedding()
.For all other models, it is assumed that the data has shape
(1, dim, time)
and is transformed into(time, dim)
format.- Parameters:
data – The input data
- Returns:
The output embedding of shape
(time, dimension)
Multi session training#
Solver implementations for multi-session datasetes.
- class cebra.solver.multi_session.MultiSessionSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#
Bases:
Solver
Multi session training, contrasting pairs of neural data.
- parameters(session_id=None)#
Iterate over all parameters.
- validation(loader, session_id=None)#
Compute score of the model on data.
Note
Overrides
cebra.solver.base.Solver.validation()
incebra.solver.base.Solver
.- Parameters:
loader – Data loader, which is an iterator over
cebra.data.datatypes.Batch
instances. Each batch contains reference, positive and negative input samples.session_id (
Optional
[int
]) – The session ID, an integer between 0 and the number of sessions in the multisession model, set to None for single session.
- Returns:
Loss averaged over iterations on data batch.
- class cebra.solver.multi_session.MultiSessionAuxVariableSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#
Bases:
MultiSessionSolver
,AuxiliaryVariableSolver
Multi session training, contrasting neural data against behavior.
Training utilities#
Utility functions for solvers and their training loops.
- class cebra.solver.util.Meter#
Bases:
object
Track statistics of a metric.
- add(value, num_elements=1)#
Add the value to the meter.
- class cebra.solver.util.ProgressBar(loader, log_format)#
Bases:
object
Log and display values during training.
- set_description(stats)#
Update the progress bar description.
The description is updated by computing a formatted string from the given
stats
in the format{key}: {value: .4f}
with a space as the divider between dictionary elements.Behavior depends on the selected progress bar.