You can download and run the notebook locally:

Download jupyter notebook


Figure 5: Decoding of natural movie features from mouse visual cortical areas#

For reference, here is the full figure

Figure5

import plot and data loading dependencies#

[1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import joblib as jl
[2]:
data = pd.read_hdf("../data/Figure5Revision.h5")
[3]:
frame_np_bayes_baseline_330 = data["frame_id"]["bayes"]["np_baseline_330"]
frame_np_bayes_baseline_33 = data["frame_id"]["bayes"]["np_baseline_33"]
frame_np_knn_baseline_330 = data["frame_id"]["knn"]["np_baseline_330"]
frame_np_knn_baseline_33 = data["frame_id"]["knn"]["np_baseline_33"]
frame_np_cebra_knn_330 = data["frame_id"]["knn"]["np_cebra_330"]
frame_np_cebra_knn_33 = data["frame_id"]["knn"]["np_cebra_33"]
frame_np_joint_cebra_knn_330 = data["frame_id"]["knn"]["np_cebra_joint_330"]
frame_np_joint_cebra_knn_33 = data["frame_id"]["knn"]["np_cebra_joint_33"]


scene_np_bayes_baseline_330 = data["scene_annotation"]["bayes"]["np_baseline_330"]
scene_np_bayes_baseline_33 = data["scene_annotation"]["bayes"]["np_baseline_33"]
scene_np_knn_baseline_330 = data["scene_annotation"]["knn"]["np_baseline_330"]
scene_np_knn_baseline_33 = data["scene_annotation"]["knn"]["np_baseline_33"]
scene_np_cebra_knn_330 = data["scene_annotation"]["knn"]["np_cebra_330"]
scene_np_cebra_knn_33 = data["scene_annotation"]["knn"]["np_cebra_33"]
scene_np_joint_cebra_knn_330 = data["scene_annotation"]["knn"]["np_cebra_joint_330"]
scene_np_joint_cebra_knn_33 = data["scene_annotation"]["knn"]["np_cebra_joint_33"]

frame_np_bayes_baseline_330_err = data["frame_err"]["bayes"]["np_baseline_330"]
frame_np_knn_baseline_330_err = data["frame_err"]["knn"]["np_baseline_330"]
frame_np_cebra_knn_330_err = data["frame_err"]["knn"]["np_cebra_330"]
frame_np_joint_cebra_knn_330_err = data["frame_err"]["knn"]["np_cebra_joint_330"]

frame_np_bayes_baseline_33_err = data["frame_err"]["bayes"]["np_baseline_33"]
frame_np_knn_baseline_33_err = data["frame_err"]["knn"]["np_baseline_33"]
frame_np_cebra_knn_33_err = data["frame_err"]["knn"]["np_cebra_33"]
frame_np_joint_cebra_knn_33_err = data["frame_err"]["knn"]["np_cebra_joint_33"]

Define plotting functions & metrics#

[4]:
LINEWIDTH = 2
[5]:
num_neurons = [10, 30, 50, 100, 200, 400, 600, 800, 900, 1000]


def set_ax(ax, white_c):
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.spines["left"].set_color(white_c)
    ax.spines["bottom"].set_color(white_c)
    ax.set_xticks(
        [10, 200, 400, 600, 800, 1000],
        color=white_c,
    )
    ax.set_xticklabels([10, 200, 400, 600, 800, 1000], rotation=45)
    ax.set_yticks(
        np.linspace(0, 100, 5),
        np.linspace(0, 100, 5, dtype=int),
        color=white_c,
        rotation=45,
    )
    ax.set_xlabel("# Neurons", color=white_c)
    ax.set_ylabel("Acc (%, in 1s time window)", color=white_c)
    ax.tick_params(colors=white_c)


def n_mean_err(dic, ns=num_neurons):
    means = []
    errs = []
    for n in ns:
        means.append(np.mean(dic[n]))
        errs.append(np.std(dic[n]) / np.sqrt(len(dic[n])))
    return np.array(means), np.array(errs)


def n_mean_err_joint(dic, modality, ns=num_neurons):
    means = []
    errs = []
    if modality == "np":
        ind = 1
    elif modality == "ca":
        ind = 0
    for n in ns:
        _d = np.array(dic[n])[:, ind]
        means.append(np.mean(_d))
        errs.append(np.std(_d) / np.sqrt(len(_d)))
    return np.array(means), np.array(errs)


def n_mean_err_frame_diff(err_dict, ns=num_neurons):
    means = []
    errs = []
    for n in ns:
        err_seeds = err_dict[n]
        accs = [np.mean(abs(seed_result)) for seed_result in err_seeds]
        means.append(np.mean(accs))
        errs.append(np.std(accs) / np.sqrt(len(accs)))
    return means, errs


def n_mean_err_frame_diff_joint(err_dict, data_type="np", ns=num_neurons):
    if data_type == "np":
        index = 1
    else:
        index = 0
    means = []
    errs = []
    for n in ns:
        err_seeds = err_dict[n]
        accs = [np.mean(abs(seed_result[index])) for seed_result in err_seeds]
        means.append(np.mean(accs))
        errs.append(np.std(accs) / np.sqrt(len(accs)))
    return means, errs

Fgure 5 c#

  • Decoding accuracy measured by considering a predicted frame being within 1 sec to the true frame as a correct prediction using CEBRA (NP only), jointly trained (2P+NP), or a baseline population-vector plus kNN or Bayes decoder using either a 1 frame (33 ms) receptive field or 10 frames (330 ms); results shown for Neuropixels dataset (V1 data).

[6]:
white = False

if white:
    white_c = "white"
else:
    white_c = "black"

scale = 0.3
fig_330 = plt.figure(figsize=(7 * scale, 10 * scale), dpi=300)
plt.title("Frame identification")
plt.subplots_adjust(wspace=0.5)
ax1 = plt.subplot(111)

c = "#9932EB"

ax1.errorbar(
    num_neurons,
    n_mean_err(frame_np_knn_baseline_330)[0],
    n_mean_err(frame_np_knn_baseline_330)[1],
    ls="--",
    label="kNN baseline",
    color="k",
    alpha=0.7,
    markersize=20,
    linewidth=LINEWIDTH,
)

ax1.errorbar(
    num_neurons,
    n_mean_err(frame_np_bayes_baseline_330)[0],
    n_mean_err(frame_np_bayes_baseline_330)[1],
    ls="--",
    label="Bayes baseline",
    color="k",
    alpha=0.3,
    markersize=20,
    linewidth=LINEWIDTH,
)

ax1.errorbar(
    num_neurons,
    n_mean_err(frame_np_cebra_knn_330)[0],
    n_mean_err(frame_np_cebra_knn_330)[1],
    label="kNN CEBRA",
    color=c,
    alpha=1,
    markersize=20,
    linewidth=LINEWIDTH,
)

ax1.errorbar(
    num_neurons,
    n_mean_err_joint(frame_np_joint_cebra_knn_330, "np")[0],
    n_mean_err_joint(frame_np_joint_cebra_knn_330, "np")[1],
    ls="dotted",
    label="kNN CEBRA joint",
    color=c,
    alpha=1,
    markersize=20,
    linewidth=LINEWIDTH,
)

l1 = ax1.legend(
    loc="best",
    frameon=False,
    fontsize="small",
    title="NP, 10 frames",
    alignment="right",
)
for text in l1.get_texts():
    text.set_color(white_c)

set_ax(ax1, white_c)
ax1.set_ylim(0, 100)
sns.despine(trim=True, ax=ax1)

plt.savefig("FrameID_10frames_np.svg", transparent=True, bbox_inches="tight")
plt.show()
../../_images/cebra-figures_figures_Figure5_9_0.png
[7]:
fig_33 = plt.figure(figsize=(7 * scale, 10 * scale), dpi=300)
fig_33.suptitle("Frame identification")
plt.subplots_adjust(wspace=0.5)
ax1 = plt.subplot(111)

c = "#9932EB"

ax1.errorbar(
    num_neurons,
    n_mean_err(frame_np_knn_baseline_33)[0],
    n_mean_err(frame_np_knn_baseline_33)[1],
    ls="--",
    label="kNN baseline",
    color="k",
    alpha=0.7,
    markersize=20,
)

ax1.errorbar(
    num_neurons,
    n_mean_err(frame_np_bayes_baseline_33)[0],
    n_mean_err(frame_np_bayes_baseline_33)[1],
    ls="--",
    label="Bayes baseline",
    color="k",
    alpha=0.3,
    markersize=20,
)


ax1.errorbar(
    num_neurons,
    n_mean_err(frame_np_cebra_knn_33)[0],
    n_mean_err(frame_np_cebra_knn_33)[1],
    label="kNN CEBRA",
    color=c,
    alpha=1,
    markersize=20,
)

ax1.errorbar(
    num_neurons,
    n_mean_err_joint(frame_np_joint_cebra_knn_33, "np")[0],
    n_mean_err_joint(frame_np_joint_cebra_knn_33, "np")[1],
    ls="dotted",
    label="kNN CEBRA joint",
    color=c,
    alpha=1,
    markersize=20,
)

set_ax(ax1, white_c)

l1 = ax1.legend(
    loc="best", frameon=False, fontsize="small", title="NP, 1 frame", alignment="right"
)
for text in l1.get_texts():
    text.set_color(white_c)

sns.despine(ax=ax1, trim=True)
plt.savefig("FrameID_1frame_np.svg", transparent=True, bbox_inches="tight")
plt.show()
../../_images/cebra-figures_figures_Figure5_10_0.png

Figure 5 d#

  • Decoding accuracy measured by the correct scene prediction using either CEBRA (NP only), jointly trained (2P+NP), or baseline population-vector plus kNN or Bayes decoder using a 1 frame (33 ms) receptive field (V1 data).

[8]:
white = False

if white:
    white_c = "white"
else:
    white_c = "black"

fig_33 = plt.figure(figsize=(7 * scale, 10 * scale), dpi=300)
fig_33.suptitle("Scene annotation")
plt.subplots_adjust(wspace=0.5)
ax1 = plt.subplot(111)
ax1.errorbar(
    num_neurons,
    n_mean_err(scene_np_knn_baseline_33)[0] * 100,
    n_mean_err(scene_np_knn_baseline_33)[1] * 100,
    ls="--",
    label="kNN baseline",
    color="k",
    alpha=0.7,
    markersize=20,
)
ax1.errorbar(
    num_neurons,
    n_mean_err(scene_np_bayes_baseline_33)[0] * 100,
    n_mean_err(scene_np_bayes_baseline_33)[1] * 100,
    ls="--",
    label="Bayes baseline",
    color="k",
    alpha=0.3,
    markersize=20,
)

ax1.errorbar(
    num_neurons,
    n_mean_err(scene_np_cebra_knn_33)[0] * 100,
    n_mean_err(scene_np_cebra_knn_33)[1] * 100,
    label="kNN CEBRA",
    color=c,
    alpha=1,
    markersize=20,
)

ax1.errorbar(
    num_neurons,
    n_mean_err_joint(scene_np_joint_cebra_knn_33, "np")[0] * 100,
    n_mean_err_joint(scene_np_joint_cebra_knn_33, "np")[1] * 100,
    ls="dotted",
    label="kNN CEBRA joint",
    color=c,
    alpha=1,
    markersize=20,
)

set_ax(ax1, white_c)
ax1.set_ylim(20, 100)
ax1.set_ylabel(f"Acc (%)", color=white_c)
sns.despine(ax=ax1, trim=True)
l1 = ax1.legend(
    loc="best", frameon=False, fontsize="small", title="NP, 1 frame", alignment="right"
)
for text in l1.get_texts():
    text.set_color(white_c)
plt.savefig("SceneAnnotation_1frame_np.svg", transparent=True, bbox_inches="tight")
plt.show()
../../_images/cebra-figures_figures_Figure5_12_0.png

Figure 5e#

  • Single frame ground truth frame ID vs predicted frame ID for Neuropixels using a \cebra-Behavior model trained with a 330 ms receptive field (1K V1 neurons across mice used).

[9]:
scale = 0.3
fig = plt.figure(figsize=(7 * scale, 7 * scale), dpi=300)
ax = plt.subplot(111)
frame_err = frame_np_cebra_knn_330_err[1000][3]
ax.scatter(
    np.repeat(np.arange(900), 4),
    np.repeat(np.arange(900), 4) + frame_err,
    s=0.1,
    linewidths=5.0,
    linewidth=3,
    facecolors="none",
    edgecolors="face",
    c="darkgray",
    label="kNN CEBRA",
)
ax.plot(
    (0, 900),
    (0, 900),
    c="#9932EB",
    # zorder = -999,
    lw=1,
    label="ground truth",
)
ax.set_xticks(np.linspace(0, 900, 4))  # , np.linspace(0, 900, 4).astype(int))
ax.set_yticks(np.linspace(0, 900, 4))  # , np.linspace(0, 900, 4).astype(int))
ax.set_xlabel("True frame")
ax.set_ylabel("Predicted frame")
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
sns.despine(trim=True, ax=ax)
ax.legend(
    frameon=False,
    fontsize="small",
    markerscale=10,
    title="NP, 10 frames",
    handlelength=0.75,
)
plt.savefig("TrueVsPredicted_10frame.svg", transparent=True, bbox_inches="tight")
plt.show()
../../_images/cebra-figures_figures_Figure5_14_0.png

Figure 5 f#

  • The mean absolute error of the correct frame index. Shown for baseline and CEBRA models as computed in c, d, e.

[10]:
white = False

if white:
    white_c = "white"
else:
    white_c = "black"

num_neurons = [10, 30, 50, 100, 200, 400, 600, 800, 900, 1000]


def set_ax(ax, white_c):
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.spines["left"].set_color(white_c)
    ax.spines["bottom"].set_color(white_c)
    ax.set_xticks(
        [10, 200, 400, 600, 800, 1000],
        [10, 200, 400, 600, 800, 1000],
        color=white_c,
        rotation=45,
    )
    ax.set_yticks(
        np.linspace(0, 100, 5),
        np.linspace(0, 100, 5, dtype=int),
        color=white_c,
    )
    ax.set_xlabel("# Neurons", color=white_c)
    ax.set_ylabel("Acc (%, in 1s time window)", color=white_c)
    ax.tick_params(colors=white_c)


fig_330 = plt.figure(figsize=(7 * scale, 10 * scale), dpi=300)
plt.title(f"Mean frame error")
plt.subplots_adjust(wspace=0.5)
ax1 = plt.subplot(111)
c = "#9932EB"

ax1.errorbar(
    num_neurons,
    n_mean_err_frame_diff(frame_np_knn_baseline_330_err)[0],
    n_mean_err_frame_diff(frame_np_knn_baseline_330_err)[1],
    ls="--",
    label="kNN baseline",
    color="k",
    alpha=0.7,
    markersize=20,
)

ax1.errorbar(
    num_neurons,
    n_mean_err_frame_diff(frame_np_bayes_baseline_330_err)[0],
    n_mean_err_frame_diff(frame_np_bayes_baseline_330_err)[1],
    ls="--",
    label="Bayes baseline",
    color="k",
    alpha=0.3,
    markersize=20,
)

ax1.errorbar(
    num_neurons,
    n_mean_err_frame_diff(frame_np_cebra_knn_330_err)[0],
    n_mean_err_frame_diff(frame_np_cebra_knn_330_err)[1],
    label="kNN CEBRA",
    color=c,
    alpha=1,
    markersize=20,
)

ax1.errorbar(
    num_neurons,
    n_mean_err_frame_diff_joint(frame_np_joint_cebra_knn_330_err, "np")[0],
    n_mean_err_frame_diff_joint(frame_np_joint_cebra_knn_330_err, "np")[1],
    ls="dotted",
    label="kNN CEBRA joint",
    color=c,
    alpha=1,
    markersize=20,
)

ax1.set_ylim(0, 300)
set_ax(ax1, white_c)
ax1.set_yticks(
    np.linspace(0, 300, 4),
    np.linspace(0, 300, 4, dtype=int),
    color=white_c,
)
ax1.set_ylabel("Frame difference")

l1 = ax1.legend(
    loc="best", frameon=False, title="NP, 10 frames", handlelength=1, fontsize="small"
)
for text in l1.get_texts():
    text.set_color(white_c)

sns.despine(trim=True)
plt.savefig("mean_frame_error_10frames.svg", transparent=True, bbox_inches="tight")
plt.show()
../../_images/cebra-figures_figures_Figure5_16_0.png
[11]:
white = False

if white:
    white_c = "white"
else:
    white_c = "black"

fig_33 = plt.figure(figsize=(7 * scale, 10 * scale), dpi=300)
fig_33.suptitle(f"Mean frame error")
plt.subplots_adjust(wspace=0.5)
ax1 = plt.subplot(111)

c = "#9932EB"

ax1.errorbar(
    num_neurons,
    n_mean_err_frame_diff(frame_np_knn_baseline_33_err)[0],
    n_mean_err_frame_diff(frame_np_knn_baseline_33_err)[1],
    ls="--",
    label="kNN baseline",
    color="k",
    alpha=0.7,
    markersize=20,
)

ax1.errorbar(
    num_neurons,
    n_mean_err_frame_diff(frame_np_bayes_baseline_33_err)[0],
    n_mean_err_frame_diff(frame_np_bayes_baseline_33_err)[1],
    ls="--",
    label="Bayes baseline",
    color="k",
    alpha=0.3,
    markersize=20,
)

ax1.errorbar(
    num_neurons,
    n_mean_err_frame_diff(frame_np_cebra_knn_33_err)[0],
    n_mean_err_frame_diff(frame_np_cebra_knn_33_err)[1],
    label="kNN CEBRA",
    color=c,
    alpha=1,
    markersize=20,
)

ax1.errorbar(
    num_neurons,
    n_mean_err_frame_diff_joint(frame_np_joint_cebra_knn_33_err, "np")[0],
    n_mean_err_frame_diff_joint(frame_np_joint_cebra_knn_33_err, "np")[1],
    ls="dotted",
    label="kNN CEBRA joint",
    color=c,
    alpha=1,
    markersize=20,
)

ax1.set_ylim(0, 400)
set_ax(ax1, white_c)
ax1.set_yticks(
    np.linspace(0, 400, 5),
    np.linspace(0, 400, 5, dtype=int),
    color=white_c,
)
ax1.set_ylabel("Frame difference")

l1 = ax1.legend(
    loc="best", frameon=False, title="NP, 1 frame", handlelength=1, fontsize="small"
)
for text in l1.get_texts():
    text.set_color(white_c)

sns.despine(trim=True, ax=ax1)
plt.savefig("mean_frame_error_1frame.svg", transparent=True, bbox_inches="tight")
plt.show()
../../_images/cebra-figures_figures_Figure5_17_0.png

Figure 5 g#

  • Diagram of the cortical areas considered, and decoding performance from CEBRA jointly trained (2P+NP), 10 frame receptive field.

[12]:
colors = {
    "VISl": "#A7D1F9",
    "VISrl": "#10CCF5",
    "VISal": "#49E7DD",
    "VISp": "#9932EB",
    "VISam": "#99F7CA",
    "VISpm": "#9EF19D",
}

np_decoding = data["cortex_decoding"]

fig = plt.figure(figsize=(7.5 * scale, 10 * scale), dpi=300)
# plt.title('Decoding by cortical area - DINO feature', fontsize=35, y=1.1)
ax = plt.subplot(111)
for area in ["VISal", "VISl", "VISrl", "VISp", "VISam", "VISpm"]:
    ax.errorbar(
        [10, 30, 50, 100, 200, 400, 600, 800],
        [np.mean(np_decoding[area][k]) for k in [10, 30, 50, 100, 200, 400, 600, 800]],
        [
            np.std(np_decoding[area][k]) / np.sqrt(len(np_decoding[area][k]))
            for k in [10, 30, 50, 100, 200, 400, 600, 800]
        ],
        label=area,
        # lw=5,
        color=colors[area],
    )

ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)

plt.xticks([10, 200, 400, 600, 800], [10, 200, 400, 600, 800], color="k")
plt.yticks(np.linspace(0, 100, 5), np.linspace(0, 100, 5, dtype=int), color="k")

plt.xlabel("# Neurons")
plt.ylabel("Acc (%, 1s time window)")
plt.ylim(0, 100)
l = plt.legend(frameon=False, fontsize="small", title="Area")
sns.despine(trim=True, ax=plt.gca())
plt.savefig("visual_areas.svg", transparent=True, bbox_inches="tight")
plt.show()
../../_images/cebra-figures_figures_Figure5_19_0.png

Figure 5 h#

  • V1 decoding performance vs. layer category using 900 neurons with a 330 ms receptive field \cebra-Behavior model.

[13]:
layer_decoding = data["layer_decoding"][900]
fig = plt.figure(figsize=(7.5 * scale, 10 * scale), dpi=300)
ax = fig.add_subplot(111)

labels = ["900"]
colors = ["black", "gray", "lightgray"]

ax.errorbar(
    np.arange(3),
    [np.mean(layer_decoding[layer]) for layer in [2, 4, 5]],
    [
        np.std(layer_decoding[layer]) / np.sqrt(len(layer_decoding[layer]))
        for layer in [2, 4, 5]
    ],
    c="k",
    # lw=4,
)
ax.set_xticks(np.arange(3))
ax.set_xticklabels(["2/3", "4", "5/6"])
ax.set_yticks(np.linspace(70, 100, 4))
ax.set_yticklabels(np.linspace(70, 100, 4, dtype=int))
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.ylabel("Acc (%, 1s time window)")
plt.xlabel("Layer")
plt.ylim(79, 100)
sns.despine(trim=True, ax=plt.gca())
plt.savefig("layer_comparison.svg", transparent=True, bbox_inches="tight")
plt.show()
../../_images/cebra-figures_figures_Figure5_21_0.png