Synthetic neural benchmarking#
this notebook will demo how to use CEBRA, piVAE, tSNE and UMAP on synthetic datasets.
COLAB, Mac pre-M1, *Windows, and Ubuntu users can skip to cell 2.
*This has not been tested on Windows
[ ]:
#Attention: M1 chip users ONLY:
# TF is a dependency of piVAE, which is used in this demo.
# It requires a few extra steps that are addressed in this cell; in short, please run this once.
# See also: https://developer.apple.com/metal/tensorflow-plugin/ if any issues.
!pip uninstall -y tensorflow-deps tensorflow-macos tensorflow-metal keras
!conda install pytorch torchvision torchaudio -c pytorch
!pip install --upgrade --force-reinstall scikit-learn
!conda install -c apple -y tensorflow-deps==2.5.0 --force-reinstall
!python -m pip install tensorflow-macos
!python -m pip install tensorflow-metal
Install note: be sure you have demo dependencies installed to use this notebook:
[1]:
!pip install --pre 'cebra[datasets,demos]'
Collecting cebra[datasets,demos]
Downloading cebra-0.4.0-py2.py3-none-any.whl.metadata (5.8 kB)
Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (1.4.2)
Collecting literate-dataclasses (from cebra[datasets,demos])
Downloading literate_dataclasses-0.0.6-py3-none-any.whl.metadata (2.3 kB)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (1.2.2)
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (1.11.4)
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (2.3.1+cu121)
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (4.66.4)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (3.7.1)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (2.31.0)
Requirement already satisfied: ipykernel in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (5.5.6)
Collecting jupyter (from cebra[datasets,demos])
Downloading jupyter-1.0.0-py2.py3-none-any.whl.metadata (995 bytes)
Requirement already satisfied: nbconvert in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (6.5.4)
Requirement already satisfied: seaborn in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (0.13.1)
Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (3.9.0)
Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (2.0.3)
Collecting nlb-tools (from cebra[datasets,demos])
Downloading nlb_tools-0.0.4-py3-none-any.whl.metadata (3.8 kB)
Collecting hdf5storage (from cebra[datasets,demos])
Downloading hdf5storage-0.1.19-py2.py3-none-any.whl.metadata (24 kB)
Requirement already satisfied: openpyxl in /usr/local/lib/python3.10/dist-packages (from cebra[datasets,demos]) (3.1.5)
Requirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.10/dist-packages (from h5py->cebra[datasets,demos]) (1.25.2)
Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.10/dist-packages (from ipykernel->cebra[datasets,demos]) (0.2.0)
Requirement already satisfied: ipython>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from ipykernel->cebra[datasets,demos]) (7.34.0)
Requirement already satisfied: traitlets>=4.1.0 in /usr/local/lib/python3.10/dist-packages (from ipykernel->cebra[datasets,demos]) (5.7.1)
Requirement already satisfied: jupyter-client in /usr/local/lib/python3.10/dist-packages (from ipykernel->cebra[datasets,demos]) (6.1.12)
Requirement already satisfied: tornado>=4.2 in /usr/local/lib/python3.10/dist-packages (from ipykernel->cebra[datasets,demos]) (6.3.3)
Requirement already satisfied: notebook in /usr/local/lib/python3.10/dist-packages (from jupyter->cebra[datasets,demos]) (6.5.5)
Collecting qtconsole (from jupyter->cebra[datasets,demos])
Downloading qtconsole-5.5.2-py3-none-any.whl.metadata (5.1 kB)
Requirement already satisfied: jupyter-console in /usr/local/lib/python3.10/dist-packages (from jupyter->cebra[datasets,demos]) (6.1.0)
Requirement already satisfied: ipywidgets in /usr/local/lib/python3.10/dist-packages (from jupyter->cebra[datasets,demos]) (7.7.1)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->cebra[datasets,demos]) (1.2.1)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->cebra[datasets,demos]) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->cebra[datasets,demos]) (4.53.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->cebra[datasets,demos]) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->cebra[datasets,demos]) (24.1)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->cebra[datasets,demos]) (9.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->cebra[datasets,demos]) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->cebra[datasets,demos]) (2.8.2)
Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (4.9.4)
Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (4.12.3)
Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (6.1.0)
Requirement already satisfied: defusedxml in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (0.7.1)
Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (0.4)
Requirement already satisfied: jinja2>=3.0 in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (3.1.4)
Requirement already satisfied: jupyter-core>=4.7 in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (5.7.2)
Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (0.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (2.1.5)
Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (0.8.4)
Requirement already satisfied: nbclient>=0.5.0 in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (0.10.0)
Requirement already satisfied: nbformat>=5.1 in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (5.10.4)
Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (1.5.1)
Requirement already satisfied: pygments>=2.4.1 in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (2.16.1)
Requirement already satisfied: tinycss2 in /usr/local/lib/python3.10/dist-packages (from nbconvert->cebra[datasets,demos]) (1.3.0)
Collecting pandas (from cebra[datasets,demos])
Downloading pandas-1.3.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting pynwb (from nlb-tools->cebra[datasets,demos])
Downloading pynwb-2.8.1-py3-none-any.whl.metadata (8.9 kB)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.10/dist-packages (from pandas->cebra[datasets,demos]) (2023.4)
Requirement already satisfied: et-xmlfile in /usr/local/lib/python3.10/dist-packages (from openpyxl->cebra[datasets,demos]) (1.1.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->cebra[datasets,demos]) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->cebra[datasets,demos]) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->cebra[datasets,demos]) (2.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->cebra[datasets,demos]) (2024.7.4)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->cebra[datasets,demos]) (3.5.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->cebra[datasets,demos]) (3.15.4)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch->cebra[datasets,demos]) (4.12.2)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->cebra[datasets,demos]) (1.13.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->cebra[datasets,demos]) (3.3)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->cebra[datasets,demos]) (2023.6.0)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->cebra[datasets,demos])
Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->cebra[datasets,demos])
Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->cebra[datasets,demos])
Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->cebra[datasets,demos])
Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->cebra[datasets,demos])
Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->cebra[datasets,demos])
Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch->cebra[datasets,demos])
Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch->cebra[datasets,demos])
Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch->cebra[datasets,demos])
Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-nccl-cu12==2.20.5 (from torch->cebra[datasets,demos])
Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting nvidia-nvtx-cu12==12.1.105 (from torch->cebra[datasets,demos])
Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)
Requirement already satisfied: triton==2.3.1 in /usr/local/lib/python3.10/dist-packages (from torch->cebra[datasets,demos]) (2.3.1)
Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch->cebra[datasets,demos])
Downloading nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.10/dist-packages (from ipython>=5.0.0->ipykernel->cebra[datasets,demos]) (71.0.4)
Collecting jedi>=0.16 (from ipython>=5.0.0->ipykernel->cebra[datasets,demos])
Downloading jedi-0.19.1-py2.py3-none-any.whl.metadata (22 kB)
Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from ipython>=5.0.0->ipykernel->cebra[datasets,demos]) (4.4.2)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.10/dist-packages (from ipython>=5.0.0->ipykernel->cebra[datasets,demos]) (0.7.5)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from ipython>=5.0.0->ipykernel->cebra[datasets,demos]) (3.0.47)
Requirement already satisfied: backcall in /usr/local/lib/python3.10/dist-packages (from ipython>=5.0.0->ipykernel->cebra[datasets,demos]) (0.2.0)
Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.10/dist-packages (from ipython>=5.0.0->ipykernel->cebra[datasets,demos]) (0.1.7)
Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.10/dist-packages (from ipython>=5.0.0->ipykernel->cebra[datasets,demos]) (4.9.0)
Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.10/dist-packages (from jupyter-core>=4.7->nbconvert->cebra[datasets,demos]) (4.2.2)
Requirement already satisfied: pyzmq>=13 in /usr/local/lib/python3.10/dist-packages (from jupyter-client->ipykernel->cebra[datasets,demos]) (24.0.1)
Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.10/dist-packages (from nbformat>=5.1->nbconvert->cebra[datasets,demos]) (2.20.0)
Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.10/dist-packages (from nbformat>=5.1->nbconvert->cebra[datasets,demos]) (4.19.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->cebra[datasets,demos]) (1.16.0)
Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->nbconvert->cebra[datasets,demos]) (2.5)
Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->nbconvert->cebra[datasets,demos]) (0.5.1)
Requirement already satisfied: widgetsnbextension~=3.6.0 in /usr/local/lib/python3.10/dist-packages (from ipywidgets->jupyter->cebra[datasets,demos]) (3.6.7)
Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ipywidgets->jupyter->cebra[datasets,demos]) (3.0.11)
Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.10/dist-packages (from notebook->jupyter->cebra[datasets,demos]) (23.1.0)
Requirement already satisfied: nest-asyncio>=1.5 in /usr/local/lib/python3.10/dist-packages (from notebook->jupyter->cebra[datasets,demos]) (1.6.0)
Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.10/dist-packages (from notebook->jupyter->cebra[datasets,demos]) (1.8.3)
Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.10/dist-packages (from notebook->jupyter->cebra[datasets,demos]) (0.18.1)
Requirement already satisfied: prometheus-client in /usr/local/lib/python3.10/dist-packages (from notebook->jupyter->cebra[datasets,demos]) (0.20.0)
Requirement already satisfied: nbclassic>=0.4.7 in /usr/local/lib/python3.10/dist-packages (from notebook->jupyter->cebra[datasets,demos]) (1.1.0)
Collecting hdmf>=3.14.0 (from pynwb->nlb-tools->cebra[datasets,demos])
Downloading hdmf-3.14.2-py3-none-any.whl.metadata (8.8 kB)
Collecting qtpy>=2.4.0 (from qtconsole->jupyter->cebra[datasets,demos])
Downloading QtPy-2.4.1-py3-none-any.whl.metadata (12 kB)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->cebra[datasets,demos]) (1.3.0)
Collecting ruamel-yaml>=0.16 (from hdmf>=3.14.0->pynwb->nlb-tools->cebra[datasets,demos])
Downloading ruamel.yaml-0.18.6-py3-none-any.whl.metadata (23 kB)
Requirement already satisfied: parso<0.9.0,>=0.8.3 in /usr/local/lib/python3.10/dist-packages (from jedi>=0.16->ipython>=5.0.0->ipykernel->cebra[datasets,demos]) (0.8.4)
Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat>=5.1->nbconvert->cebra[datasets,demos]) (23.2.0)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat>=5.1->nbconvert->cebra[datasets,demos]) (2023.12.1)
Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat>=5.1->nbconvert->cebra[datasets,demos]) (0.35.1)
Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat>=5.1->nbconvert->cebra[datasets,demos]) (0.19.0)
Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.10/dist-packages (from nbclassic>=0.4.7->notebook->jupyter->cebra[datasets,demos]) (0.2.4)
Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.10/dist-packages (from pexpect>4.3->ipython>=5.0.0->ipykernel->cebra[datasets,demos]) (0.7.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=5.0.0->ipykernel->cebra[datasets,demos]) (0.2.13)
Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.10/dist-packages (from argon2-cffi->notebook->jupyter->cebra[datasets,demos]) (21.2.0)
Requirement already satisfied: jupyter-server<3,>=1.8 in /usr/local/lib/python3.10/dist-packages (from notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook->jupyter->cebra[datasets,demos]) (1.24.0)
Collecting ruamel.yaml.clib>=0.2.7 (from ruamel-yaml>=0.16->hdmf>=3.14.0->pynwb->nlb-tools->cebra[datasets,demos])
Downloading ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl.metadata (2.2 kB)
Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter->cebra[datasets,demos]) (1.16.0)
Requirement already satisfied: pycparser in /usr/local/lib/python3.10/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter->cebra[datasets,demos]) (2.22)
Requirement already satisfied: anyio<4,>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook->jupyter->cebra[datasets,demos]) (3.7.1)
Requirement already satisfied: websocket-client in /usr/local/lib/python3.10/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook->jupyter->cebra[datasets,demos]) (1.8.0)
Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<4,>=3.1.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook->jupyter->cebra[datasets,demos]) (1.3.1)
Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<4,>=3.1.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook->jupyter->cebra[datasets,demos]) (1.2.2)
Downloading cebra-0.4.0-py2.py3-none-any.whl (202 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 202.2/202.2 kB 1.9 MB/s eta 0:00:00
Downloading hdf5storage-0.1.19-py2.py3-none-any.whl (53 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 53.6/53.6 kB 917.6 kB/s eta 0:00:00
Downloading jupyter-1.0.0-py2.py3-none-any.whl (2.7 kB)
Downloading literate_dataclasses-0.0.6-py3-none-any.whl (5.0 kB)
Downloading nlb_tools-0.0.4-py3-none-any.whl (39 kB)
Downloading pandas-1.3.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.5 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.5/11.5 MB 20.0 MB/s eta 0:00:00
Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)
Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)
Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)
Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
Downloading pynwb-2.8.1-py3-none-any.whl (1.4 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.4/1.4 MB 2.9 MB/s eta 0:00:00
Downloading qtconsole-5.5.2-py3-none-any.whl (123 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 123.4/123.4 kB 5.1 MB/s eta 0:00:00
Downloading hdmf-3.14.2-py3-none-any.whl (336 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 336.0/336.0 kB 19.1 MB/s eta 0:00:00
Downloading jedi-0.19.1-py2.py3-none-any.whl (1.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 38.3 MB/s eta 0:00:00
Downloading QtPy-2.4.1-py3-none-any.whl (93 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 93.5/93.5 kB 6.0 MB/s eta 0:00:00
Downloading nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl (21.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.3/21.3 MB 55.8 MB/s eta 0:00:00
Downloading ruamel.yaml-0.18.6-py3-none-any.whl (117 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 117.8/117.8 kB 5.2 MB/s eta 0:00:00
Downloading ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (526 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 526.7/526.7 kB 20.5 MB/s eta 0:00:00
Installing collected packages: ruamel.yaml.clib, qtpy, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, literate-dataclasses, jedi, ruamel-yaml, pandas, nvidia-cusparse-cu12, nvidia-cudnn-cu12, hdf5storage, nvidia-cusolver-cu12, qtconsole, hdmf, pynwb, cebra, nlb-tools, jupyter
Attempting uninstall: pandas
Found existing installation: pandas 2.0.3
Uninstalling pandas-2.0.3:
Successfully uninstalled pandas-2.0.3
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 1.11.1 requires pandas>=1.5.0, but you have pandas 1.3.4 which is incompatible.
cudf-cu12 24.4.1 requires pandas<2.2.2dev0,>=2.0, but you have pandas 1.3.4 which is incompatible.
google-colab 1.0.0 requires pandas==2.0.3, but you have pandas 1.3.4 which is incompatible.
mizani 0.9.3 requires pandas>=1.3.5, but you have pandas 1.3.4 which is incompatible.
plotnine 0.12.4 requires pandas>=1.5.0, but you have pandas 1.3.4 which is incompatible.
statsmodels 0.14.2 requires pandas!=2.1.0,>=1.4, but you have pandas 1.3.4 which is incompatible.
xarray 2023.7.0 requires pandas>=1.4, but you have pandas 1.3.4 which is incompatible.
Successfully installed cebra-0.4.0 hdf5storage-0.1.19 hdmf-3.14.2 jedi-0.19.1 jupyter-1.0.0 literate-dataclasses-0.0.6 nlb-tools-0.0.4 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.5.82 nvidia-nvtx-cu12-12.1.105 pandas-1.3.4 pynwb-2.8.1 qtconsole-5.5.2 qtpy-2.4.1 ruamel-yaml-0.18.6 ruamel.yaml.clib-0.2.8
[2]:
#import CEBRA:
import cebra.datasets
from cebra import CEBRA
from cebra.datasets import get_datapath
[ ]:
# ONLY for using for t-sne, umap, and piVAE:
!pip install openTSNE
!pip install umap-learn
!git clone --depth 1 --filter=blob:none --sparse https://github.com/AdaptiveMotorControlLab/CEBRA.git && cd CEBRA && git sparse-checkout set third_party
!pip install tensorflow keras
[4]:
#import other packages:
import matplotlib.pyplot as plt
import numpy as np
import torch
import joblib as jl
import sklearn.linear_model
import openTSNE
import umap
import keras
[ ]:
#import piVAE:
import os
import sys
sys.path.insert(0, '/content/CEBRA/third_party')
# import
import pivae.pivae_code.pi_vae as pivae
from tensorflow.keras.callbacks import ModelCheckpoint
Let’s load the data:#
The data will be automatically downloaded into a
/data
folder.
[6]:
dataset = cebra.datasets.init("continuous-label-poisson")
100%|██████████| 30.1M/30.1M [00:02<00:00, 14.2MB/s]
Download complete. Dataset saved in 'data/synthetic/continuous_label_poisson.jl'
[7]:
data = dataset.data
plt.scatter(data['z'][:, 0], data['z'][:, 1], c=data['u'], s=1, cmap='cool')
plt.axis('off')
[7]:
(-3.411717481214122, 9.427972850408525, -4.042572592877058, 4.03572573445655)

Define the reconstruction score we use for all methods:#
[8]:
def reconstruction_score(x, y):
def _linear_fitting(x, y):
lin_model = sklearn.linear_model.LinearRegression()
lin_model.fit(x, y)
return lin_model.score(x, y), lin_model.predict(x)
return _linear_fitting(x, y)
CEBRA#
Define & fit a 🦓 CEBRA model:
For a quick CPU run-time demo, you can drop
max_iterations
to 500; otherwise set to 10,000.
[9]:
max_iterations = 5000
[10]:
cebra_model = CEBRA(
model_architecture="offset1-model-mse",
batch_size=512,
learning_rate=1e-4,
max_iterations=max_iterations,
delta=0.1,
conditional='delta',
output_dimension=2,
distance='euclidean',
device="cuda_if_available",
verbose=True,
)
[11]:
cebra_model.fit(data['x'][:12000], data['u'][:12000])
pos: 0.6551 neg: 4.5135 total: 5.1686 temperature: 1.0000: 100%|██████████| 5000/5000 [16:31<00:00, 5.04it/s]
[11]:
CEBRA(batch_size=512, conditional='delta', delta=0.1, distance='euclidean', learning_rate=0.0001, max_iterations=5000, model_architecture='offset1-model-mse', output_dimension=2, verbose=True)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
CEBRA(batch_size=512, conditional='delta', delta=0.1, distance='euclidean', learning_rate=0.0001, max_iterations=5000, model_architecture='offset1-model-mse', output_dimension=2, verbose=True)
[12]:
cebra_output = cebra_model.transform(data['x'])
cebra_score, transformed_cebra_z = reconstruction_score(cebra_output,
data['z'][:, :2])
print(f"linear reconstruction score: {cebra_score}")
plt.scatter(transformed_cebra_z[:, 0],
transformed_cebra_z[:, 1],
c=data['u'],
s=1,
cmap='cool')
plt.axis('off')
linear reconstruction score: 0.9155677466979806
[12]:
(-0.7319542288780212,
8.35180081129074,
-2.9202826380729676,
2.9944043517112733)

t-SNE#
[13]:
tsne_model = openTSNE.TSNE(perplexity=84,
n_components=2,
initialization='pca',
random_state=None,
metric='euclidean')
[14]:
tsne_model.fit(data['x'])
tsne_output = tsne_model.fit(data['x'])
[15]:
tsne_score, transformed_tsne_z = reconstruction_score(tsne_output,
data['z'][:, :2])
print(f"linear reconstruction score: {tsne_score}")
plt.scatter(transformed_tsne_z[:, 0],
transformed_tsne_z[:, 1],
c=data['u'],
s=1,
cmap='cool')
plt.axis('off')
linear reconstruction score: 0.788776430987681
[15]:
(-0.5648915170388875,
7.0257705428562085,
-2.726191011070682,
2.5588439549227457)

UMAP#
[16]:
umap_model = umap.UMAP(n_neighbors=68,
min_dist=0.2475,
n_components=2,
random_state=None,
metric='euclidean')
umap_output = umap_model.fit_transform(data['x'])
[17]:
umap_score, transformed_umap_z = reconstruction_score(umap_output,
data['z'][:, :2])
print(f"linear reconstruction score: {umap_score}")
plt.scatter(transformed_umap_z[:, 0],
transformed_umap_z[:, 1],
c=data['u'],
s=1,
cmap='cool')
plt.axis('off')
linear reconstruction score: 0.8302637230854306
[17]:
(-0.45553902685642245,
6.961537101864815,
-2.585074985027313,
2.8202560782432555)

piVAE#
The dataset parsing, model configuration and training are all adapted from zhd96/pi-vae
piVAE has a long run time. We store checkpoints.
[ ]:
u_true = data['u']
z_true = data['z']
x_true = data['x']
x_all = x_true.reshape(50, 300, -1)
u_all = u_true.reshape(50, 300, -1)
x_train = x_all[:40]
u_train = u_all[:40]
x_valid = x_all[40:]
u_valid = u_all[40:]
print(f'Train set has {len(x_train)} samples')
[ ]:
def custom_data_generator(x_train, u_train):
while True:
for i in range(len(x_train)):
yield [x_train[i], u_train[i]], x_train[i] # or appropriate target
[ ]:
vae = pivae.vae_mdl(dim_x=x_all[0].shape[-1],
dim_z=2,
dim_u=u_all[0].shape[-1],
gen_nodes=60,
n_blk=2,
mdl='poisson',
disc=False,
learning_rate=5e-4)
[ ]:
model_chk_path = 'synthetic_pivae.h5'
mcp = ModelCheckpoint(model_chk_path,
monitor="val_loss",
save_best_only=True,
save_weights_only=True)
s_n = vae.fit(pivae.custom_data_generator(x_train, u_train),
steps_per_epoch=len(x_train),
epochs=100,
verbose=0,
validation_data=pivae.custom_data_generator(
x_valid, u_valid),
validation_steps=len(x_valid),
callbacks=[mcp])
[ ]:
vae.load_weights(model_chk_path)
outputs = vae.predict_generator(pivae.custom_data_generator(x_all, u_all),
steps=len(x_all))
#post_mean, post_log_var, z_sample,fire_rate, lam_mean, lam_log_var, z_mean, z_log_var
z_post = outputs[0]
[ ]:
pivae_score, transformed_pivae_z = reconstruction_score(z_post,
data['z'][:, :2])
print(f"linear reconstruction score: {pivae_score}")
plt.scatter(transformed_pivae_z[:, 0],
transformed_pivae_z[:, 1],
c=data['u'],
s=1,
cmap='cool')
plt.axis('off')