You can download and run the notebook locally:

Download jupyter notebook


Extended Data Figure 3: CEBRA produces consistent, highly decodable embeddings#

  • Additional rat data shown for all algorithms we benchmarked (see Methods). CEBRA was trained with output latent on the 2-sphere (the minimum) and all other methods were obtained with a 2D latent in Euclidean space.

[1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

df = pd.concat([
  pd.read_hdf("../data/EDFigure3.h5", key="data"),
  pd.read_hdf("../data/EDFigure3_addition.h5", key="data")
], axis = 0, ignore_index = True)

def scatter(data, index, ax, s=0.01, alpha=0.5):
    mask = index[:, 1] > 0
    ax.scatter(*data[mask].T, c=index[mask, 0], s=s, cmap="viridis", alpha=alpha)
    ax.scatter(*data[~mask].T, c=index[~mask, 0], s=s, cmap="cool", alpha=alpha)


fig = plt.figure(figsize=(4 * 3, 7 * 3), dpi=600)
for i in df.index:
    ax = fig.add_subplot(7, 4, i + 1)
    scatter(df.loc[i, "emission"][:, :2], df.loc[i, "labels"], ax=ax, s=0.5, alpha=0.7)
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.set_xticks([])
    ax.set_yticks([])
    sns.despine(bottom=True, left=True, ax=ax)
    # first row labels
    if i // 4 == 0:
        ax.set_title(f"Rat {df.loc[i, 'animal']}", fontsize=18)
    # first column labels
    if i % 4 == 0:
        ax.set_ylabel(df.loc[i, "method"])
../../_images/cebra-figures_figures_ExtendedDataFigure3_2_0.png

For a higher resolution plot, we export each row as a separate file:

[2]:
def scatter(data, index, ax, s=0.01, alpha=0.5):
    mask = index[:, 1] > 0
    ax.scatter(*data[mask].T, c=index[mask, 0], s=s, cmap="viridis", alpha=alpha)
    ax.scatter(*data[~mask].T, c=index[~mask, 0], s=s, cmap="cool", alpha=alpha)

def export_highres():
  for method in df.method.unique():
      print(method)
      fig = plt.figure(figsize=(4 * 3, 1 * 3), dpi=600)
      entry = df[df.method == method].set_index("animal")
      for i, animal in enumerate(sorted(entry.index)):
          ax = fig.add_subplot(1, 4, i + 1)
          scatter(
            entry.loc[animal, "emission"][:, :2],
            entry.loc[animal, "labels"],
            ax=ax, s=0.5, alpha=0.7
          )
          ax.set_yticklabels([])
          ax.set_xticklabels([])
          ax.set_xticks([])
          ax.set_yticks([])
          ax.set_aspect("equal")
          sns.despine(bottom=True, left=True, ax=ax)
          # first row labels
          #if i // 4 == 0:
          #    ax.set_title(f"Rat {df.loc[i, 'animal']}")
          # first column labels
          if i % 4 == 0:
              ax.set_ylabel(method)
      method = method.replace('/', '-')
      plt.savefig(f'edf3_{method}.png', bbox_inches = "tight", transparent = True)
      plt.show()

export_highres()
CEBRA-Behavior
../../_images/cebra-figures_figures_ExtendedDataFigure3_4_1.png
conv-piVAE w/labels
../../_images/cebra-figures_figures_ExtendedDataFigure3_4_3.png
CEBRA-Time
../../_images/cebra-figures_figures_ExtendedDataFigure3_4_5.png
conv-piVAE
../../_images/cebra-figures_figures_ExtendedDataFigure3_4_7.png
tSNE
../../_images/cebra-figures_figures_ExtendedDataFigure3_4_9.png
UMAP
../../_images/cebra-figures_figures_ExtendedDataFigure3_4_11.png
autoLFADS
../../_images/cebra-figures_figures_ExtendedDataFigure3_4_13.png