Forelimb dynamics, somatosensory (S1)#
this notebook will demo how to use CEBRA on the primate reaching data (shown in Fig. 3, Extended Data Fig. 8).
Specifically, it shows how to use the standard infoNCE loss with CEBRA.
Install note: be sure you have demo dependencies installed to use this notebook:
[ ]:
!pip install --pre 'cebra[dev,demos]'
[1]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib as jl
import cebra.datasets
from cebra import CEBRA
Let’s load the data:#
The data will be automatically downloaded into a
/data
folder.
[2]:
monkey_pos = cebra.datasets.init('area2-bump-pos-active')
monkey_target = cebra.datasets.init('area2-bump-target-active')
100%|██████████| 63.9M/63.9M [00:08<00:00, 7.93MB/s]
Download complete. Dataset saved in 'data/monkey_reaching_preload_smth_40/active_all.jl'
For a quick CPU run-time demo, you can drop
max_iterations
to 50-1000; otherwise set to 5000+.
[4]:
max_iterations = 500
Define a model that uses positional information as the auxililary variable:#
[5]:
cebra_pos_model = CEBRA(model_architecture='offset10-model',
batch_size=512,
learning_rate=0.0001,
temperature=1,
output_dimension=3,
max_iterations=max_iterations,
distance='cosine',
conditional='time_delta',
device='cuda_if_available',
verbose=True,
time_offsets=10)
Fit the model:#
[6]:
cebra_pos_model.fit(monkey_pos.neural, monkey_pos.continuous_index.numpy())
cebra_pos = cebra_pos_model.transform(monkey_pos.neural)
_continuous_index
_discrete_index
delta
_seed
_seed
pos: 0.1635 neg: 5.4339 total: 5.5974: 100%|███████████████████████████████████████████████████| 500/500 [08:24<00:00, 1.01s/it]
Plot the results:#
[7]:
%matplotlib notebook
fig = plt.figure(figsize=(12, 5))
plt.suptitle('CEBRA-behavior trained with position label',
fontsize=20)
ax = plt.subplot(121, projection = '3d')
ax.set_title('x', fontsize=20, y=0)
x = ax.scatter(cebra_pos[:, 0],
cebra_pos[:, 1],
cebra_pos[:, 2],
c=monkey_pos.continuous_index[:, 0],
cmap='seismic',
s=0.05,
vmin=-15,
vmax=15)
ax.axis('off')
ax = plt.subplot(122, projection = '3d')
y = ax.scatter(cebra_pos[:, 0],
cebra_pos[:, 1],
cebra_pos[:, 2],
c=monkey_pos.continuous_index[:, 1],
cmap='seismic',
s=0.05,
vmin=-15,
vmax=15)
ax.axis('off')
ax.set_title('y', fontsize=20, y=0)
yc = plt.colorbar(y, fraction=0.03, pad=0.05, ticks=np.linspace(-15, 15, 7))
yc.ax.tick_params(labelsize=15)
yc.ax.set_title("(cm)", fontsize=10)
plt.show()
Define a model that uses TARGET information as the auxililary variable:#
[8]:
cebra_target_model = CEBRA(model_architecture='offset10-model',
batch_size=512,
learning_rate=0.0001,
temperature=1,
output_dimension=3,
max_iterations=max_iterations,
distance='cosine',
conditional='time_delta',
device='cuda_if_available',
verbose=True,
time_offsets=10)
[9]:
cebra_target_model.fit(monkey_target.neural,
monkey_target.discrete_index.numpy())
cebra_target = cebra_target_model.transform(monkey_target.neural)
_continuous_index
_discrete_index
_seed
pos: 0.2610 neg: 5.4770 total: 5.7379: 100%|███████████████████████████████████████████████████| 500/500 [00:28<00:00, 17.46it/s]
[10]:
fig = plt.figure(figsize=(4, 2), dpi=300)
plt.suptitle('CEBRA-behavior trained with target label',
fontsize=5)
ax = plt.subplot(121, projection = '3d')
ax.set_title('All trials embedding', fontsize=5, y=-0.1)
x = ax.scatter(cebra_target[:, 0],
cebra_target[:, 1],
cebra_target[:, 2],
c=monkey_target.discrete_index,
cmap=plt.cm.hsv,
s=0.01)
ax.axis('off')
ax = plt.subplot(122,projection = '3d')
ax.set_title('direction-averaged embedding', fontsize=5, y=-0.1)
for i in range(8):
direction_trial = (monkey_target.discrete_index == i)
trial_avg = cebra_target[direction_trial, :].reshape(-1, 600,
3).mean(axis=0)
trial_avg_normed = trial_avg/np.linalg.norm(trial_avg, axis=1)[:,None]
ax.scatter(trial_avg_normed[:, 0],
trial_avg_normed[:, 1],
trial_avg_normed[:, 2],
color=plt.cm.hsv(1 / 8 * i),
s=0.01)
ax.axis('off')
plt.show()
Define a model that uses time only information:#
[11]:
cebra_time_model = CEBRA(model_architecture='offset10-model',
batch_size=512,
learning_rate=0.0003,
temperature=1,
output_dimension=3,
max_iterations=max_iterations,
distance='cosine',
conditional='time',
device='cuda_if_available',
verbose=True,
time_offsets=5)
[12]:
cebra_time_model.fit(monkey_target.neural)
cebra_time = cebra_time_model.transform(monkey_target.neural)
_continuous_index
_discrete_index
_seed
pos: 0.0012 neg: 5.4089 total: 5.4101: 100%|███████████████████████████████████████████████████| 500/500 [00:27<00:00, 18.49it/s]
[13]:
fig = plt.figure(figsize=(4, 2), dpi=300)
plt.suptitle('CEBRA-time', fontsize=5)
ax = plt.subplot(121, projection='3d')
ax.set_title('x', fontsize=4, y=-0.1)
x = ax.scatter(cebra_time[:, 0],
cebra_time[:, 1],
cebra_time[:, 2],
c=monkey_pos.continuous_index[:, 0],
cmap='seismic',
s=0.05,
vmin=-15,
vmax=15)
ax.axis('off')
ax = plt.subplot(122, projection='3d')
y = ax.scatter(cebra_time[:, 0],
cebra_time[:, 1],
cebra_time[:, 2],
c=monkey_pos.continuous_index[:, 1],
cmap='seismic',
s=0.05,
vmin=-15,
vmax=15)
ax.axis('off')
ax.set_title('y', fontsize=5, y=-0.1)
yc = plt.colorbar(y, fraction=0.03, pad=0.05, ticks=np.linspace(-15, 15, 7))
yc.ax.tick_params(labelsize=3)
yc.ax.set_title("(cm)", fontsize=5)
plt.show()
[ ]: