Distributions#
Distributions and indexing helper functions for training CEBRA models.
This package contains classes for sampling and indexing of datasets. Typically, the functionality of classes in this module is guided by the auxiliary variables of CEBRA. A dataset would pass auxiliary variables to a sampler, and within the sampler the indices of reference, negative and positive samples will be sampled based on the auxiliary information. Custom ways of sampling should therefore be implemented in this package. Functionality in this package is fully agnostic to the actual signal to be analysed, and only considers the auxiliary information of a dataset (called “index”).
Distributions take data samples and allow to sample or re-sample from the dataset. Sampling from the prior distribution is done via “sample_prior”, sampling from the conditional distribution via “sample_conditional”.
For fast lookups in datasets, indexing classes provide 1-nearest-neighbor searches with L2 and cosine similarity metrics (recommended on GPU) or using standard multi-threaded dataloading with FAISS as the backend for retrieving data.
Base classes#
Abstract base classes for distributions and indices.
Contrastive learning in CEBRA requires a prior and conditional distribution. Distributions are defined in terms of _indices_ that reference samples within the dataset.
The appropriate base classes are defined in this module: An Index
is the
part of the dataset used to inform the prior and conditional distributions;
and could for example be time, or information about an experimental condition.
- class cebra.distributions.base.HasGenerator(device, seed)#
Bases:
HasDevice
Base class for all distributions implementing seeding.
- Parameters:
Note
This class is not fully functional yet. Functionality and API might slightly change in upcoming versions. Do not rely (yet) on seeding provided by this class, but start using it to integrate seeding in all parts of
cebra.distributions
.- property generator: torch.Generator#
The generator object.
Can be used in many sampling methods provided by
torch
.- Return type:
- to(device)#
Move the instance to the specified device.
- randint(*args, **kwargs)#
Generate random integers.
See docs of
torch.randint
for information on the arguments.- Return type:
- class cebra.distributions.base.Index(device=None)#
-
Base class for indexed datasets.
Indexes contain functionality to pass a query vector, and return the indices of the closest matches within the index.
- abstract search(query)#
Return index of entry closest to query.
- Parameters:
query – The query tensor to look for. The index computes the closest element to this query tensor and returns its location within the index. (TODO: add type)
- Return type:
- Returns:
The index of the element closest to the query in the dataset.
- class cebra.distributions.base.PriorDistribution#
Bases:
ABC
Mixin for all prior distributions.
Prior distributions return a batch of indices. Indexing the dataset with these indices will return samples from the prior distribution.
- class cebra.distributions.base.ConditionalDistribution#
Bases:
ABC
Mixin for all conditional distributions.
Conditional distributions return a batch of indices, based on a given batch of indices. Indexing the dataset with these indices will return samples from the conditional distribution.
- abstract sample_conditional(query)#
Return indices for the conditional distribution samples
- class cebra.distributions.base.JointDistribution#
Bases:
PriorDistribution
,ConditionalDistribution
Mixin for joint distributions.
Index#
Index operations for conditional sampling.
Indexing operations—in contrast to data distributions—exhibit deterministic behavior by returning an element closest in the dataset to a given query sample. This module contains helper functions for mixed and continuously indexed datasets (i.e., containing discrete and/or continuous data).
Discrete data has to come in the format of a single label for each datapoint. Multidimensional discrete labels should be converted accordingly.
- class cebra.distributions.index.DistanceMatrix(samples)#
Bases:
HasDevice
Compute shortest distances between dataset samples.
- Parameters:
samples (
Tensor
) – The continuous values that will be used to index the dataset and specify the conditional distribution.
Note
This implementation is not particularly efficient on very large datasets. For these cases, packages like FAISS offer more optimized retrieval functions.
As a rule of thumb, using this class is suitable for datasets for which the dataset can be hosted on GPU memory.
- class cebra.distributions.index.OffsetDistanceMatrix(samples, offset=1)#
Bases:
DistanceMatrix
Compute shortest distances, ignoring samples close to the boundary.
Compared to the standard
DistanceMatrix
, this class should be used for datasets and learning setups where multiple timesteps are fed into the network at once — the samples close to the time-series boundary should be ignored in the sampling process in these cases.- Parameters:
samples – The continuous values that will be used to index the dataset and specify the conditional distribution.
offset (
int
) – The number of timesteps to ignored at each size of the dataset
- class cebra.distributions.index.ContinuousIndex(index)#
-
Naive nearest neighbor search implementation.
- index: tensor(N, d)
the values used for kNN search
- offset: int or (int,int)
the time offset in each direction
- search(query)#
Return index location closest to query.
- class cebra.distributions.index.ConditionalIndex(discrete, continuous)#
Bases:
Index
Index a dataset based on both continuous and discrete information.
In contrast to the standard
base.Index
class, theConditionalIndex
accept both discrete and continuous indexing information.This index considers the discrete indexing information first to identify possible positive pairs. Then among these candidate samples, behaves like an
base.Index
and returns the samples closest in terms of the information in the continuous index.- Parameters:
discrete – The discrete indexing information, which should be limited to a 1d feature space. If higher dimensional discrete vectors are used, they should be first re-formatted to fit this structure.
continuous – The continuous indexing information, which can be a vector of arbitrary dimension and will be used to define the distance between the samples that share the same discrete index.
- search(continuous, discrete=None)#
Search closest sample based on continuous and discrete indexing information.
- Parameters:
continuous – Samples from the continuous index
discrete – Optionally matching samples from the discrete index, used to pre-select matching indices.
- search_naive(continuous, discrete)#
Brute force search Fast especially for small indices
- Parameters:
continuous – TODO
discrete – TODO
- search_iterative(continuous, discrete)#
Iterative search Gets faster especially for >1e6 samples in the index.
- Parameters:
continuous – TODO
discrete – TODO
- class cebra.distributions.index.MultiSessionIndex(*indices)#
Bases:
Index
Index multiple sessions.
- Parameters:
indices – Indices for the different sessions. Indices of multi-session datasets should have matching feature dimension.
- search(query)#
Return closest element in each of the datasets.
- Parameters:
query – The query which is applied to each index of the dataset.
- Returns:
A list of indices from each session.
Discrete#
Discrete indices.
- class cebra.distributions.discrete.Discrete(samples, device='cpu', seed=None)#
Bases:
ConditionalDistribution
,HasGenerator
Resample 1-dimensional discrete data.
The distribution is fully specified by an array of discrete samples. Samples can be drawn either from the dataset directly (i.e., output samples will have the same distribution of class labels as the samples used to specify the distribution), or from a resampled data distribution where the occurrence of each class label is balanced.
- Parameters:
samples (
Tensor
) – Discrete index used for sampling
- sample_uniform(num_samples)#
Draw samples from the uniform distribution over values.
This will change the likelihood of values depending on the values in the given (discrete) index. When reindexing the dataset with the returned indices, all values in the index will appear with equal probability.
- sample_empirical(num_samples)#
Draw samples from the empirical distribution.
- sample_conditional(reference_index)#
Draw samples conditional on template samples.
- Parameters:
samples – batch of indices, typically drawn from a prior distribution. Conditional samples will match the values of these indices
- Return type:
- Returns:
batch of indices, whose values match the values corresponding to the given indices.
- class cebra.distributions.discrete.DiscreteUniform(samples, device='cpu', seed=None)#
Bases:
Discrete
,PriorDistribution
Re-sample the given indices and produce samples from a uniform distribution.
- sample_prior(num_samples)#
Draw samples from the uniform distribution over values.
This will change the likelihood of values depending on the values in the given (discrete) index. When reindexing the dataset with the returned indices, all values in the index will appear with equal probability.
- class cebra.distributions.discrete.DiscreteEmpirical(samples, device='cpu', seed=None)#
Bases:
Discrete
,PriorDistribution
Draw samples from the empirical distribution defined by the passed index.
Continuous#
Distributions for sampling from continuously indexed datasets.
- class cebra.distributions.continuous.Prior(continuous, device='cpu', seed=0)#
Bases:
PriorDistribution
,HasGenerator
An empirical prior distribution for continuous datasets.
Given the index, uniformly sample across time steps, i.e., sample from the empirical distribution.
- Parameters:
continuous (
Tensor
) – The multi-dimensional continuous index.
- sample_prior(num_samples, offset=None)#
Return uniformly sampled indices.
- Parameters:
num_samples (
int
) – The number of samples to draw from the prior distribution. This will be the length of the returned tensor.offset (
Optional
[Offset
]) – Thecebra.data.datatypes.Offset
offset to be respected when sampling indices. The minimum index sampled will beoffset.left
(inclusive), the maximum index will be the index length minusoffset.right
(exclusive).
- Return type:
- Returns:
An integer tensor of shape
num_samples
containing random indices
- class cebra.distributions.continuous.TimeContrastive(continuous=None, time_offset=1, num_samples=None, device='cpu', seed=None)#
Bases:
JointDistribution
,HasGenerator
Time contrastive learning.
Positive samples will have a distance of exactly
time_offset
samples in time.- continuous#
The multi-dimensional continuous index.
- time_offset#
The time delay between samples that form a positive pair
- num_samples#
TODO(stes) remove?
- device#
Device (cpu or gpu)
- seed#
The seed for sampling from the prior and negative distribution (TODO currentlty not used)
- sample_prior(num_samples, offset=None)#
Return a random index sample, respecting the given time offset.
Prior samples are uniformly sampled from
[0, T - t)
whereT
is the total number of samples in the index, andt
is the time offset used for sampling.- Parameters:
- Return type:
- Returns:
A
(num_samples,)
shaped tensor containing time indices from the uniform prior distribution.
- sample_conditional(reference_idx)#
Return samples from the time-contrastive conditional distribution.
The returned indices will be given by incrementing the reference indices by the specified
time_offset
. When the reference indices are sampled withsample_prior()
, it is ensured that the indices all lie within the bounds of the dataset.- Parameters:
reference_idx (
Tensor
) – The time indices of the reference samples- Return type:
- Returns:
A
(len(reference_idx),)
shaped tensor containing time indices from the time-contrastive conditional distribution. The samples will be simply offset bytime_offset
fromreference_idx
.
- class cebra.distributions.continuous.DirectTimedeltaDistribution(continuous, time_offset=1)#
Bases:
TimeContrastive
,HasGenerator
Look up indices with
- class cebra.distributions.continuous.TimedeltaDistribution(continuous, time_delta=1, device='cpu', seed=None)#
Bases:
JointDistribution
,HasGenerator
Define a conditional distribution based on behavioral changes over time.
Takes a continuous index, and uses the empirical distribution of differences between samples in this index.
- Parameters:
Note
For best results, the given continuous index should contain independent factors; positive pairs will be formed by adding a _random_ difference estimated within the dataset to the reference samples. Factors should ideally also be within the same range (since the Euclidean distance is used in the search). A simple solution is to perform a PCA or ICA, or apply CEBRA first before passing the index to this function.
- sample_prior(num_samples)#
See
Prior.sample_prior()
.- Return type:
- class cebra.distributions.continuous.DeltaNormalDistribution(continuous, delta=0.1, device='cpu', seed=None)#
Bases:
JointDistribution
,HasGenerator
Define a conditional distribution based on behavioral changes over time.
Takes a continuous index, and uses sample from Gaussian distribution to sample positive pairs. Note that if the continuous index is multidimensional, the Gaussian distribution will have isotropic covariance matrix i.e. Σ = sigma^2 * I.
- Parameters:
- sample_prior(num_samples)#
See
Prior.sample_prior()
.- Return type:
- class cebra.distributions.continuous.CEBRADistribution#
Bases:
JointDistribution
Use CEBRA embeddings for defining a conditional distribution.
Mixed#
Distributions with a mix of continuous/discrete auxiliary variables.
- class cebra.distributions.mixed.ConfigurableDistribution#
Bases:
object
Experimental. Do not use yet.
- configure_prior(distribution='empirical')#
Not implemented yet.
- configure_conditional()#
Not implemented yet.
- class cebra.distributions.mixed.Mixed(discrete, continuous)#
Bases:
HasDevice
Distribution over behavior variables.
Class combines sampling across continuous and discrete variables.
- sample_conditional_discrete(discrete)#
Sample conditional on the discrete samples, marginalized across continuous.
- Return type:
- sample_conditional_continuous(continuous)#
Sample conditional on the continuous samples, marginalized across discrete.
- Return type:
- sample_conditional(discrete, continuous)#
Sample conditional on the continuous and discrete samples
- Return type:
- class cebra.distributions.mixed.MixedTimeDeltaDistribution(discrete, continuous, time_delta=1)#
Bases:
TimedeltaDistribution
Combination of a time delta and discrete distribution for sampling.
Sampling from the prior uses the
DiscreteUniform
distribution. For sampling the conditional, it is ensured that the positive pairs share their behavior variable, and are then sampled according to theTimedeltaDistribution
.See also
TimedeltaDistribution
for the conditional distribution.
- sample_prior(num_samples)#
Return indices from the uniform prior distribution.
- sample_conditional(reference_idx)#
Return indices from the conditional distribution.
- Parameters:
reference_idx (
Tensor
) – The reference indices.- Return type:
- Returns:
The positive indices. The positive samples will match the reference samples in their discrete variable, and will otherwise be drawn from the
TimedeltaDistribution
.
Multi-session#
Continuous variable multi-session sampling.
- class cebra.distributions.multisession.MultisessionSampler(dataset, time_offset)#
Bases:
PriorDistribution
,ConditionalDistribution
Continuous multi-session sampling.
Align embeddings across multiple sessions, using a continuous index. The transitions between index samples are computed across all sessions.
Note
The batch dimension of positive samples are shuffled. Before applying the contrastive loss, either the reference samples need to be aligned with the positive samples, or vice versa:
>>> import cebra.distributions.multisession as cebra_distributions_multisession >>> import cebra.integrations.sklearn.dataset as cebra_sklearn_dataset >>> import cebra.data >>> import torch >>> from torch import nn >>> # Multisession training: one model per dataset (different input dimensions) >>> session1 = torch.rand(100, 30) >>> session2 = torch.rand(100, 50) >>> index1 = torch.rand(100) >>> index2 = torch.rand(100) >>> num_features = 8 >>> dataset = cebra.data.DatasetCollection( ... cebra_sklearn_dataset.SklearnDataset(session1, (index1, )), ... cebra_sklearn_dataset.SklearnDataset(session2, (index2, ))) >>> model = nn.ModuleList([ ... cebra.models.init( ... name="offset1-model", ... num_neurons=dataset.input_dimension, ... num_units=32, ... num_output=num_features, ... ) for dataset in dataset.iter_sessions()]).to("cpu") >>> sampler = cebra_distributions_multisession.MultisessionSampler(dataset, time_offset=10)
>>> # ref and pos samples from all datasets >>> ref = sampler.sample_prior(100) >>> pos, idx, rev_idx = sampler.sample_conditional(ref) >>> ref = torch.LongTensor(ref) >>> pos = torch.LongTensor(pos)
>>> # Then the embedding spaces can be concatenated >>> refs, poss = [], [] >>> for i in range(len(model)): ... refs.append(model[i](dataset._datasets[i][ref[i]])) ... poss.append(model[i](dataset._datasets[i][pos[i]])) >>> ref = torch.stack(refs, dim=0) >>> pos = torch.stack(poss, dim=0)
>>> # Now the index can be applied to the stacked features, >>> # to align reference to positive samples >>> aligned_ref = sampler.mix(ref, idx) >>> reference = aligned_ref.view(-1, num_features) >>> positive = pos.view(-1, num_features) >>> loss = (reference - positive)**2
>>> # .. or the reverse index, to align positive to reference samples >>> aligned_pos = sampler.mix(pos, rev_idx) >>> reference = ref.view(-1, num_features) >>> positive = aligned_pos.view(-1, num_features) >>> loss = (ref - pos)**2
The reason for this implementation is that
dataset[i]
will in general have different dimensions (for example, number of neurons), per session. In contrast to the reference and positive indices, this data cannot be stacked and models need to be applied session by session.After data processing, the dimensionality of the returned features matches. The resulting embeddings can be concatenated, and shuffling (across the session axis) can be applied to the reference samples, or reversed for the positive samples.
Note
This function does currently not support explicitly selected discrete indices. They should be added as dimensions to the continuous index. More weight can be added to the discrete dimensions by using larger values in one-hot coding.
- mix(array, idx)#
Re-order array elements according to the given index mapping.
The given array should be of the shape
(session, batch, ...)
and the indices should have lengthsession x batch
, representing a mapping between indices.The resulting array will be rearranged such that
out.reshape(session*batch, -1)[i] = array.reshape(session*batch, -1)[idx[i]]
For the inverse mapping, convert the indices first using
_invert_index
function.
- sample_prior(num_samples)#
Return indices for the prior distribution samples
- Parameters:
num_samples – The batch size
- Returns:
A tensor of indices. Indexing the dataset with these indices will return samples from the desired prior distribution.
- sample_conditional(idx)#
Sample from the conditional distribution.
Note
Reference samples are sampled equally between sessions.
Queries are computed for each reference as in single-session, meaning by adding a random time shift to each reference sample.
In order to guarantee the same number of positive samples per session, queries are randomly assigned to a session and its corresponding positive sample is searched in that session only.
As a result, ref/pos pairing is shuffled and can be recovered the reverse shuffle operation.
- Parameters:
idx (
Tensor
) – Reference indices, with dimension(session, batch)
.- Return type:
- Returns:
Positive indices (1st return value), which will be grouped by session and not match the reference indices. In addition, a mapping will be returned to apply the same shuffle operation that was applied to assign a query to a session along session/batch dimension to the reference indices (2nd return value), or reverse the shuffle operation (3rd return value). Returned shapes are
(session, batch), (session, batch), (session, batch)
.
- class cebra.distributions.multisession.DiscreteMultisessionSampler(dataset)#
Bases:
PriorDistribution
,ConditionalDistribution
Discrete multi-session sampling.
Discrete indices don’t need to be aligned. Positive pairs are found by matching the discrete index in randomly assigned sessions.
After data processing, the dimensionality of the returned features matches. The resulting embeddings can be concatenated, and shuffling (across the session axis) can be applied to the reference samples, or reversed for the positive samples.
- mix(array, idx)#
Re-order array elements according to the given index mapping.
The given array should be of the shape
(session, batch, ...)
and the indices should have lengthsession x batch
, representing a mapping between indices.The resulting array will be rearranged such that
out.reshape(session*batch, -1)[i] = array.reshape(session*batch, -1)[idx[i]]
For the inverse mapping, convert the indices first using
_invert_index
function.
- sample_prior(num_samples)#
Return indices for the prior distribution samples
- Parameters:
num_samples – The batch size
- Returns:
A tensor of indices. Indexing the dataset with these indices will return samples from the desired prior distribution.
- sample_conditional(idx)#
Sample from the conditional distribution.
Note
Reference samples are sampled equally between sessions.
In order to guarantee the same number of positive samples per session, reference samples are randomly assigned to a session and its corresponding positive sample is searched in that session only.
As a result, ref/pos pairing is shuffled and can be recovered the reverse shuffle operation.
- Parameters:
idx (
Tensor
) – Reference indices, with dimension(session, batch)
.- Return type:
- Returns:
Positive indices (1st return value), which will be grouped by session and not match the reference indices. In addition, a mapping will be returned to apply the same shuffle operation that was applied to assign reference samples to a session along session/batch dimension (2nd return value), or reverse the shuffle operation (3rd return value). Returned shapes are
(session, batch), (session, batch), (session, batch)
.