Distributions#

Distributions and indexing helper functions for training CEBRA models.

This package contains classes for sampling and indexing of datasets. Typically, the functionality of classes in this module is guided by the auxiliary variables of CEBRA. A dataset would pass auxiliary variables to a sampler, and within the sampler the indices of reference, negative and positive samples will be sampled based on the auxiliary information. Custom ways of sampling should therefore be implemented in this package. Functionality in this package is fully agnostic to the actual signal to be analysed, and only considers the auxiliary information of a dataset (called “index”).

Distributions take data samples and allow to sample or re-sample from the dataset. Sampling from the prior distribution is done via “sample_prior”, sampling from the conditional distribution via “sample_conditional”.

For fast lookups in datasets, indexing classes provide 1-nearest-neighbor searches with L2 and cosine similarity metrics (recommended on GPU) or using standard multi-threaded dataloading with FAISS as the backend for retrieving data.

Base classes#

Abstract base classes for distributions and indices.

Contrastive learning in CEBRA requires a prior and conditional distribution. Distributions are defined in terms of _indices_ that reference samples within the dataset.

The appropriate base classes are defined in this module: An Index is the part of the dataset used to inform the prior and conditional distributions; and could for example be time, or information about an experimental condition.

class cebra.distributions.base.HasGenerator(device, seed)#

Bases: HasDevice

Base class for all distributions implementing seeding.

Parameters:
  • device (str) – The device the instance resides on, can be cpu or cuda.

  • seed (int) – The seed to use for initializing the random number generator.

Note

This class is not fully functional yet. Functionality and API might slightly change in upcoming versions. Do not rely (yet) on seeding provided by this class, but start using it to integrate seeding in all parts of cebra.distributions.

property generator: torch.Generator#

The generator object.

Can be used in many sampling methods provided by torch.

Return type:

Generator

property seed: int#

The seed used for generating random numbers in this class.

Return type:

int

to(device)#

Move the instance to the specified device.

randint(*args, **kwargs)#

Generate random integers.

See docs of torch.randint for information on the arguments.

Return type:

Tensor

property device: str#

The device of all attributes.

Can be cpu or cuda.

Return type:

str

class cebra.distributions.base.Index(device=None)#

Bases: ABC, HasDevice

Base class for indexed datasets.

Indexes contain functionality to pass a query vector, and return the indices of the closest matches within the index.

abstract search(query)#

Return index of entry closest to query.

Parameters:

query – The query tensor to look for. The index computes the closest element to this query tensor and returns its location within the index. (TODO: add type)

Return type:

Tensor

Returns:

The index of the element closest to the query in the dataset.

class cebra.distributions.base.PriorDistribution#

Bases: ABC

Mixin for all prior distributions.

Prior distributions return a batch of indices. Indexing the dataset with these indices will return samples from the prior distribution.

abstract sample_prior(num_samples)#

Return indices for the prior distribution samples

Parameters:

num_samples (int) – The batch size

Return type:

Tensor

Returns:

A tensor of indices. Indexing the dataset with these indices will return samples from the desired prior distribution.

class cebra.distributions.base.ConditionalDistribution#

Bases: ABC

Mixin for all conditional distributions.

Conditional distributions return a batch of indices, based on a given batch of indices. Indexing the dataset with these indices will return samples from the conditional distribution.

abstract sample_conditional(query)#

Return indices for the conditional distribution samples

Parameters:

query (Tensor) – Indices of reference samples

Return type:

Tensor

Returns:

A tensor of indices. Indexing the dataset with these indices will return samples from the desired conditional distribution.

class cebra.distributions.base.JointDistribution#

Bases: PriorDistribution, ConditionalDistribution

Mixin for joint distributions.

sample_joint(num_samples)#

Return indices from the joint distribution.

Parameters:

num_samples (int) – Desired batch size

Returns:

tuple containing indices of the reference and positive samples

Index#

Index operations for conditional sampling.

Indexing operations—in contrast to data distributions—exhibit deterministic behavior by returning an element closest in the dataset to a given query sample. This module contains helper functions for mixed and continuously indexed datasets (i.e., containing discrete and/or continuous data).

Discrete data has to come in the format of a single label for each datapoint. Multidimensional discrete labels should be converted accordingly.

class cebra.distributions.index.DistanceMatrix(samples)#

Bases: HasDevice

Compute shortest distances between dataset samples.

Parameters:

samples (Tensor) – The continuous values that will be used to index the dataset and specify the conditional distribution.

Note

This implementation is not particularly efficient on very large datasets. For these cases, packages like FAISS offer more optimized retrieval functions.

As a rule of thumb, using this class is suitable for datasets for which the dataset can be hosted on GPU memory.

class cebra.distributions.index.OffsetDistanceMatrix(samples, offset=1)#

Bases: DistanceMatrix

Compute shortest distances, ignoring samples close to the boundary.

Compared to the standard DistanceMatrix, this class should be used for datasets and learning setups where multiple timesteps are fed into the network at once — the samples close to the time-series boundary should be ignored in the sampling process in these cases.

Parameters:
  • samples – The continuous values that will be used to index the dataset and specify the conditional distribution.

  • offset (int) – The number of timesteps to ignored at each size of the dataset

class cebra.distributions.index.ContinuousIndex(index)#

Bases: Index, HasDevice

Naive nearest neighbor search implementation.

index: tensor(N, d)

the values used for kNN search

offset: int or (int,int)

the time offset in each direction

search(query)#

Return index location closest to query.

class cebra.distributions.index.ConditionalIndex(discrete, continuous)#

Bases: Index

Index a dataset based on both continuous and discrete information.

In contrast to the standard base.Index class, the ConditionalIndex accept both discrete and continuous indexing information.

This index considers the discrete indexing information first to identify possible positive pairs. Then among these candidate samples, behaves like an base.Index and returns the samples closest in terms of the information in the continuous index.

Parameters:
  • discrete – The discrete indexing information, which should be limited to a 1d feature space. If higher dimensional discrete vectors are used, they should be first re-formatted to fit this structure.

  • continuous – The continuous indexing information, which can be a vector of arbitrary dimension and will be used to define the distance between the samples that share the same discrete index.

search(continuous, discrete=None)#

Search closest sample based on continuous and discrete indexing information.

Parameters:
  • continuous – Samples from the continuous index

  • discrete – Optionally matching samples from the discrete index, used to pre-select matching indices.

search_naive(continuous, discrete)#

Brute force search Fast especially for small indices

Parameters:
  • continuous – TODO

  • discrete – TODO

search_iterative(continuous, discrete)#

Iterative search Gets faster especially for >1e6 samples in the index.

Parameters:
  • continuous – TODO

  • discrete – TODO

class cebra.distributions.index.MultiSessionIndex(*indices)#

Bases: Index

Index multiple sessions.

Parameters:

indices – Indices for the different sessions. Indices of multi-session datasets should have matching feature dimension.

search(query)#

Return closest element in each of the datasets.

Parameters:

query – The query which is applied to each index of the dataset.

Returns:

A list of indices from each session.

Discrete#

Discrete indices.

class cebra.distributions.discrete.Discrete(samples, device='cpu', seed=None)#

Bases: ConditionalDistribution, HasGenerator

Resample 1-dimensional discrete data.

The distribution is fully specified by an array of discrete samples. Samples can be drawn either from the dataset directly (i.e., output samples will have the same distribution of class labels as the samples used to specify the distribution), or from a resampled data distribution where the occurrence of each class label is balanced.

Parameters:

samples (Tensor) – Discrete index used for sampling

property num_samples: int#

Number of samples in the index.

Return type:

int

sample_uniform(num_samples)#

Draw samples from the uniform distribution over values.

This will change the likelihood of values depending on the values in the given (discrete) index. When reindexing the dataset with the returned indices, all values in the index will appear with equal probability.

Parameters:

num_samples (int) – Number of uniform samples to be drawn.

Return type:

Tensor

Returns:

A batch of indices from the distribution. Reindexing the index samples of this instance with the returned in indices will yield a uniform distribution across the discrete values.

sample_empirical(num_samples)#

Draw samples from the empirical distribution.

Parameters:

num_samples (int) – Number of samples to be drawn.

Return type:

Tensor

Returns:

A batch of indices from the empirical distribution, which is the uniform distribution over [0, N-1].

sample_conditional(reference_index)#

Draw samples conditional on template samples.

Parameters:

samples – batch of indices, typically drawn from a prior distribution. Conditional samples will match the values of these indices

Return type:

Tensor

Returns:

batch of indices, whose values match the values corresponding to the given indices.

class cebra.distributions.discrete.DiscreteUniform(samples, device='cpu', seed=None)#

Bases: Discrete, PriorDistribution

Re-sample the given indices and produce samples from a uniform distribution.

sample_prior(num_samples)#

Draw samples from the uniform distribution over values.

This will change the likelihood of values depending on the values in the given (discrete) index. When reindexing the dataset with the returned indices, all values in the index will appear with equal probability.

Parameters:

num_samples (int) – Number of uniform samples to be drawn.

Return type:

Tensor

Returns:

A batch of indices from the distribution. Reindexing the index samples of this instance with the returned in indices will yield a uniform distribution across the discrete values.

class cebra.distributions.discrete.DiscreteEmpirical(samples, device='cpu', seed=None)#

Bases: Discrete, PriorDistribution

Draw samples from the empirical distribution defined by the passed index.

sample_prior(num_samples)#

Draw samples from the empirical distribution.

Parameters:

num_samples (int) – Number of samples to be drawn.

Return type:

Tensor

Returns:

A batch of indices from the empirical distribution, which is the uniform distribution over [0, N-1].

Continuous#

Distributions for sampling from continuously indexed datasets.

class cebra.distributions.continuous.Prior(continuous, device='cpu', seed=0)#

Bases: PriorDistribution, HasGenerator

An empirical prior distribution for continuous datasets.

Given the index, uniformly sample across time steps, i.e., sample from the empirical distribution.

Parameters:

continuous (Tensor) – The multi-dimensional continuous index.

sample_prior(num_samples, offset=None)#

Return uniformly sampled indices.

Parameters:
  • num_samples (int) – The number of samples to draw from the prior distribution. This will be the length of the returned tensor.

  • offset (Optional[Offset]) – The cebra.data.datatypes.Offset offset to be respected when sampling indices. The minimum index sampled will be offset.left (inclusive), the maximum index will be the index length minus offset.right (exclusive).

Return type:

Tensor

Returns:

An integer tensor of shape num_samples containing random indices

class cebra.distributions.continuous.TimeContrastive(continuous=None, time_offset=1, num_samples=None, device='cpu', seed=None)#

Bases: JointDistribution, HasGenerator

Time contrastive learning.

Positive samples will have a distance of exactly time_offset samples in time.

continuous#

The multi-dimensional continuous index.

time_offset#

The time delay between samples that form a positive pair

num_samples#

TODO(stes) remove?

device#

Device (cpu or gpu)

seed#

The seed for sampling from the prior and negative distribution (TODO currentlty not used)

sample_prior(num_samples, offset=None)#

Return a random index sample, respecting the given time offset.

Prior samples are uniformly sampled from [0, T - t) where T is the total number of samples in the index, and t is the time offset used for sampling.

Parameters:
  • num_samples (int) – Number of time steps to draw uniformly from the number of available time steps in the dataset

  • offset (Optional[Offset]) – The model offset to respect for sampling from the prior. TODO not yet implemented

Return type:

Tensor

Returns:

A (num_samples,) shaped tensor containing time indices from the uniform prior distribution.

sample_conditional(reference_idx)#

Return samples from the time-contrastive conditional distribution.

The returned indices will be given by incrementing the reference indices by the specified time_offset. When the reference indices are sampled with sample_prior(), it is ensured that the indices all lie within the bounds of the dataset.

Parameters:

reference_idx (Tensor) – The time indices of the reference samples

Return type:

Tensor

Returns:

A (len(reference_idx),) shaped tensor containing time indices from the time-contrastive conditional distribution. The samples will be simply offset by time_offset from reference_idx.

class cebra.distributions.continuous.DirectTimedeltaDistribution(continuous, time_offset=1)#

Bases: TimeContrastive, HasGenerator

Look up indices with

sample_conditional(reference_idx)#

Samples from the conditional distribution.

Return type:

Tensor

class cebra.distributions.continuous.TimedeltaDistribution(continuous, time_delta=1, device='cpu', seed=None)#

Bases: JointDistribution, HasGenerator

Define a conditional distribution based on behavioral changes over time.

Takes a continuous index, and uses the empirical distribution of differences between samples in this index.

Parameters:
  • continuous – The multidimensional, continuous index

  • time_delta (int) – The time delay between samples that should form a positive pair.

  • device (Literal[‘cpu’, ‘cuda’]) – TODO

  • seed (Optional[int]) – TODO

Note

For best results, the given continuous index should contain independent factors; positive pairs will be formed by adding a _random_ difference estimated within the dataset to the reference samples. Factors should ideally also be within the same range (since the Euclidean distance is used in the search). A simple solution is to perform a PCA or ICA, or apply CEBRA first before passing the index to this function.

sample_prior(num_samples)#

See Prior.sample_prior().

Return type:

Tensor

sample_conditional(reference_idx)#

Return indices from the conditional distribution.

Return type:

Tensor

class cebra.distributions.continuous.DeltaNormalDistribution(continuous, delta=0.1, device='cpu', seed=None)#

Bases: JointDistribution, HasGenerator

Define a conditional distribution based on behavioral changes over time.

Takes a continuous index, and uses sample from Gaussian distribution to sample positive pairs. Note that if the continuous index is multidimensional, the Gaussian distribution will have isotropic covariance matrix i.e. Σ = sigma^2 * I.

Parameters:
  • continuous (Tensor) – The multidimensional, continuous index.

  • delta (float) – Standard deviation of Gaussian distribution to sample positive pair.

sample_prior(num_samples)#

See Prior.sample_prior().

Return type:

Tensor

sample_conditional(reference_idx)#

Return indices from the conditional distribution.

Return type:

Tensor

class cebra.distributions.continuous.CEBRADistribution#

Bases: JointDistribution

Use CEBRA embeddings for defining a conditional distribution.

Mixed#

Distributions with a mix of continuous/discrete auxiliary variables.

class cebra.distributions.mixed.ConfigurableDistribution#

Bases: object

Experimental. Do not use yet.

configure_prior(distribution='empirical')#

Not implemented yet.

configure_conditional()#

Not implemented yet.

class cebra.distributions.mixed.Mixed(discrete, continuous)#

Bases: HasDevice

Distribution over behavior variables.

Class combines sampling across continuous and discrete variables.

sample_conditional_discrete(discrete)#

Sample conditional on the discrete samples, marginalized across continuous.

Return type:

Tensor

sample_conditional_continuous(continuous)#

Sample conditional on the continuous samples, marginalized across discrete.

Return type:

Tensor

sample_conditional(discrete, continuous)#

Sample conditional on the continuous and discrete samples

Return type:

Tensor

sample_prior(num_samples)#

Sample from the uniform prior distribution.

Return type:

Tensor

class cebra.distributions.mixed.MixedTimeDeltaDistribution(discrete, continuous, time_delta=1)#

Bases: TimedeltaDistribution

Combination of a time delta and discrete distribution for sampling.

Sampling from the prior uses the DiscreteUniform distribution. For sampling the conditional, it is ensured that the positive pairs share their behavior variable, and are then sampled according to the TimedeltaDistribution.

See also

sample_prior(num_samples)#

Return indices from the uniform prior distribution.

Parameters:

num_samples (int) – The number of samples

Return type:

Tensor

Returns:

The reference indices of shape (num_samples, ).

sample_conditional(reference_idx)#

Return indices from the conditional distribution.

Parameters:

reference_idx (Tensor) – The reference indices.

Return type:

Tensor

Returns:

The positive indices. The positive samples will match the reference samples in their discrete variable, and will otherwise be drawn from the TimedeltaDistribution.

Multi-session#

Continuous variable multi-session sampling.

class cebra.distributions.multisession.MultisessionSampler(dataset, time_offset)#

Bases: PriorDistribution, ConditionalDistribution

Continuous multi-session sampling.

Align embeddings across multiple sessions, using a continuous index. The transitions between index samples are computed across all sessions.

Note

The batch dimension of positive samples are shuffled. Before applying the contrastive loss, either the reference samples need to be aligned with the positive samples, or vice versa:

>>> import cebra.distributions.multisession as cebra_distributions_multisession
>>> import cebra.integrations.sklearn.dataset as cebra_sklearn_dataset
>>> import cebra.data
>>> import torch
>>> from torch import nn
>>> # Multisession training: one model per dataset (different input dimensions)
>>> session1 = torch.rand(100, 30)
>>> session2 = torch.rand(100, 50)
>>> index1 = torch.rand(100)
>>> index2 = torch.rand(100)
>>> num_features = 8
>>> dataset = cebra.data.DatasetCollection(
...               cebra_sklearn_dataset.SklearnDataset(session1, (index1, )),
...               cebra_sklearn_dataset.SklearnDataset(session2, (index2, )))
>>> model = nn.ModuleList([
...                cebra.models.init(
...                    name="offset1-model",
...                    num_neurons=dataset.input_dimension,
...                    num_units=32,
...                    num_output=num_features,
...                ) for dataset in dataset.iter_sessions()]).to("cpu")
>>> sampler = cebra_distributions_multisession.MultisessionSampler(dataset, time_offset=10)
>>> # ref and pos samples from all datasets
>>> ref = sampler.sample_prior(100)
>>> pos, idx, rev_idx = sampler.sample_conditional(ref)
>>> ref = torch.LongTensor(ref)
>>> pos = torch.LongTensor(pos)
>>> # Then the embedding spaces can be concatenated
>>> refs, poss = [], []
>>> for i in range(len(model)):
...     refs.append(model[i](dataset._datasets[i][ref[i]]))
...     poss.append(model[i](dataset._datasets[i][pos[i]]))
>>> ref = torch.stack(refs, dim=0)
>>> pos = torch.stack(poss, dim=0)
>>> # Now the index can be applied to the stacked features,
>>> # to align reference to positive samples
>>> aligned_ref = sampler.mix(ref, idx)
>>> reference = aligned_ref.view(-1, num_features)
>>> positive = pos.view(-1, num_features)
>>> loss = (reference - positive)**2
>>> # .. or the reverse index, to align positive to reference samples
>>> aligned_pos = sampler.mix(pos, rev_idx)
>>> reference = ref.view(-1, num_features)
>>> positive = aligned_pos.view(-1, num_features)
>>> loss = (ref - pos)**2

The reason for this implementation is that dataset[i] will in general have different dimensions (for example, number of neurons), per session. In contrast to the reference and positive indices, this data cannot be stacked and models need to be applied session by session.

After data processing, the dimensionality of the returned features matches. The resulting embeddings can be concatenated, and shuffling (across the session axis) can be applied to the reference samples, or reversed for the positive samples.

Note

This function does currently not support explicitly selected discrete indices. They should be added as dimensions to the continuous index. More weight can be added to the discrete dimensions by using larger values in one-hot coding.

property num_sessions: int#

The number of sessions in the index.

Return type:

int

mix(array, idx)#

Re-order array elements according to the given index mapping.

The given array should be of the shape (session, batch, ...) and the indices should have length session x batch, representing a mapping between indices.

The resulting array will be rearranged such that out.reshape(session*batch, -1)[i] = array.reshape(session*batch, -1)[idx[i]]

For the inverse mapping, convert the indices first using _invert_index function.

Parameters:
  • array (ndarray) – A 2D matrix containing samples for each session.

  • idx (ndarray) – A list of indexes to re-order array on.

sample_prior(num_samples)#

Return indices for the prior distribution samples

Parameters:

num_samples – The batch size

Returns:

A tensor of indices. Indexing the dataset with these indices will return samples from the desired prior distribution.

sample_conditional(idx)#

Sample from the conditional distribution.

Note

  • Reference samples are sampled equally between sessions.

  • Queries are computed for each reference as in single-session, meaning by adding a random time shift to each reference sample.

  • In order to guarantee the same number of positive samples per session, queries are randomly assigned to a session and its corresponding positive sample is searched in that session only.

  • As a result, ref/pos pairing is shuffled and can be recovered the reverse shuffle operation.

Parameters:

idx (Tensor) – Reference indices, with dimension (session, batch).

Return type:

Tensor

Returns:

Positive indices (1st return value), which will be grouped by session and not match the reference indices. In addition, a mapping will be returned to apply the same shuffle operation that was applied to assign a query to a session along session/batch dimension to the reference indices (2nd return value), or reverse the shuffle operation (3rd return value). Returned shapes are (session, batch), (session, batch), (session, batch).

class cebra.distributions.multisession.DiscreteMultisessionSampler(dataset)#

Bases: PriorDistribution, ConditionalDistribution

Discrete multi-session sampling.

Discrete indices don’t need to be aligned. Positive pairs are found by matching the discrete index in randomly assigned sessions.

After data processing, the dimensionality of the returned features matches. The resulting embeddings can be concatenated, and shuffling (across the session axis) can be applied to the reference samples, or reversed for the positive samples.

property num_sessions: int#

The number of sessions in the index.

Return type:

int

mix(array, idx)#

Re-order array elements according to the given index mapping.

The given array should be of the shape (session, batch, ...) and the indices should have length session x batch, representing a mapping between indices.

The resulting array will be rearranged such that out.reshape(session*batch, -1)[i] = array.reshape(session*batch, -1)[idx[i]]

For the inverse mapping, convert the indices first using _invert_index function.

Parameters:
  • array (ndarray) – A 2D matrix containing samples for each session.

  • idx (ndarray) – A list of indexes to re-order array on.

sample_prior(num_samples)#

Return indices for the prior distribution samples

Parameters:

num_samples – The batch size

Returns:

A tensor of indices. Indexing the dataset with these indices will return samples from the desired prior distribution.

sample_conditional(idx)#

Sample from the conditional distribution.

Note

  • Reference samples are sampled equally between sessions.

  • In order to guarantee the same number of positive samples per session, reference samples are randomly assigned to a session and its corresponding positive sample is searched in that session only.

  • As a result, ref/pos pairing is shuffled and can be recovered the reverse shuffle operation.

Parameters:

idx (Tensor) – Reference indices, with dimension (session, batch).

Return type:

Tensor

Returns:

Positive indices (1st return value), which will be grouped by session and not match the reference indices. In addition, a mapping will be returned to apply the same shuffle operation that was applied to assign reference samples to a session along session/batch dimension (2nd return value), or reverse the shuffle operation (3rd return value). Returned shapes are (session, batch), (session, batch), (session, batch).