Technical: S1 training with MSE loss#
this notebook will demo how to use CEBRA on the primate reaching data (shown in Fig. 3, Extended Data Fig. 8).
Specifically, it shows how to use the MSE loss with CEBRA.
Install note: be sure you have demo dependencies installed to use this notebook:
[ ]:
!pip install --pre 'cebra[datasets,demos]'
[1]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib as jl
import cebra.datasets
from cebra import CEBRA
Let’s load the data:#
The data will be automatically downloaded into a
/data
folder.
[2]:
monkey_pos = cebra.datasets.init('area2-bump-pos-active')
monkey_target = cebra.datasets.init('area2-bump-target-active')
100%|██████████| 63.9M/63.9M [00:11<00:00, 5.38MB/s]
Download complete. Dataset saved in 'data/monkey_reaching_preload_smth_40/active_all.jl'
For a quick CPU run-time demo, you can drop
max_iterations
to 50-500; otherwise set to 5000.
[4]:
max_iterations = 5000 #default is 5000.
Define a model that uses positional information as the auxililary variable:#
[5]:
cebra_pos_model = CEBRA(model_architecture='offset10-model-mse',
batch_size=512,
learning_rate=5e-5,
temperature=0.01,
output_dimension=2,
max_iterations=max_iterations,
distance='euclidean',
conditional='time_delta',
device='cuda_if_available',
verbose=True,
time_offsets=10)
Fit the model:#
[6]:
cebra_pos_model.fit(monkey_pos.neural, monkey_pos.continuous_index.numpy())
cebra_pos = cebra_pos_model.transform(monkey_pos.neural)
_continuous_index
_discrete_index
delta
_seed
_seed
pos: 0.8407 neg: 4.1718 total: 5.0125: 100%|█████████████████████████████████████████████████████| 5000/5000 [54:33<00:00, 1.53it/s]
Plot the results:#
[16]:
%matplotlib notebook
fig = plt.figure(figsize=(10, 3))
plt.suptitle('CEBRA-behavior trained with position label using MSE loss',
fontsize=20)
ax = plt.subplot(121)
ax.set_title('x', fontsize=20, y=0)
x = ax.scatter(cebra_pos[:, 0],
cebra_pos[:, 1],
c=monkey_pos.continuous_index[:, 0],
cmap='seismic',
s=0.05,
vmin=-15,
vmax=15)
ax.axis('off')
ax = plt.subplot(122)
y = ax.scatter(cebra_pos[:, 0],
cebra_pos[:, 1],
c=monkey_pos.continuous_index[:, 1],
cmap='seismic',
s=0.05,
vmin=-15,
vmax=15)
ax.axis('off')
ax.set_title('y', fontsize=20, y=0)
yc = plt.colorbar(y, fraction=0.03, pad=0.05, ticks=np.linspace(-15, 15, 7))
yc.ax.tick_params(labelsize=15)
yc.ax.set_title("(cm)", fontsize=10)
plt.show()
Define a model that uses TARGET information as the auxililary variable:#
[8]:
cebra_target_model = CEBRA(model_architecture='offset10-model-mse',
batch_size=512,
learning_rate=5e-5,
temperature=0.01,
output_dimension=2,
max_iterations=max_iterations,
distance='euclidean',
conditional='time_delta',
device='cuda_if_available',
verbose=True,
time_offsets=10)
[9]:
cebra_target_model.fit(monkey_target.neural,
monkey_target.discrete_index.numpy())
cebra_target = cebra_target_model.transform(monkey_target.neural)
_continuous_index
_discrete_index
_seed
pos: 0.7948 neg: 4.8042 total: 5.5989: 100%|█████████████████████████████████████████████████████| 5000/5000 [02:08<00:00, 38.86it/s]
[10]:
%matplotlib notebook
fig = plt.figure(figsize=(12, 5), dpi=100)
plt.suptitle('CEBRA-behavior trained with target label using MSE loss',
fontsize=20)
ax = plt.subplot(121)
ax.set_title('All trials embedding', fontsize=20, y=-0.1)
x = ax.scatter(cebra_target[:, 0],
cebra_target[:, 1],
c=monkey_target.discrete_index,
cmap=plt.cm.hsv,
s=0.05)
ax.axis('off')
ax = plt.subplot(122)
ax.set_title('Post-averaged by direction', fontsize=20, y=-0.1)
for i in range(8):
direction_trial = (monkey_target.discrete_index == i)
trial_avg = cebra_target[direction_trial, :].reshape(-1, 600,
2).mean(axis=0)
ax.scatter(trial_avg[:, 0],
trial_avg[:, 1],
color=plt.cm.hsv(1 / 8 * i),
s=3)
ax.axis('off')
plt.show()
Define a model that uses time only information:#
[11]:
cebra_time_model = CEBRA(model_architecture='offset10-model-mse',
batch_size=512,
learning_rate=5e-5,
temperature=0.01,
output_dimension=2,
max_iterations=max_iterations,
distance='euclidean',
conditional='time',
device='cuda_if_available',
verbose=True,
time_offsets=10)
[12]:
cebra_time_model.fit(monkey_target.neural)
cebra_time = cebra_time_model.transform(monkey_target.neural)
_continuous_index
_discrete_index
_seed
pos: 34206528.0000 neg: 155542336.0000 total: 189748864.0000: 100%|██████████████████████████████| 5000/5000 [02:08<00:00, 38.98it/s]
[14]:
fig = plt.figure(figsize=(12, 5), dpi=100)
plt.suptitle('CEBRA-time trained using MSE loss', fontsize=20)
ax = plt.subplot(121)
ax.set_title('x', fontsize=20, y=-0.1)
x = ax.scatter(cebra_time[:, 0],
cebra_time[:, 1],
c=monkey_pos.continuous_index[:, 0],
cmap='seismic',
s=0.05,
vmin=-15,
vmax=15)
ax.axis('off')
ax = plt.subplot(122)
y = ax.scatter(cebra_time[:, 0],
cebra_time[:, 1],
c=monkey_pos.continuous_index[:, 1],
cmap='seismic',
s=0.05,
vmin=-15,
vmax=15)
ax.axis('off')
ax.set_title('y', fontsize=20, y=-0.1)
yc = plt.colorbar(y, fraction=0.03, pad=0.05, ticks=np.linspace(-15, 15, 7))
yc.ax.tick_params(labelsize=15)
yc.ax.set_title("(cm)", fontsize=10)
plt.show()
[ ]: