You can download and run the notebook locally:

Download jupyter notebook


Extended Data Figure 7: Multi-session training and rapid decoding#

import plot and data loading dependencies#

[1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
import joblib as jl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.markers import MarkerStyle

Define data loading & functions#

[2]:
def recover_python_datatypes(element):
    if isinstance(element, str):
        if element.startswith("[") and element.endswith("]"):
            if "," in element:
                element = np.fromstring(element[1:-1], dtype=float, sep=",")
            else:
                element = np.fromstring(element[1:-1], dtype=float, sep=" ")
    return element


def figure():
    plt.figure(figsize=(2, 2), dpi=200)

Define metric functions#

[3]:
def mean(v):
    return sum(v) / len(v)


def sort_results(metrics, full_metrics, results):

    sorted_results = results.copy()
    sorted_results["seed"] = -1
    idx = sorted_results.pivot_table("logdir", index=full_metrics, aggfunc="count")
    sorted_results = sorted_results.set_index(full_metrics)
    for i in idx.T:
        sorted_results.loc[i, "seed"] = np.arange(10)

    sorted_results.groupby(full_metrics)["seed"].apply(tuple).apply(
        lambda v: v == tuple(range(10))
    ).all().all()
    return sorted_results.reset_index()


import pandas as pd

metrics = [
    "batch_size",
    "conditional",
    "delta",
    "device",
    "learning_rate",
    "loss_distance",
    "model",
    "num_hidden_units",
    "num_output",
    "num_steps",
    "save_emissions",
    "save_frequency",
    "share_model",
    "temperature",
    "time_offset",
    "tqdm",
    "train_ratio",
    "valid_frequency",
    "valid_ratio",
    "variant",
    "repeat",
    "steps",
]
full_metrics = metrics + ["data_mode"]
results = pd.read_csv("../data/Figure4SupplementMultisession_rat.csv").applymap(
    recover_python_datatypes
)

sorted_results = sort_results(metrics, full_metrics, results)
/tmp/ipykernel_852782/2666736443.py:12: PerformanceWarning: indexing past lexsort depth may impact performance.
  sorted_results.loc[i, "seed"] = np.arange(10)

Load and check metrics#

[4]:
aggregated_metrics = [m for m in full_metrics if m not in ["repeat", "seed"]]
aggregated_results = sorted_results.groupby(aggregated_metrics).agg(mean).reset_index()
aggregated_results.columns

idx = (
    aggregated_results[aggregated_results.data_mode == "single-session"]
    .sort_values("train_consistency")
    .groupby("data_mode")
    .tail(1)
    .set_index(aggregated_metrics)
    .index
)
sorted_results.set_index(aggregated_metrics).loc[idx].reset_index().groupby(
    aggregated_metrics + ["seed"]
).mean().reset_index().to_csv("cebra-behavior.csv")
/tmp/ipykernel_852782/495670736.py:2: FutureWarning: Dropping invalid columns in DataFrameGroupBy.agg is deprecated. In a future version, a TypeError will be raised. Before calling .agg, select only columns which should be valid for the function.
  aggregated_results = sorted_results.groupby(aggregated_metrics).agg(mean).reset_index()
[5]:
def compute(sorted_results):
    sorted_results = sorted_results.copy()
    sorted_results["mean_train_consistency"] = sorted_results[
        "train_consistency_raw"
    ].apply(np.mean)

    groupby = [m for m in metrics if m != "repeat"] + ["data_mode"]
    sorted_results

    sorted_results = sorted_results[
        (sorted_results.num_hidden_units == 32) & (sorted_results.steps == 5000)
    ].explode("test_position_error_raw")

    sorted_results["animal"] = np.arange(len(sorted_results)) % 4
    return sorted_results


result = (
    compute(sorted_results)
    .pivot_table(
        "test_position_error_raw",
        index=("data_mode", "animal"),
        columns="seed",
        aggfunc="mean",
    )
    .agg(list, axis=1)
    .unstack()
    .T
)

Comparison of decoding vs. consistency#

  • Define plotting function

  • Comparison of decoding metrics for single or multi-session training at various consistency levels (averaged across all 12 comparisons). Models were trained for 5,000 (single) or 10,000 (multi-session) steps with a 0.003 learning rate; batch size was 7,200 samples per session. Multi-session training requires longer training or higher learning rates to obtain the same accuracy due to the 4-fold larger batch size, but converges to same decoding accuracy. We plot points at intervals of 500 steps (n=10 seeds); training progresses from lower right to upper left corner within both plots.

[6]:
def plot_hippocampus(sorted_results, metric, yaxis):

    filtered_results = (
        sorted_results[
            (sorted_results.num_hidden_units == 32)
            & (sorted_results.num_output == 32)
            & (sorted_results.steps > 100)
            & (sorted_results.steps < 5000)
        ]
        .set_index("data_mode")
        .sort_values("valid_position_error")
        .groupby(["steps", "data_mode", "seed", "repeat", "num_output"])
        .head(1)
    )
    filtered_results = filtered_results[filtered_results.steps % 250 == 0]
    filtered_results["train_consistency"] *= 100
    filtered_results["test_position_error"] *= 100
    filtered_results = filtered_results.groupby(["steps", "data_mode", "seed"]).mean()
    # display(filtered_results.reset_index().steps.value_counts())

    plt.figure(figsize=(2, 2), dpi=200)
    sns.scatterplot(
        data=filtered_results.reset_index().sort_values("data_mode"),
        x="train_consistency",
        y="test_position_error",
        style="data_mode",
        hue="data_mode",
        palette=["#840884", "k"],
        markers=[MarkerStyle("v", "none"), MarkerStyle("o", "none")],
        alpha=0.5,
        ci=None,
        legend=False,
    )

    plt.legend(loc=(1, 0), frameon=False)
    plt.xlabel("Consistency [% R²]")

    plt.xticks([90, 92, 94, 96, 98])

    sns.despine(trim=True)
    plt.ylabel(yaxis)
    plt.xlabel("Consistency [% R²]")

    plt.xticks(rotation=90)
    plt.gca().invert_yaxis()
    plt.show()


from matplotlib.markers import MarkerStyle


def plot_allen(results):
    def add_seed(line):
        line["seed"] = np.arange(10)
        return line

    traj_results = results.groupby(["data_mode", "repeat", "steps"]).apply(add_seed)
    traj_results = (
        traj_results.groupby(["data_mode", "seed", "steps"]).apply("mean").reset_index()
    )
    traj_results = (
        traj_results.set_index("data_mode").drop("time-contrastive").reset_index()
    )
    traj_results.train_consistency *= 100
    traj_results = traj_results[traj_results.steps % 500 == 0]
    traj_results.data_mode = traj_results.data_mode.apply(
        {"multi-session": "multi", "single-session": "single"}.get
    )

    plt.figure(figsize=(2, 2), dpi=200)
    sns.scatterplot(
        data=traj_results[["data_mode", "steps", "test_accuracy", "train_consistency"]],
        x="train_consistency",
        y="test_accuracy",
        style="data_mode",
        hue="data_mode",
        palette=["#840884", "k"],
        markers=[MarkerStyle("v", "none"), MarkerStyle("o", "none")],
        # s = 10,
        alpha=0.6,
        ci=None,
    )
    plt.xlabel("Consistency [% R²]")
    plt.ylabel("Decoding Accuracy [%]")
    plt.xticks([94, 96, 98, 100])
    sns.despine(trim=True)
    plt.legend(frameon=False)
    plt.show()


plot_hippocampus(
    sorted_results, metric="test_position_error", yaxis="Decoding Error [cm]"
)

allen_results = pd.read_hdf("../data/EDFigure7/allen-consistency.h5", key="allen")
plot_allen(allen_results)
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
../../_images/cebra-figures_figures_ExtendedDataFigure7_11_1.png
../../_images/cebra-figures_figures_ExtendedDataFigure7_11_2.png

consistency matrix single vs. multi-session training for hippocampus (32D embedding) and Allen data (128D embedding) respectively.#

  • Consistency is reported at the point in training where the average position decoding error is less than 14 cm (corresponds to 7 cm error for rat 1), and a decoding accuracy of 60% on the Allen dataset.

[7]:
def show_sweep(df):
    for c in df.columns:
        try:
            values = df[c].unique()
        except TypeError:
            continue
        # if len(values) != len(df):
        if len(values) > 1:
            print(c, values)


filtered_results = aggregated_results[
    (aggregated_results.batch_size == 512)
    & (aggregated_results.learning_rate == 0.0003)
    & (aggregated_results.num_output == 32)
    & (aggregated_results.num_hidden_units == 32)
].set_index("data_mode")

filtered_results = filtered_results[filtered_results["valid_position_error"] < 0.14]
display(
    filtered_results.sort_values("valid_position_error")["valid_position_error"]
    .groupby("data_mode")
    .tail(1)
)
filtered_results = (
    filtered_results.sort_values("valid_position_error").groupby("data_mode").tail(1)
)

print("Rat hippocampus dataset")
_, axes = plt.subplots(2, 2, figsize=(4, 4), dpi=200)

for ax_row, data_mode in zip(axes, ["single-session", "multi-session"]):

    for ax, split in zip(ax_row, ["train"]):

        values = filtered_results[f"{split}_consistency_raw"].loc[data_mode]

        cfm = np.zeros((4, 4))
        cfm[:] = float("nan")
        cfm[~np.eye(4).astype(bool)] = values
        sns.heatmap(
            cfm * 100,
            cmap="gray_r",
            vmin=50,
            vmax=100,
            annot=True,
            fmt=".1f",
            square=True,
            cbar=False,
            ax=ax,
        )

for ax in axes.flatten():
    ax.axis("off")

plt.subplots_adjust(wspace=-0.005, hspace=0.05)
plt.show()
data_mode
multi-session     0.139947
single-session    0.139954
Name: valid_position_error, dtype: float64
Rat hippocampus dataset
../../_images/cebra-figures_figures_ExtendedDataFigure7_13_2.png
[8]:
def agg(values):
    if values.name.endswith("_raw"):
        return np.stack(values, axis=0).mean(axis=0)
    return values.mean()


def show_sweep(df):
    for c in df.columns:
        try:
            values = df[c].unique()
        except TypeError:
            continue
        # if len(values) != len(df):
        if len(values) > 1:
            print(c, values)


def filter_results(results):
    results_ = (
        results.groupby([c for c in metrics + ["data_mode"] if c != "repeat"])
        .agg(agg)
        .reset_index()
    )
    filtered_results = results_[
        (results.batch_size == 7200)
        & (results.learning_rate == 0.003)
        & (results.num_output == 128)
        & (results.num_hidden_units == 128)  # &
    ].set_index("data_mode")

    filtered_results = filtered_results[filtered_results["valid_accuracy"] > 60]
    filtered_results = (
        filtered_results.sort_values("valid_accuracy").groupby("data_mode").head(1)
    )  # ["valid_accuracy"]
    # close to 60% accuracy
    # filtered_results["valid_accuracy"]
    return filtered_results


def plot_consistency_allen(filtered_results):
    _, axes = plt.subplots(2, 1, figsize=(4, 4), dpi=200)

    for ax, data_mode in zip(axes, ["single-session", "multi-session"]):

        split = "train"
        values = filtered_results[f"{split}_consistency_raw"].loc[data_mode]

        cfm = np.zeros((4, 4))
        cfm[:] = float("nan")
        cfm[~np.eye(4).astype(bool)] = values
        sns.heatmap(
            cfm * 100,
            cmap="gray_r",
            vmin=50,
            vmax=100,
            annot=True,
            fmt=".1f",
            square=True,
            cbar=False,
            ax=ax,
        )
        ax.axis("off")

    plt.subplots_adjust(wspace=-0.005, hspace=0.05)
    plt.show()


allen_results = pd.read_hdf("../data/EDFigure7/allen-consistency.h5", key="allen")
filtered_results_allen = filter_results(allen_results)
print("Allen dataset")
plot_consistency_allen(filtered_results_allen)
/tmp/ipykernel_852782/1272793558.py:20: FutureWarning: Dropping invalid columns in DataFrameGroupBy.agg is deprecated. In a future version, a TypeError will be raised. Before calling .agg, select only columns which should be valid for the function.
  results.groupby([c for c in metrics + ["data_mode"] if c != "repeat"])
Allen dataset
/tmp/ipykernel_852782/1272793558.py:24: UserWarning: Boolean Series key will be reindexed to match DataFrame index.
  filtered_results = results_[
../../_images/cebra-figures_figures_ExtendedDataFigure7_14_3.png
[9]:
filtered_results[f"train_consistency_raw"].loc["multi-session"], filtered_results[
    f"train_consistency_raw"
].loc["single-session"]
[9]:
(array([0.95081546, 0.91219784, 0.97159572, 0.9668054 , 0.89214698,
        0.95330396, 0.93798715, 0.90619733, 0.92378746, 0.97700708,
        0.93911369, 0.90545486]),
 array([0.95555933, 0.88572262, 0.97200465, 0.9593825 , 0.84737061,
        0.94094321, 0.93445153, 0.90453392, 0.92011637, 0.97647199,
        0.94380083, 0.87852935]))

Statistics#

[10]:
import scipy.stats

scipy.stats.ttest_rel(
    filtered_results[f"train_consistency_raw"].loc["multi-session"],
    filtered_results[f"train_consistency_raw"].loc["single-session"],
    alternative="greater",
)
[10]:
Ttest_relResult(statistic=2.2233272445184493, pvalue=0.024042329397411348)

Accuracy on train, validation, and test sets:#

[11]:
def _select(values):
    index, values = values[0], values[1:]
    return values[index]


def hyperparameter_pivot(
    data, select_with, metric, index=[], columns=[], select_func=np.argmin
):

    idx = data.pivot_table(
        select_with, index=index, columns=columns, aggfunc=lambda v: (select_func(v),)
    )

    agg = data.pivot_table(metric, index=index, columns=columns, aggfunc=tuple)

    if len(agg.columns) == 1:
        idx.columns = [metric]
        agg.columns = [metric]

    return (idx + agg).applymap(_select)


from IPython.display import display, HTML

KEYS = (
    "train_consistency",
    "valid_consistency",
    "test_consistency",
    "valid_position_error",
    "test_position_error",
)


def show_results(select_by="valid_accuracy", func=np.argmax, keys=KEYS):

    frames = (
        hyperparameter_pivot(
            results,
            select_by,
            key,
            index=["data_mode"],
            columns=["batch_size"],
            select_func=func,
        )
        * (100 if "consistency" in key else 1)
        for key in keys
    )

    titles = [
        "Consistency (R², train)",
        "Consistency (R², valid)",
        "Consistency (R², test)",
        "Accuracy (valid)",
        "Accuracy (test)",
    ]

    # summary = pd.concat(frames, axis = 1)
    # summary.columns = titles
    # return HTML(summary.style.format(precision = 2)._repr_html_())

    return HTML(
        ("\xa0" * 3).join(
            frame._repr_html_() + title for frame, title in zip(frames, titles)
        )
    )


select_metrics = ("valid_total_r2",)
for key in select_metrics:
    display(show_results(key))
    print("selected by best", key)

#
train_consistency
data_mode
multi-session 95.259507
single-session 94.677608
Consistency (R², train)   
valid_consistency
data_mode
multi-session 84.270572
single-session 82.158812
Consistency (R², valid)   
test_consistency
data_mode
multi-session 89.312167
single-session 85.715708
Consistency (R², test)   
valid_position_error
data_mode
multi-session 0.124064
single-session 0.127632
Accuracy (valid)   
test_position_error
data_mode
multi-session 0.126732
single-session 0.128096
Accuracy (test)
selected by best valid_total_r2

Load data for Fine-tuning on unseen subject experiments, and define plotting code#

[12]:
data = pd.read_hdf("../data/EDFigure7.h5")
[13]:
def one_rat(ax, rat_no, rat_df):
    adapt_reinit = [True, False]
    adapt_tune_all = [True, False]

    for reinit in adapt_reinit:
        for tune_all in adapt_tune_all:
            if reinit:
                c = "gray"
            else:
                c = "purple"  # (0, (5, 10))
            if tune_all:
                ls = "solid"
            else:
                ls = "dashed"
            exp_type = rat_df[
                (rat_df["adapt_reinit"] == reinit)
                & (rat_df["adapt_tune_all"] == tune_all)
            ]
            ax.plot(
                np.arange(int(steps / 25)),
                np.array(eval(exp_type["median_err"].item()))[: int(steps / 25)] * 100,
                linestyle=ls,
                color=c,
                lw=3,
            )
            ax.set_title(f"Rat {rat_no+1}", fontsize=20)
            ax.spines["right"].set_visible(False)
            ax.spines["top"].set_visible(False)

            ax.set_ylabel("Median Error (cm)", fontsize=30)
            ax.set_xlabel("Adaptation steps", fontsize=30)
            ax.set_xticks(xticks / 25)
            ax.set_yticks(yticks)
            ax.set_yticklabels(yticks, fontsize=25)
            ax.set_xticklabels(xticks, fontsize=25)


def plot(rat_df):

    fig = plt.figure(figsize=(7, 7))
    plt.subplots_adjust(wspace=0.5, hspace=0.5)
    ax = plt.subplot(111)
    rat = 3
    plt.yticks(fontsize=25)

    one_rat(ax, rat, rat_df)

    custom_lines = [
        Line2D([0], [0], color="gray", lw=4),
        Line2D([0], [0], color="purple", lw=4),
        Line2D([0], [0], color="black", lw=4, ls="solid"),
        Line2D([0, 2], [0, 0], color="black", lw=4, ls="dashed"),
    ]

    ax.legend(
        custom_lines,
        [
            "From scratch",
            "Pre-trained",
            "Full model",
            "Input embedding",
        ],
        loc=(1.2, 0),
        frameon=False,
        fontsize=30,
    )

Adapt to an unseen dataset#

  • Here, 3 rats were used for pretraining, and rat #4 was used as a held-out test. The grey lines are from scratch training. We also tested fine-tuning the input or full model, as the diagram, left, describes. We measured the average time to adapt 100 steps (0.65 ± 0.13 sec) and 500 steps (3.07± 0.61 sec) on 40 repeated experiments.

[14]:
steps = 5000
yticks = np.arange(15, 55, 10)
xticks = np.arange(0, 6000, 1000)
plot(data)
../../_images/cebra-figures_figures_ExtendedDataFigure7_24_0.png

panel, inset#

[15]:
steps = 500
yticks = np.arange(15, 55, 10)
xticks = np.arange(0, steps + 25, 25 * 10)
plot(data)
../../_images/cebra-figures_figures_ExtendedDataFigure7_26_0.png