We present STSBench, a large-scale dataset for modeling neuronal response in the dorsal stream. We show that our dataset can be used for:
- Benchmarking encoding models of dorsal stream responses.
- Reconstructing visual input from neural activity in the dorsal stream.
- Comparing the features of the visual world extracted by the dorsal and ventral visual streams.
We provide code necessary to reproduce our results, and encourage others to build upon our baseline models and to use STSBench in their own work.
The encoding and reconstruction directories contain the code for training and evaluating encoding and reconstruction models, and the plotting directory contains code for visualizing results and constructing tables. Preprocessing scripts for constructing STSBench from raw data are in preprocessing - these are provided for reference as the dataset is the output of these preprocessing steps. See the repo structure section below for more details.
To install requirements:
pip install -r requirements.txt
To setup with Docker:
docker compose run -p 0.0.0.0:10116:8888 -d -e CUDA_VISIBLE_DEVICES=0 --name mt_bench jupyter_dev --gpus '"device=1"'
To access Jupyter Lab in the container, navigate to localhost:10116 with the password 1234.
We provide preprocessed neural data and metadata in this repo for conveinence. To train and evaluate dorsal stream encoding or reconstruction models, you need to download the corresponding stimuli from the STSBench dataset and place them in dataset/dorsal_stream/.
We provide preprocessed neural data from the Things Ventral Stream Dataset (TVSD) in this repo for conveinence. To train and evaluate ventral stream reconstruction models, you need to download the images from the THINGS image dataset and run preprocessing/preprocess_ventral_dataset.ipynb to convert the images into .mp4 videos in a directory called dataset/ventral_stream/. Note that you just need to specify the appropriate path to the THINGS dataset and run the last cell in the notebok to do this. The other cells in the notebook are for preprocessing the neural data from TVSD.
To train all the encoding models in the paper with a grid search over hyperparameters, run this command:
sh grid_search_all.sh
To train the diffusion reconstruction models in the paper, run these commands:
python3 train_vqvae.py --config configs/dorsal_stream_diffusion.yaml
python3 train_ddpm_cond.py --config configs/dorsal_stream_diffusion.yaml
python3 train_vqvae.py --config configs/ventral_stream_diffusion.yaml
python3 train_ddpm_cond.py --config configs/ventral_stream_diffusion.yaml
python3 train_vqvae.py --config configs/dorsal_stream_diffusion_video.yaml
python3 train_ddpm_cond.py --config configs/dorsal_stream_diffusion_video.yaml
To evaluate one encoding model on STSBench, download the encoding model checkpoints (see below) and dataset, then run the following command:
python3 test.py --config ./configs/dorsal_stream_simple3d5.yaml
Any of models can be substituted for simple3d5 (the 3D CNN-1 model) here by replacing the config.
To evaluate the reconstruction diffusion models on STSBench, download the reconstruction model checkpoints (see below) and dataset, then run the following commands:
python3 sample_ddpm_cond.py --config ./configs/dorsal_stream_diffusion.yaml
python3 eval.py --config ./configs/dorsal_stream_diffusion.yaml
You can download pretrained encoding and reconstruction models here:
- STSBench Checkpoints
- The
encoding_checkpointsfolder should be renamedcheckpointsand placed in theencodingdirectory as indicated above. - The
reconstruction_checkpointsfolder should be renamedcheckpointsand placed in thereconstructiondirectory as indicated above.
The configs used to train the encoding models are in encoding/configs, and the loss function hyperparameters used for training were determined via a grid search and are specified separately in the corresponding encoding/logs/*_grid.npy file. The configs used to train the reconstruction models are in reconstruction/configs.
Our encoding models achieve the following performance (
| Training Scheme | Model name | |
|---|---|---|
| End-to-end | 3D CNN-5 | 0.338 |
| Pretrained | 3D ResNet-Kinetics | 0.303 |
| Pretrained | 3D ResNet-Motion | 0.289 |
| Hand-tuned | 3D Gabor | 0.266 |
| Pretrained | 2D ResNet-ImageNet | 0.185 |
See the associated paper for a detailed description of each model.
Our diffusion models acheive the following performance (LPIPS and PSNR on test set) at reconstructing images from neuronal activity:
| Dataset | PSNR | LPIPS |
|---|---|---|
| STSBench | 14.16 | 0.67 |
| TVSD | 10.63 | 0.59 |
See the associated paper for a detailed description of each model, and additional comparisons to null models and baselines.
The code and models is this repository are released under the MIT License. Images or videos included in this repository for reproducibility purposes and code from other repositories are bound by their original licenses. Please cite the corresponding paper if you use the code or dataset in your work.
- The diffusion model code is a port of the PyTorch Stable Diffusion Implementation from Explaining AI.
- The encoding model readout and training code was adapted from models released with the Things Ventral Stream Dataset.
- The 3D ResNet-Self Motion and 3D Gabor model code was adapted from Your Head is There to Move You Around.
- We include the TVSD neural data here for ease of use because it was released under a CC license, but any reuse of this data should cite the original paper.
- Code for plotting results and making LaTeX tables was written with assistance from ChatGPT. Any LLM generated code was proofread to ensure correctness.
βββ assets # Gifs to display in README
βββ dataset
β βββ dorsal_stream # Drop *.mp4 videos for STSBench here
β ...
β βββ ventral_stream # Drop *.mp4 videos for TVSD here
β ...
β βββ dorsal_stream_neuron_table.pickle # Neuron metadata for STSBench
β βββ dorsal_stream_dataset.pickle # Neural activity & stimulus IDs for STSBench
β βββ ventral_stream_dataset.pickle # Neural activity & stimulus IDs for TVSD (V4 subset)
βββ preprocessing
βΒ Β βββ construct_dataset_neurips.ipynb # Notebook to construct STSBench dataset
βΒ Β βββ postprocess_fixation_rf_neurips.ipynb # Postprocessing receptive field mapping task data
βΒ Β βββ postprocess_fixation_video_neurips.ipynb # Postprocessing video fixation task data
βΒ Β βββ postprocess_neuron_properties.ipynb # Adding neuron metadata to the dataset
βΒ Β βββ preprocess_ventral_dataset.ipynb # Preprocess ventral stream data from TVSD to the same format
βΒ Β βββ run_all.sh # Shell script to execute dorsal stream preprocessing notebooks in order
βΒ Β βββ utils.py # General utility functions used in preprocessing scripts
βββ plotting
β βββ figures
β ...
β βββ tables
β ...
β βββ make_table1.py # Write encoding metrics to LaTeX table
β βββ make_table2.py # Write reconstruction metrics to LaTeX table
β βββ plot_encoding_model_arch.ipynb # Plot example features and predictions
β βββ plot_encoding_model_perf_layers.ipynb # Plot performance over layers
β βββ plot_reconstruction.py # Plot reconstructions as figure
β βββ plot_reconstruction_gif.py # Plot reconstructions as gif
β βββ plot_reconstruction.sh # Script to plot reconstructions for selected images
β βββ plot_rfs_waveforms.ipynb # Visual RFs and spike waveforms filters
β βββ visualize_learned_filters.ipynb # Visual CNN filters in the Conv3D-1 Model
βββ encoding
β βββ baselines # Baseline encoding models
β β βββ dorsalnet # 3D ResNet - Self Motion model
β β β βββ checkpoints
β β β β βββ dorsalnet.pt # Checkpoint for 3D ResNet - Self Motion model
β β β βββ dorsal_net.py # Model for 3D ResNet - Self Motion
β β β βββ resblocks.py # Helpers for 3D ResNet - Self Motion
β β βββ gaborpyramid
β β β βββ gabor_pyramid.py # Model for 3D Gabor
β β βββ simple3d.py # Models for 3D CNNs trained end-to-end
β βββ checkpoints
β β βββ dorsal_stream_dorsalnet_112_res0.pth # Encoding model checkpoint
β β βββ ...
β βββ configs
β β βββ dorsal_stream_dorsalnet.yaml # Config file for 3D ResNet - Self Motion model
β β βββ ...
β βββ logs
β β βββ dorsal_stream_dorsalnet_112_res0_grid.npy # Results of grid search
β β βββ dorsal_stream_dorsalnet_112_res0_test.npy # Test set correlation to average
β β βββ dorsal_stream_dorsalnet_112_res0_train.txt # Training logs
β β βββ ...
β βββ dataloader.py # Data loader utilities
β βββ dataset.py # Dataset class for stimuli and neural data
β βββ model.py # Readout architecture and generic feature extractor
β βββ grid_search.py # Script to grid search model with specific config
β βββ test.py # Script for testing model with specific config
β βββ train.py # Script for training model with specific config
β βββ grid_search_all.sh # Script to run grid search and eval all encoding models
β βββ train_all.sh # Script to train and eval all encoding models
β βββ utils.py
βββ decoding
βββ optic_flow_decoding.ipynb # Notebook to train and test optic flow decoders
βββ reconstruction
βββ LICENSE # License for diffusion model code from Explaining AI
βββ checkpoints # Model checkpoints for diffusion and baselines
β βββ dorsal_stream
β ...
β βββ ventral_stream
β ...
βββ configs
β βββ dorsal_stream_diffusion.yaml # Model configuration files for diffusion
β ...
βββ logs # Training logs and sampled images for specific models
β βββ dorsal_stream
...
β βββ ventral_stream
β ...
β βββ dorsal_diffusion_LPIPS.npy # LPIPS metric evaluated for diffusion
β βββ dorsal_diffusion_PSNR.npy # PSNR metric evaluated for diffusion
| ...
βββ models # LDM components, see Explaining AI
β βββ __init__.py
β βββ blocks.py
β βββ discriminator.py
β βββ lpips.py
β βββ unet_cond_base.py
β βββ vqvae.py
β βββ weights
β βββ v0.1
β βββ vgg.pth # VGG checkpoint for LPIPS
βββ scheduler # LDM noise scheduler, see Explaining AI
β βββ __init__.py
β βββ linear_noise_scheduler.py
βββ __init__.py
βββ baselines.py # Linear and CNN baselines and training code
βββ dataloader.py
βββ dataset.py
βββ train_vqvae.py # Script to train VQVAE from scratch for LDM
βββ train_ddpm_cond.py # Script to train LDM from scratch
βββ sample_ddpm_cond.py # Script to sample from trained LDM
βββ eval.py # Script to compute metrics for reconstructions
βββ utils.py



