You can download and run the notebook locally:

Download jupyter notebook


Extended Data Figure 9: CEBRA produces consistent, highly decodable embeddings#

import plot and data loading dependencies#

[1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn.linear_model
[2]:
data = pd.read_hdf("../data/EDFigure9.h5", key="data")

Additional 4 sessions with the most neurons in the Allen visual dataset calcium recording shown for all algorithms we benchmarked#

  • For CEBRA-Behavior and \cebra-Time, we used temperature 1, time offset 10, batch size 128 and 10k training steps. For UMAP, we used cosine metric and \(n\_neighbors\) 15 and \(min\_dist\) 0.1. For tSNE, we used cosine metric and \(perplexity\) 30. For conv-pi-VAE, we trained with 600 epochs, batch size 200 and learning rate \(5\times 10^{-4}\). All methods used 10 time bins input. CEBRA was trained with 3D latent and all other methods were obtained with 2D latent dimension.

[3]:
embedding_idx = list(range(7))

fig = plt.figure(figsize=(11, 20 / 7 * len(embedding_idx)), dpi = 300)
emissions_list = [
    data["cebra"],
    data["emission_cebra_time"],
    data["emission_umap"],
    data["emission_tsne"],
    data["emission_pivae_w"],
    data["emission_pivae_wo"],
    data["emission_autolfads"],
]


for i in range(4):
    for j_, j in enumerate(embedding_idx):

        ax = fig.add_subplot(len(embedding_idx), 4, j_ * 4 + i + 1)
        idx1, idx2 = (0, 1)

        if j <= 4:
          lin = sklearn.linear_model.LinearRegression()
          lin.fit(emissions_list[j][i], emissions_list[j][0])
          fitted = lin.predict(emissions_list[j][i])
        else:
          lin = sklearn.linear_model.LinearRegression()
          lin.fit(emissions_list[j][i], emissions_list[j][3])
          fitted = lin.predict(emissions_list[j][i])


        ax.scatter(
            fitted[:, idx1],
            fitted[:, idx2],
            c=np.tile(np.arange(900), 10)[:len(fitted)],
            cmap="magma",
            s=.75,
        )
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.spines["left"].set_visible(False)
        ax.set_xticks([])
        ax.set_yticks([])

        if j <= 1:
          ax.set_aspect("equal")

        if j == 0:
            plt.title(["Mouse 1", "Mouse 2", "Mouse 3", "Mouse 4"][i])
        if i == 0:
            plt.ylabel(
                [
                    "CEBRA-Behavior",
                    "CEBRA-Time",
                    "UMAP",
                    "TSNE",
                    "pi-VAE w/ label",
                    "pi-VAE w/o label",
                    "autoLFADS"
                ][j],
            )

#plt.subplots_adjust(wspace = .1)
#plt.savefig('fig.png', bbox_inches = 'tight', transparent = True)
../../_images/cebra-figures_figures_ExtendedDataFigure9_5_0.png