You can download and run the notebook locally:

Download jupyter notebook


Extended Data Figure 5: Hypothesis testing with CEBRA.#

Example data from a hippocampus recording session (Rat 1).#

  • We test possible relationships between three experimental variables (rat location, velocity, movement direction) and the neural recordings (120 neurons, not shown).

[1]:
import glob
import sys
import numpy as np
import itertools
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import torch
import json


def load_emission(path, which):
    emission = torch.load(path)
    return emission[which]


def load_loss(path):
    emission = torch.load(path, map_location="cpu")
    return emission["loss"]


def load_args(path):
    with open(path, "r") as fh:
        idx = [5, 7, 9, -1]
        args = json.load(fh)
        position, velocity, direction, loader = [
            args["data"].split("-")[i] for i in idx
        ]
        args["direction"] = direction
        args["velocity"] = velocity
        args["position"] = position
        args["loader"] = loader
        return args, (position, velocity, direction, loader, args["num_output"])


def prepare_data():
    keys = ["position", "velocity", "direction"]
    options = list(itertools.product(keys, ["original", "shuffle"]))
    options = (
        np.array(options).reshape(3, 2, 2).transpose(1, 0, 2).reshape(6, 2).tolist()
    )
    options = list(map(tuple, options))
    print(options)

    import pandas as pd

    def rep(k):
        if len(k) == 0:
            return "none"
        if len(k) == 2:
            return "none"
        return k.pop()

    args = []
    for selection in itertools.product(options, options):
        kwargs = {opt: set() for opt in keys}
        for k, v in selection:
            kwargs[k].add(v)
        args.append(
            {
                "src": tuple(selection[0]),
                "tgt": tuple(selection[1]),
                "config": tuple(rep(kwargs[k]) for k in keys),
            }
        )

    df = pd.DataFrame(args).pivot("src", "tgt")
    return df


def scatter(data, ax):

    pca = PCA(2, whiten=False)
    pca.fit(data)
    data = pca.transform(data)[:, :]
    data /= data.max(keepdims=True)
    x, y = data[:, :2].T

    ax.scatter(x, y, s=1, alpha=0.25, edgecolor="none", c="k")  # , c = f"C{n}")


def apply_style(ax):
    sns.despine(ax=ax, trim=True, left=True, bottom=True)
    ax.set_xlim([-1, 1])
    ax.set_ylim([-1, 1])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.tick_params(axis="both", which="both", length=0)
    ax.set_aspect("equal")


def scatter_grid(df, dim):
    fig, axes = plt.subplots(6, 6, figsize=(5, 5), dpi=200)

    df = df.iloc[[0, 2, 4, 1, 3, 5], [0, 2, 4, 1, 3, 5]]

    for i, (index, row) in enumerate(zip(df.index, axes)):
        for j, (column, ax) in enumerate(zip(df.columns, row)):
            apply_style(ax)
            if i == 5:
                ax.set_xlabel("{}\n{}".format(*column[1]), fontsize=6)
            if j == 0:
                ax.set_ylabel("{}\n{}".format(*index), fontsize=6)
            if j > i:
                continue
            entry = df[column][index]
            key = entry + ("continuous", dim, "train")
            if entry == ("none", "none", "none"):
                continue
            emission = files[key]

            scatter(emission[:, :], ax=ax)


def plot_heatmap(df, dim):
    def to_heatmap(entry):
        key = entry + ("continuous", dim, "loss")
        return min(files.get(key, [float("nan")]))

    df = df.iloc[[0, 2, 4, 1, 3, 5], [0, 2, 4, 1, 3, 5]]
    labels = ["{}\n{}".format(*i) for i in df.index]
    hm = df.applymap(to_heatmap)
    for i, j in itertools.product(range(len(hm)), range(len(hm))):
        if j > i:
            hm.iloc[i, j] = float("nan")

    plt.figure(figsize=(4, 4), dpi=150)
    sns.heatmap(
        hm,
        cmap="gray",
        xticklabels=labels,
        yticklabels=labels,
        annot=True,
        fmt=".2f",
        square=True,
    )
    plt.xlabel("")
    plt.ylabel("")
    plt.gca().set_xticklabels(labels, fontsize=8)
    plt.gca().set_yticklabels(labels, fontsize=8)
    plt.show()


df = prepare_data()

emission_files = sorted(glob.glob("../data/EDFigure5/*/emission*"))
arg_files = sorted(glob.glob("../data/EDFigure5/*/args.json"))

losses = [load_emission(e, "losses") for e in emission_files]
emissions = [load_emission(e, "train") for e in emission_files]
emissions_test = [load_emission(e, "test") for e in emission_files]
args = [load_args(path) for path in arg_files]
files = {keys + ("train",): values for (_, keys), values in zip(args, emissions)}
files.update(
    {keys + ("test",): values for (_, keys), values in zip(args, emissions_test)}
)
files.update({keys + ("loss",): values for (_, keys), values in zip(args, losses)})
[('position', 'original'), ('velocity', 'original'), ('direction', 'original'), ('position', 'shuffle'), ('velocity', 'shuffle'), ('direction', 'shuffle')]

Relationship between velocity and position.#

[2]:
import os
import joblib


def prepare_dataset(cache="../data/EDFigure5/achilles_behavior.jl"):
    """If you have CEBRA installed, you can adapt this function to load other
    rats, or adapt the data loading in other ways. Otherwise, pre-computed data
    will be loaded."""
    if not os.path.exists(cache):
        import cebra.datasets

        DATADIR = "/home/stes/projects/neural_cl/cebra_public/data"
        dataset = cebra.datasets.init(
            "rat-hippocampus-achilles-3fold-trial-split-0", root=DATADIR
        )
        index = dataset.index
        joblib.dump(index, cache)
    return joblib.load(cache)
[3]:
import scipy.signal
import matplotlib.pyplot as plt
import seaborn as sns


def dataset_plots(index):
    """Make a summary plot of the hippocampus dataset.

    Args:
      index is the [position, left, right] behavior index also used for training
      CEBRA models.
    """

    # preprocess the dataset
    position = index[:, 0]
    velocity = scipy.signal.savgol_filter(
        index[:, 0], window_length=31, polyorder=2, deriv=1
    )
    accel = scipy.signal.savgol_filter(
        index[:, 0], window_length=11, polyorder=2, deriv=2
    )
    direction = index[:, 1] - index[:, 2]

    # plot figure
    fig, axes = plt.subplots(3, 1, figsize=(12, 3), sharex=True, dpi=150)
    axes[0].plot(index[:, 0], c="k")
    axes[0].set_ylim([-0.05, 1.65])
    axes[0].set_yticks([0, 1.6])
    axes[0].set_yticklabels([0, 1.6])

    axes[1].plot(velocity, c="k")
    axes[2].plot(direction, c="k")

    for i, ax in enumerate(axes):
        sns.despine(ax=ax, trim=True, bottom=i != 2)
    for ax, lbl in zip(axes, ["pos", "vel", "dir"]):
        ax.set_ylabel(lbl)

    plt.show()

    plt.figure(figsize=(3, 3), dpi=150)
    plt.plot(index[:, 0], velocity, linewidth=0.5, color="black", alpha=0.8)
    plt.xlabel("position")
    plt.ylabel("velocity")
    sns.despine(trim=True)
    plt.show()


index = prepare_dataset()
dataset_plots(index)
../../_images/cebra-figures_figures_ExtendedDataFigure5_5_0.png
../../_images/cebra-figures_figures_ExtendedDataFigure5_5_1.png

Training CEBRA with three-dimensional outputs on every single experimental variable (main diagonal) and every combination of two variables.#

  • All variables are treated as “continuous” in this experiment. We compare original to shuffled variables (shuffling is done by permuting all samples over the time dimension) as a control. We project the original three dimensional space onto the first principal components. We show the minimum value of the InfoNCE loss on the trained embedding for all combinations in the confusion matrix (lower number is better). Either velocity or direction, paired with position information is needed for maximum structure in the embedding (highlighted, colored), yielding lowest InfoNCE error.

[4]:
scatter_grid(df, 3)
plot_heatmap(df, 3)
../../_images/cebra-figures_figures_ExtendedDataFigure5_7_0.png
../../_images/cebra-figures_figures_ExtendedDataFigure5_7_1.png

Using an eight-dimensional CEBRA embedding does not qualitatively alter the results.#

  • We again report the first two principal components as well as InfoNCE training error upon convergence, and find non-trivial embeddings with lowest training error for combinations of direction/velocity and position.

[5]:
scatter_grid(df, 8)
plot_heatmap(df, 8)
../../_images/cebra-figures_figures_ExtendedDataFigure5_9_0.png
../../_images/cebra-figures_figures_ExtendedDataFigure5_9_1.png
[6]:
hypothesis_loss = joblib.load("../data/EDFigure5/hypothesis_testing_loss.jl")

fig = plt.figure(figsize=(5, 5))
ax = plt.subplot(111)
ax.plot(hypothesis_loss["posdir"], c="deepskyblue", label="position+direction")
ax.plot(hypothesis_loss["pos"], c="deepskyblue", alpha=0.3, label="position")
ax.plot(hypothesis_loss["dir"], c="deepskyblue", alpha=0.6, label="direction")
ax.plot(hypothesis_loss["posdir-shuffled"], c="gray", label="pos+dir, shuffled")
ax.plot(
    hypothesis_loss["pos-shuffled"], c="gray", alpha=0.3, label="position, shuffled"
)
ax.plot(
    hypothesis_loss["dir-shuffled"], c="gray", alpha=0.6, label="direction, shuffled"
)

ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("Iterations")
ax.set_ylabel("InfoNCE Loss")
plt.legend(bbox_to_anchor=(0.5, 0.3), frameon=False)
plt.show()
../../_images/cebra-figures_figures_ExtendedDataFigure5_10_0.png
[7]:
hypothesis_decode = joblib.load("../data/EDFigure5/hypothesis_testing_decode.jl")
[8]:
fig = plt.figure(figsize=(10, 4))
ax1 = plt.subplot(121)
ax1.bar(
    np.arange(6),
    [
        hypothesis_decode["posdir"],
        hypothesis_decode["pos"],
        hypothesis_decode["dir"],
        hypothesis_decode["posdir-shuffled"],
        hypothesis_decode["pos-shuffled"],
        hypothesis_decode["dir-shuffled"],
    ],
    width=0.5,
    color="gray",
)
ax1.spines["top"].set_visible(False)
ax1.spines["right"].set_visible(False)
ax1.set_xticks(np.arange(6))
ax1.set_xticklabels(
    ["pos+dir", "pos", "dir", "pos+dir,\nshuffled", "pos,\nshuffled", "dir,\nshuffled"]
)
ax1.set_ylabel("Median err. [m]")

ax2 = plt.subplot(122)
ax2.scatter(
    hypothesis_loss["posdir"][-1],
    hypothesis_decode["posdir"],
    s=50,
    c="deepskyblue",
    label="position+direction",
)
ax2.scatter(
    hypothesis_loss["pos"][-1],
    hypothesis_decode["pos"],
    s=50,
    c="deepskyblue",
    alpha=0.3,
    label="position_only",
)
ax2.scatter(
    hypothesis_loss["dir"][-1],
    hypothesis_decode["dir"],
    s=50,
    c="deepskyblue",
    alpha=0.6,
    label="direction_only",
)
ax2.scatter(
    hypothesis_loss["posdir-shuffled"][-1],
    hypothesis_decode["posdir-shuffled"],
    s=50,
    c="gray",
    label="pos+dir, shuffled",
)
ax2.scatter(
    hypothesis_loss["pos-shuffled"][-1],
    hypothesis_decode["pos-shuffled"],
    s=50,
    c="gray",
    alpha=0.3,
    label="position, shuffled",
)
ax2.scatter(
    hypothesis_loss["dir-shuffled"][-1],
    hypothesis_decode["dir-shuffled"],
    s=50,
    c="gray",
    alpha=0.6,
    label="direction, shuffled",
)

ax2.spines["top"].set_visible(False)
ax2.spines["right"].set_visible(False)
ax2.set_xlabel("InfoNCE Loss")
ax2.set_ylabel("Decoding Median Err.")
plt.legend(bbox_to_anchor=(1, 1), frameon=False)
plt.show()
../../_images/cebra-figures_figures_ExtendedDataFigure5_12_0.png