You can download and run the notebook locally:

Download jupyter notebook


Extended Data Figure 6: Persistence across dimensions#

import plot and data loading dependencies#

[1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from persim import plot_diagrams
[2]:
data = pd.read_hdf("../data/EDFigure6.h5")

CEBRA rat dataset toplogy analysis#

For each dimension of CEBRA-Behavior embedding from the rat hippocampus dataset, we obtained Betti number by applying persistent co-homology. The colored dots are lifespans observed in hypothesis based CEBRA-Behavior. To rule out noisy lifespans, we set a threshold (colored diagonal lines) as maximum lifespan based on 500 seeds of shuffled-CEBRA embedding for each dimension. We obtained circular coordinate from the first cocycle from persistent co-homology analysis and map it to position.

[3]:
def plot_barcode(topology_dgms, shuffled_max, ax, label_vis):
    plot_diagrams(
        topology_dgms,
        ax=ax,
        legend=False,
    )
    ax.plot(
        [
            -0.5,
            2,
        ],
        [-0.5 + shuffled_max[0], 2 + shuffled_max[0]],
        color="C0",
        linewidth=3,
        alpha=0.5,
    )
    ax.plot(
        [
            -0.5,
            2,
        ],
        [-0.5 + shuffled_max[1], 2 + shuffled_max[1]],
        color="orange",
        linewidth=3,
        alpha=0.5,
    )
    ax.plot(
        [-0.50, 2],
        [-0.5 + shuffled_max[2], 2 + shuffled_max[2]],
        color="green",
        linewidth=3,
        alpha=0.5,
    )
    ax.set_xlabel("Birth", fontsize=15)
    ax.set_xticks([0, 1, 2])
    ax.set_xticklabels([0, 1, 2])
    ax.tick_params(labelsize=13)
    if label_vis:
        ax.set_ylabel("Death", fontsize=15)
    else:
        ax.set_ylabel("")


def plot_polar(radial_angle, label, ax, first, last):
    r = np.ones(1000)
    r_ind = label[:, 1] == 1
    l_ind = label[:, 2] == 1
    right = ax.scatter(
        radial_angle[r_ind], r[r_ind], s=5, c=label[r_ind, 0], cmap="cool"
    )
    left = ax.scatter(
        radial_angle[l_ind], r[l_ind], s=5, c=label[l_ind, 0], cmap="viridis"
    )
    ax.set_yticklabels([])
    ax.set_xticks(np.pi / 180.0 * np.linspace(0, 360, 8, endpoint=False))
    if first:
        ax.set_xticklabels(
            [
                "",
                "45$^\circ$",
                "90$^\circ$",
                "135$^\circ$",
                "180$^\circ$",
                "225$^\circ$",
                "270$^\circ$",
                "315$^\circ$",
            ]
        )
    if last:
        ax.set_xticklabels(
            [
                "0$^\circ$",
                "45$^\circ$",
                "90$^\circ$",
                "135$^\circ$",
                "",
                "225$^\circ$",
                "270$^\circ$",
                "315$^\circ$",
            ]
        )
    if not first and not last:
        ax.set_xticks(np.pi / 180.0 * np.linspace(0, 360, 9))
        ax.set_xticklabels(
            [
                "",
                "45$^\circ$",
                "90$^\circ$",
                "135$^\circ$",
                "",
                "225$^\circ$",
                "270$^\circ$",
                "315$^\circ$",
                "",
            ]
        )
    ax.tick_params(labelsize=11)
    return right, left


def plot_position_angle(radial_angle, position, ax, yax_label=False):
    r_ind = position[:, 1] == 1
    l_ind = position[:, 2] == 1

    def _to_deg(rad):
        return rad * 180 / np.pi

    ax.scatter(_to_deg(radial_angle[r_ind]), position[r_ind, 0] * 100, s=1, c="black")
    ax.scatter(_to_deg(radial_angle[l_ind]), position[l_ind, 0] * 100, s=1, c="gray")

    ax.set_xlabel("Radial angle [$^\circ$]", fontsize=15)
    if yax_label:
        ax.set_ylabel("Position [cm]", fontsize=15)

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.tick_params(labelsize=13)
    ax.set_xticks([0, 90, 180, 270, 360])
    ax.set_xticklabels([0, 90, 180, 270, 360])


def display_labels_flag(i):
    if i == 0:
        first = True
        last = False
        label_vis = True
    elif i == 4:
        first = False
        last = True
        label_vis = False
    else:
        first = False
        last = False
        label_vis = False
    return first, last, label_vis
[4]:
behavior_topology = data["behavior"]
max_dims_behavior = data["shuffled-max"]
circular_coord = data["circular_coord_analysis"]
dims = [3, 8, 16, 32, 64]
fig = plt.figure(figsize=(15, 11), dpi=300)
subfigs = fig.subfigures(
    3,
    1,
)
barcode_axs = subfigs[0].subplots(1, 5, sharey=True)
polar_axs = subfigs[1].subplots(1, 5, sharey=True, subplot_kw={"projection": "polar"})
position_axs = subfigs[2].subplots(1, 5, sharey=True)
for i, d in enumerate(dims):

    first, last, label_vis = display_labels_flag(i)
    barcode_ax = barcode_axs[i]
    barcode_ax.set_title(f"Dimension {d}", fontsize=20, y=1.2)
    plot_barcode(
        behavior_topology[d]["dgms"], max_dims_behavior[d], barcode_ax, label_vis
    )
    polar_ax = polar_axs[i]
    right, left = plot_polar(
        circular_coord["circular_coord"][0][d],
        circular_coord["position"],
        polar_ax,
        first,
        last,
    )
    position_ax = position_axs[i]

    plot_position_angle(
        circular_coord["circular_coord"][0][d],
        circular_coord["position"],
        position_ax,
        label_vis,
    )

barcode_legend = [
    Line2D([0], [0], ls="--", color="k"),
    Line2D([0], [0], markerfacecolor="C0", marker="o", color="w", markersize=6),
    Line2D([0], [0], markerfacecolor="C1", marker="o", color="w", markersize=6),
    Line2D([0], [0], markerfacecolor="C2", marker="o", color="w", markersize=6),
]

position_legend = [
    Line2D([0], [0], markerfacecolor="black", marker="o", color="w", markersize=6),
    Line2D([0], [0], markerfacecolor="gray", marker="o", color="w", markersize=6),
]


subfigs[0].legend(
    barcode_legend,
    ["$\infty$", "$H_0$", "$H_1$", "$H_2$"],
    frameon=False,
    bbox_to_anchor=[0.98, 0.7],
    fontsize=15,
)
cax_r = polar_axs[-1].inset_axes([1.2, 0.1, 0.05, 1])
cax_l = polar_axs[-1].inset_axes([1.3, 0.1, 0.05, 1])
r_bar = subfigs[1].colorbar(right, ax=polar_axs, cax=cax_r)
l_bar = subfigs[1].colorbar(left, ax=polar_axs, cax=cax_l)
r_bar.set_ticks([])
l_bar.set_ticks([0, 1.5], fontsize=15)
l_bar.set_label("position [cm]", rotation=0, y=-0.1, fontsize=12)
subfigs[2].legend(
    position_legend,
    ["right", "left"],
    frameon=False,
    bbox_to_anchor=[0.98, 0.4],
    fontsize=15,
)
[4]:
<matplotlib.legend.Legend at 0x7f29a55799a0>
../../_images/cebra-figures_figures_ExtendedDataFigure6_6_1.png