Solvers#
Variants of CEBRA solvers for single- and multi-session training.
This package contains wrappers around training loops. If you want to customize how different encoder models are used to transform the reference, positive and negative samples, how the loss functions are applied to the data, or adapt specifics on how results are logged, extending the classes in this package is the right way to go.
The module contains the base cebra.solver.base.Solver
class along with multiple variants to deal with
single- and multi-session datasets.
This module is a registry and currently contains the options [‘multi-session’, ‘multi-session-aux’, ‘supervised-solver-xcebra’, ‘multiobjective-solver’, ‘single-session’, ‘single-session-aux’, ‘single-session-hybrid’, ‘single-session-full’, ‘regularized-solver’].
To retrieve a list of options, call:
>>> print(cebra.solver.get_options())
['multi-session', 'multi-session-aux', 'supervised-solver-xcebra', ...]
To obtain an initialized instance, call cebra.solver.init
,
defined in cebra.registry.add_helper_functions()
.
The first parameter to provide is the solver name to use,
which is one of the available options presented above.
Then the required positional arguments specific to the module are provided, if
needed.
You can register additional options by defining and registering
classes with a name. To do that, you can add a decorator on top of it:
@cebra.solver.register("my-cebra-solver")
.
Later, initialize your class similarly to the pre-defined options, using cebra.solver.init
with the solver name set to my-cebra-solver
.
Note that these customized options will not be automatically added to this docstring.
This package contains abstract base classes for different solvers.
Solvers are used to package models, criterions and optimizers and implement training
loops. When subclassing abstract solvers, in the simplest case only the
Solver._inference()
needs to be overridden.
For more complex use cases, the Solver.step()
and
Solver.fit()
method can be overridden to
implement larger changes to the training loop.
- class cebra.solver.base.Solver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#
-
Solver base class.
A solver contains helper methods for bundling a model, criterion and optimizer.
- model#
The encoder for transforming reference, positive and negative samples.
- Type:
- criterion#
The criterion computed from the similarities between positive pairs and negative pairs. The criterion can have trainable parameters on its own.
- Type:
- optimizer#
A PyTorch optimizer for updating model and criterion parameters.
- Type:
- decode_history#
Deprecated since 0.0.2. Use a hook during training for validation and decoding. See the arguments of
fit()
.- Type:
List
- log#
The logs recorded during training, typically contains the
total
loss as well as the logs for positive (pos
) and negative (neg
) pairs. For the standard criterions in CEBRA, also contains the value of thetemperature
.- Type:
Dict
- state_dict()#
Return a dictionary fully describing the current solver state.
- Return type:
- Returns:
State dictionary, including the state dictionary of the models and optimizer. Also contains the training history and the CEBRA version the model was trained with.
- load_state_dict(state_dict, strict=True)#
Update the solver state with the given state_dict.
- parameters()#
Iterate over all parameters.
- fit(loader, valid_loader=None, *, save_frequency=None, valid_frequency=None, decode=False, logdir=None, save_hook=None)#
Train model for the specified number of steps.
- Parameters:
loader (
Loader
) – Data loader, which is an iterator over cebra.data.Batch instances. Each batch contains reference, positive and negative input samples.valid_loader (
Optional
[Loader
]) – Data loader used for validation of the model.save_frequency (
Optional
[int
]) – If not None, the frequency for automatically saving model checkpoints to logdir.valid_frequency (
Optional
[int
]) – The frequency for running validation on thevalid_loader
instance.logdir (
Optional
[str
]) – The logging directory for writing model checkpoints. The checkpoints can be read again using the solver.load function, or manually via loading the state dict.
- step(batch)#
Perform a single gradient update.
- validation(loader, session_id=None)#
Compute score of the model on data.
- Parameters:
loader (
Loader
) – Data loader, which is an iterator over cebra.data.Batch instances. Each batch contains reference, positive and negative input samples.session_id (
Optional
[int
]) – The session ID, an integer between 0 and the number of sessions in the multisession model, set to None for single session.
- Returns:
Loss averaged over iterations on data batch.
- decoding(train_loader, valid_loader)#
Deprecated since 0.0.2.
- transform(inputs)#
Compute the embedding.
This function by default only applies the
forward
function of the given model, after switching it into eval mode.
- abstract _inference(batch)#
Given a batch of input examples, return the model outputs.
TODO: make this a public function?
- Parameters:
batch (
Batch
) – The input data, not necessarily aligned across the batch dimension. This means thatbatch.index
specifies the map between reference/positive samples, if not equalNone
.- Return type:
- Returns:
Processed batch of data. While the input data might not be aligned across the sample dimensions, the output data should be aligned and
batch.index
should be set toNone
.
- load(logdir, filename='checkpoint.pth')#
Load the experiment from its checkpoint file.
- Parameters:
filename (str) – Checkpoint name for loading the experiment.
- save(logdir, filename='checkpoint_last.pth')#
Save the model and optimizer params.
- Parameters:
logdir – Logging directory for this model.
filename – Checkpoint name for saving the experiment.
- class cebra.solver.base.MultiobjectiveSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True, num_behavior_features=3, renormalize_features=False, output_mode='overlapping', ignore_deprecation_warning=False)#
Bases:
Solver
Train models to satisfy multiple learning objectives.
This variant of the standard
cebra.solver.base.Solver
implements multi-objective or “hybrid” training.- model#
A multi-objective CEBRA model
- Type:
- optimizer#
The optimizer used for training.
- Type:
- num_behavior_features#
The feature dimension for the features dedicated to satisfy the behavior contrastive objective. The remainder is used for time contrastive learning.
- Type:
- renormalize_features#
If
True
, normalize the behavior and time contrastive features individually before computing similarity scores.- Type:
Note
This solver will be deprecated in a future version. Please use the functionality in
cebra.solver.multiobjective
instead, which provides more versatile multi-objective training capabilities. Instantiation of this solver will raise a deprecation warning.- _check_dimensions()#
Check the feature dimensions for behavior/time contrastive learning.
- Raises:
ValueError – If feature dimensions are larger than the model features, or not sufficiently large for renormalization.
Single session training#
Single session solvers embed a single pair of time series.
- class cebra.solver.single_session.SingleSessionSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#
Bases:
Solver
Single session training with a symmetric encoder.
This solver assumes that reference, positive and negative samples are processed by the same features encoder.
- get_embedding(data)#
Return the embedding of the given input data.
Note
This function assumes that the input data is sliced according to the receptive field of the model. The input data needs to match
batch x dims x len(self.model.get_offset())
which is internally reduced tobatch x dims x 1
. The last dimension is squeezed, and the output is of shapetime x features
.This function does not perform checks for correctness of the input.
- class cebra.solver.single_session.SingleSessionAuxVariableSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True, reference_model=None)#
Bases:
Solver
Single session training for reference and positive/negative samples.
This solver processes reference samples with a model different from processing the positive and negative samples. Requires that the
reference_model
is initialized to be different from themodel
used to process the positive and negative samples.Besides using an asymmetric encoder for the same modality, this solver also allows for e.g. time-contrastive learning across modalities, by using a reference model on modality A, and a different model processing the signal from modality B.
- class cebra.solver.single_session.SingleSessionHybridSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True, num_behavior_features=3, renormalize_features=False, output_mode='overlapping', ignore_deprecation_warning=False)#
Bases:
MultiobjectiveSolver
Single session training, contrasting neural data against behavior.
- class cebra.solver.single_session.BatchSingleSessionSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#
Bases:
SingleSessionSolver
Optimize a model with batch gradient descent.
Usage of this solver requires a sufficient amount of GPU memory. Using this solver is equivalent to using a single session solver with batch size set to dataset size, but requires less computation.
- fit(loader, *args, **kwargs)#
TODO
- get_embedding(data)#
Compute the embedding of a full input dataset.
For convolutional models that implement
cebra.models.model.ConvolutionalModelMixin
), the embedding is computed viaSingleSessionSolver.get_embedding()
.For all other models, it is assumed that the data has shape
(1, dim, time)
and is transformed into(time, dim)
format.- Parameters:
data – The input data
- Returns:
The output embedding of shape
(time, dimension)
Multi session training#
Solver implementations for multi-session datasetes.
- class cebra.solver.multi_session.MultiSessionSolver(model, criterion, optimizer, history=<factory>, decode_history=<factory>, log=<factory>, tqdm_on=True)#
Bases:
Solver
Multi session training, contrasting pairs of neural data.
- validation(loader, session_id=None)#
Compute score of the model on data.
Note
Overrides
cebra.solver.base.Solver.validation()
incebra.solver.base.Solver
.- Parameters:
loader – Data loader, which is an iterator over
cebra.data.datatypes.Batch
instances. Each batch contains reference, positive and negative input samples.session_id (
Optional
[int
]) – The session ID, an integer between 0 and the number of sessions in the multisession model, set to None for single session.
- Returns:
Loss averaged over iterations on data batch.
Training utilities#
Utility functions for solvers and their training loops.
- class cebra.solver.util.Meter#
Bases:
object
Track statistics of a metric.
- add(value, num_elements=1)#
Add the value to the meter.
- class cebra.solver.util.ProgressBar(loader, log_format)#
Bases:
object
Log and display values during training.
- set_description(stats)#
Update the progress bar description.
The description is updated by computing a formatted string from the given
stats
in the format{key}: {value: .4f}
with a space as the divider between dictionary elements.Behavior depends on the selected progress bar.