You can download and run the notebook locally or run it with Google Colaboratory:

Download jupyter notebook Run on Colab


Technical: Learning the temperature parameter#

  • This notebook demo will show how to use learnable temperature feature in CEBRA using rat hippocampus dataset.

  • Install note: be sure you have cebra, and the demo dependencies, installed to use this notebook:

[ ]:
!pip install --pre 'cebra[datasets,demos]'
[1]:
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib as jl
import cebra.datasets
from cebra import CEBRA

from matplotlib.collections import LineCollection
import pandas as pd

Load the data:#

  • The data will be automatically downloaded into a /data folder.

[2]:
hippocampus_pos = cebra.datasets.init('rat-hippocampus-single-achilles')
100%|██████████| 10.0M/10.0M [00:01<00:00, 9.19MB/s]
Download complete. Dataset saved in 'data/rat_hippocampus/achilles.jl'
[6]:
max_iterations = 5000

Fixed temperature model:#

  • temperature_mode is set to constant by default, which uses a fixed temperature for a model.

  • Initialize and train a model.

[9]:
cebra_fixed_temperature_model = CEBRA(model_architecture='offset10-model',
                        batch_size=512,
                        learning_rate=3e-4,
                        temperature_mode = 'constant',
                        temperature=1,
                        output_dimension=3,
                        max_iterations=max_iterations,
                        distance='cosine',
                        conditional='time_delta',
                        device='cuda_if_available',
                        verbose=True,
                        time_offsets=10)
[10]:
cebra_fixed_temperature_model.fit(hippocampus_pos.neural, hippocampus_pos.continuous_index)
fixed_temp_emb = cebra_fixed_temperature_model.transform(hippocampus_pos.neural)
_continuous_index
_discrete_index
delta
_seed
_seed
pos:  0.0898 neg:  5.4158 total:  5.5056 temperature:  1.0000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [03:06<00:00, 26.83it/s]
[10]:
CEBRA(batch_size=512, conditional='time_delta', max_iterations=5000,
      model_architecture='offset10-model', output_dimension=3, temperature=1,
      time_offsets=10, verbose=True)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Learnable temperature model:#

  • Setting temperature_mode to auto will make the temperature to be optimized during the training of a model.

  • min_temperature is the lower bound for the tempearture. If not specified, min_temperature is set to asymptotic 0.

[42]:
cebra_learnable_temperature_model = CEBRA(model_architecture='offset10-model',
                        batch_size=512,
                        learning_rate=3e-4,
                        temperature_mode = 'auto',
                        temperature=1,
                        min_temperature = 1e-1,
                        output_dimension=3,
                        max_iterations=max_iterations,
                        distance='cosine',
                        conditional='time_delta',
                        device='cuda_if_available',
                        verbose=True,
                        time_offsets=10)
[43]:
cebra_learnable_temperature_model.fit(hippocampus_pos.neural, hippocampus_pos.continuous_index)
learnable_temp_emb = cebra_learnable_temperature_model.transform(hippocampus_pos.neural)
_continuous_index
_discrete_index
delta
_seed
_seed
pos:  0.4400 neg:  4.2548 total:  4.6948 temperature:  0.2280: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [03:08<00:00, 26.58it/s]

Visualize the loss and the temperature:#

[46]:
fig = plt.figure(figsize=(10,5))
ax1 = plt.subplot(121)
ax1.plot(cebra_learnable_temperature_model.state_dict_['loss'], c= 'lightskyblue', label='Learnable temperature')
ax1.plot(cebra_fixed_temperature_model.state_dict_['loss'], c='gray', label = 'Fixed temperature')
ax1.set_ylabel('InfoNCE Loss')
ax1.set_xlabel('Steps')
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
plt.legend(loc = 'lower left', frameon = False)
ax2= plt.subplot(122)
ax2.plot(cebra_learnable_temperature_model.state_dict_['log']['temperature'], c= 'lightskyblue', label='Learnable temperature')
ax2.plot(cebra_fixed_temperature_model.state_dict_['log']['temperature'], c='gray', label = 'Fixed temperature')
ax2.set_ylabel('Temperature')
ax2.set_xlabel('Steps')
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
plt.legend(loc = 'lower left', frameon = False)
[46]:
<matplotlib.legend.Legend at 0x7f00064b4fd0>
../_images/demo_notebooks_Demo_learnable_temperature_16_1.png

Visualize the embeddings from the models#

[47]:
def plot_hippocampus(ax, embedding, label):
    r_ind = label[:,1] == 1
    l_ind = label[:,2] == 1

    ax.scatter(embedding [r_ind,0],
               embedding [r_ind,1],
               embedding [r_ind,2],
               c=label[r_ind, 0],
               cmap='cool', s=1)
    ax.scatter(embedding [l_ind,0],
               embedding [l_ind,1],
               embedding [l_ind,2],
               c=label[l_ind, 0],
               cmap='viridis', s=1)

    ax.grid(False)
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False
    ax.xaxis.pane.set_edgecolor('w')
    ax.yaxis.pane.set_edgecolor('w')
    ax.zaxis.pane.set_edgecolor('w')
    return ax

[50]:
fig=plt.figure(figsize=(10,5))
ax1 = plt.subplot(121, projection = '3d')
ax1=plot_hippocampus(ax1, fixed_temp_emb, hippocampus_pos.continuous_index)
ax1.set_title('Fixed temperature', fontsize=15)
ax2 = plt.subplot(122, projection = '3d')
ax2=plot_hippocampus(ax2, learnable_temp_emb, hippocampus_pos.continuous_index)
ax2.set_title('Learnable temperature', fontsize=15)
[50]:
Text(0.5, 0.92, 'Learnable temperature')
../_images/demo_notebooks_Demo_learnable_temperature_19_1.png