You can download and run the notebook locally or run it with Google Colaboratory:

Download jupyter notebook Run on Colab


Forelimb dynamics, somatosensory (S1)#

  • this notebook will demo how to use CEBRA on the primate reaching data (shown in Fig. 3, Extended Data Fig. 8).

  • Specifically, it shows how to use the standard infoNCE loss with CEBRA.

  • Install note: be sure you have demo dependencies installed to use this notebook:

[ ]:
!pip install --pre 'cebra[dev,demos]'
[1]:
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib as jl
import cebra.datasets
from cebra import CEBRA

Let’s load the data:#

  • The data will be automatically downloaded into a /data folder.

[2]:
monkey_pos = cebra.datasets.init('area2-bump-pos-active')
monkey_target = cebra.datasets.init('area2-bump-target-active')
100%|██████████| 63.9M/63.9M [00:08<00:00, 7.93MB/s]
Download complete. Dataset saved in 'data/monkey_reaching_preload_smth_40/active_all.jl'
  • For a quick CPU run-time demo, you can drop max_iterations to 50-1000; otherwise set to 5000+.

[4]:
max_iterations = 500

Define a model that uses positional information as the auxililary variable:#

[5]:
cebra_pos_model = CEBRA(model_architecture='offset10-model',
                        batch_size=512,
                        learning_rate=0.0001,
                        temperature=1,
                        output_dimension=3,
                        max_iterations=max_iterations,
                        distance='cosine',
                        conditional='time_delta',
                        device='cuda_if_available',
                        verbose=True,
                        time_offsets=10)

Fit the model:#

[6]:
cebra_pos_model.fit(monkey_pos.neural, monkey_pos.continuous_index.numpy())
cebra_pos = cebra_pos_model.transform(monkey_pos.neural)
_continuous_index
_discrete_index
delta
_seed
_seed
pos:  0.1635 neg:  5.4339 total:  5.5974: 100%|███████████████████████████████████████████████████| 500/500 [08:24<00:00,  1.01s/it]

Plot the results:#

[7]:
%matplotlib notebook
fig = plt.figure(figsize=(12, 5))
plt.suptitle('CEBRA-behavior trained with position label',
             fontsize=20)
ax = plt.subplot(121, projection = '3d')
ax.set_title('x', fontsize=20, y=0)
x = ax.scatter(cebra_pos[:, 0],
               cebra_pos[:, 1],
               cebra_pos[:, 2],
               c=monkey_pos.continuous_index[:, 0],
               cmap='seismic',
               s=0.05,
               vmin=-15,
               vmax=15)
ax.axis('off')
ax = plt.subplot(122, projection = '3d')
y = ax.scatter(cebra_pos[:, 0],
               cebra_pos[:, 1],
               cebra_pos[:, 2],
               c=monkey_pos.continuous_index[:, 1],
               cmap='seismic',
               s=0.05,
               vmin=-15,
               vmax=15)
ax.axis('off')
ax.set_title('y', fontsize=20, y=0)
yc = plt.colorbar(y, fraction=0.03, pad=0.05, ticks=np.linspace(-15, 15, 7))
yc.ax.tick_params(labelsize=15)
yc.ax.set_title("(cm)", fontsize=10)
plt.show()

Define a model that uses TARGET information as the auxililary variable:#

[8]:
cebra_target_model = CEBRA(model_architecture='offset10-model',
                           batch_size=512,
                           learning_rate=0.0001,
                           temperature=1,
                           output_dimension=3,
                           max_iterations=max_iterations,
                           distance='cosine',
                           conditional='time_delta',
                           device='cuda_if_available',
                           verbose=True,
                           time_offsets=10)
[9]:
cebra_target_model.fit(monkey_target.neural,
                       monkey_target.discrete_index.numpy())
cebra_target = cebra_target_model.transform(monkey_target.neural)
_continuous_index
_discrete_index
_seed
pos:  0.2610 neg:  5.4770 total:  5.7379: 100%|███████████████████████████████████████████████████| 500/500 [00:28<00:00, 17.46it/s]
[10]:

fig = plt.figure(figsize=(4, 2), dpi=300) plt.suptitle('CEBRA-behavior trained with target label', fontsize=5) ax = plt.subplot(121, projection = '3d') ax.set_title('All trials embedding', fontsize=5, y=-0.1) x = ax.scatter(cebra_target[:, 0], cebra_target[:, 1], cebra_target[:, 2], c=monkey_target.discrete_index, cmap=plt.cm.hsv, s=0.01) ax.axis('off') ax = plt.subplot(122,projection = '3d') ax.set_title('direction-averaged embedding', fontsize=5, y=-0.1) for i in range(8): direction_trial = (monkey_target.discrete_index == i) trial_avg = cebra_target[direction_trial, :].reshape(-1, 600, 3).mean(axis=0) trial_avg_normed = trial_avg/np.linalg.norm(trial_avg, axis=1)[:,None] ax.scatter(trial_avg_normed[:, 0], trial_avg_normed[:, 1], trial_avg_normed[:, 2], color=plt.cm.hsv(1 / 8 * i), s=0.01) ax.axis('off') plt.show()

Define a model that uses time only information:#

[11]:
cebra_time_model = CEBRA(model_architecture='offset10-model',
                         batch_size=512,
                         learning_rate=0.0003,
                         temperature=1,
                         output_dimension=3,
                         max_iterations=max_iterations,
                         distance='cosine',
                         conditional='time',
                         device='cuda_if_available',
                         verbose=True,
                         time_offsets=5)
[12]:
cebra_time_model.fit(monkey_target.neural)
cebra_time = cebra_time_model.transform(monkey_target.neural)
_continuous_index
_discrete_index
_seed
pos:  0.0012 neg:  5.4089 total:  5.4101: 100%|███████████████████████████████████████████████████| 500/500 [00:27<00:00, 18.49it/s]
[13]:
fig = plt.figure(figsize=(4, 2), dpi=300)
plt.suptitle('CEBRA-time', fontsize=5)
ax = plt.subplot(121, projection='3d')
ax.set_title('x', fontsize=4, y=-0.1)
x = ax.scatter(cebra_time[:, 0],
               cebra_time[:, 1],
               cebra_time[:, 2],
               c=monkey_pos.continuous_index[:, 0],
               cmap='seismic',
               s=0.05,
               vmin=-15,
               vmax=15)
ax.axis('off')
ax = plt.subplot(122, projection='3d')
y = ax.scatter(cebra_time[:, 0],
               cebra_time[:, 1],
               cebra_time[:, 2],
               c=monkey_pos.continuous_index[:, 1],
               cmap='seismic',
               s=0.05,
               vmin=-15,
               vmax=15)
ax.axis('off')
ax.set_title('y', fontsize=5, y=-0.1)
yc = plt.colorbar(y, fraction=0.03, pad=0.05, ticks=np.linspace(-15, 15, 7))
yc.ax.tick_params(labelsize=3)
yc.ax.set_title("(cm)", fontsize=5)
plt.show()
[ ]: