Skip to content

et22/stsbench

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

2 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

STSBench: A Large-Scale Dataset for Modeling Neuronal Activity in the Dorsal Stream

πŸ“š Paper | πŸ“Š Dataset

Activations

Activations

Activations

Activations

We present STSBench, a large-scale dataset for modeling neuronal response in the dorsal stream. We show that our dataset can be used for:

  1. Benchmarking encoding models of dorsal stream responses.
  2. Reconstructing visual input from neural activity in the dorsal stream.
  3. Comparing the features of the visual world extracted by the dorsal and ventral visual streams.

We provide code necessary to reproduce our results, and encourage others to build upon our baseline models and to use STSBench in their own work.

Repo organization

The encoding and reconstruction directories contain the code for training and evaluating encoding and reconstruction models, and the plotting directory contains code for visualizing results and constructing tables. Preprocessing scripts for constructing STSBench from raw data are in preprocessing - these are provided for reference as the dataset is the output of these preprocessing steps. See the repo structure section below for more details.

Requirements

To install requirements:

pip install -r requirements.txt

To setup with Docker:

docker compose run -p 0.0.0.0:10116:8888 -d -e CUDA_VISIBLE_DEVICES=0 --name mt_bench jupyter_dev --gpus '"device=1"'

To access Jupyter Lab in the container, navigate to localhost:10116 with the password 1234.

Downloading and preprocessing STSBench

We provide preprocessed neural data and metadata in this repo for conveinence. To train and evaluate dorsal stream encoding or reconstruction models, you need to download the corresponding stimuli from the STSBench dataset and place them in dataset/dorsal_stream/.

Downloading and preprocessing TVSD

We provide preprocessed neural data from the Things Ventral Stream Dataset (TVSD) in this repo for conveinence. To train and evaluate ventral stream reconstruction models, you need to download the images from the THINGS image dataset and run preprocessing/preprocess_ventral_dataset.ipynb to convert the images into .mp4 videos in a directory called dataset/ventral_stream/. Note that you just need to specify the appropriate path to the THINGS dataset and run the last cell in the notebok to do this. The other cells in the notebook are for preprocessing the neural data from TVSD.

Training

To train all the encoding models in the paper with a grid search over hyperparameters, run this command:

sh grid_search_all.sh

To train the diffusion reconstruction models in the paper, run these commands:

python3 train_vqvae.py --config configs/dorsal_stream_diffusion.yaml
python3 train_ddpm_cond.py --config configs/dorsal_stream_diffusion.yaml

python3 train_vqvae.py --config configs/ventral_stream_diffusion.yaml
python3 train_ddpm_cond.py --config configs/ventral_stream_diffusion.yaml

python3 train_vqvae.py --config configs/dorsal_stream_diffusion_video.yaml
python3 train_ddpm_cond.py --config configs/dorsal_stream_diffusion_video.yaml

Evaluation

To evaluate one encoding model on STSBench, download the encoding model checkpoints (see below) and dataset, then run the following command:

python3 test.py --config ./configs/dorsal_stream_simple3d5.yaml

Any of models can be substituted for simple3d5 (the 3D CNN-1 model) here by replacing the config.

To evaluate the reconstruction diffusion models on STSBench, download the reconstruction model checkpoints (see below) and dataset, then run the following commands:

python3 sample_ddpm_cond.py --config ./configs/dorsal_stream_diffusion.yaml
python3 eval.py --config ./configs/dorsal_stream_diffusion.yaml

Pre-trained Models

You can download pretrained encoding and reconstruction models here:

  • STSBench Checkpoints
  • The encoding_checkpoints folder should be renamed checkpoints and placed in the encoding directory as indicated above.
  • The reconstruction_checkpoints folder should be renamed checkpoints and placed in the reconstruction directory as indicated above.

The configs used to train the encoding models are in encoding/configs, and the loss function hyperparameters used for training were determined via a grid search and are specified separately in the corresponding encoding/logs/*_grid.npy file. The configs used to train the reconstruction models are in reconstruction/configs.

Results

Our encoding models achieve the following performance ($R^2$ on test set) at predicting neuronal firing rates, averaged across neurons in the dataset:

Training Scheme Model name $R^2$
End-to-end 3D CNN-5 0.338
Pretrained 3D ResNet-Kinetics 0.303
Pretrained 3D ResNet-Motion 0.289
Hand-tuned 3D Gabor 0.266
Pretrained 2D ResNet-ImageNet 0.185

See the associated paper for a detailed description of each model.

Our diffusion models acheive the following performance (LPIPS and PSNR on test set) at reconstructing images from neuronal activity:

Dataset PSNR LPIPS
STSBench 14.16 0.67
TVSD 10.63 0.59

See the associated paper for a detailed description of each model, and additional comparisons to null models and baselines.

Contributing

The code and models is this repository are released under the MIT License. Images or videos included in this repository for reproducibility purposes and code from other repositories are bound by their original licenses. Please cite the corresponding paper if you use the code or dataset in your work.

Acknowledgements

  1. The diffusion model code is a port of the PyTorch Stable Diffusion Implementation from Explaining AI.
  2. The encoding model readout and training code was adapted from models released with the Things Ventral Stream Dataset.
  3. The 3D ResNet-Self Motion and 3D Gabor model code was adapted from Your Head is There to Move You Around.
  4. We include the TVSD neural data here for ease of use because it was released under a CC license, but any reuse of this data should cite the original paper.
  5. Code for plotting results and making LaTeX tables was written with assistance from ChatGPT. Any LLM generated code was proofread to ensure correctness.

Repo Structure

β”œβ”€β”€ assets                                          # Gifs to display in README

β”œβ”€β”€ dataset
β”‚   β”œβ”€β”€ dorsal_stream                                # Drop *.mp4 videos for STSBench here 
β”‚       ...
β”‚   β”œβ”€β”€ ventral_stream                               # Drop *.mp4 videos for TVSD here 
β”‚       ...
β”‚   β”œβ”€β”€ dorsal_stream_neuron_table.pickle            # Neuron metadata for STSBench
β”‚   β”œβ”€β”€ dorsal_stream_dataset.pickle                 # Neural activity & stimulus IDs for STSBench
β”‚   β”œβ”€β”€ ventral_stream_dataset.pickle                # Neural activity & stimulus IDs for TVSD (V4 subset)

β”œβ”€β”€ preprocessing                                   
β”‚Β Β  β”œβ”€β”€ construct_dataset_neurips.ipynb             # Notebook to construct STSBench dataset
β”‚Β Β  β”œβ”€β”€ postprocess_fixation_rf_neurips.ipynb       # Postprocessing receptive field mapping task data
β”‚Β Β  β”œβ”€β”€ postprocess_fixation_video_neurips.ipynb    # Postprocessing video fixation task data
β”‚Β Β  β”œβ”€β”€ postprocess_neuron_properties.ipynb         # Adding neuron metadata to the dataset
β”‚Β Β  β”œβ”€β”€ preprocess_ventral_dataset.ipynb            # Preprocess ventral stream data from TVSD to the same format
β”‚Β Β  β”œβ”€β”€ run_all.sh                                  # Shell script to execute dorsal stream preprocessing notebooks in order
β”‚Β Β  └── utils.py                                    # General utility functions used in preprocessing scripts

β”œβ”€β”€ plotting
β”‚   β”œβ”€β”€ figures
β”‚       ...
β”‚   β”œβ”€β”€ tables
β”‚       ...
β”‚   β”œβ”€β”€ make_table1.py                               # Write encoding metrics to LaTeX table
β”‚   β”œβ”€β”€ make_table2.py                               # Write reconstruction metrics to LaTeX table
β”‚   β”œβ”€β”€ plot_encoding_model_arch.ipynb               # Plot example features and predictions
β”‚   β”œβ”€β”€ plot_encoding_model_perf_layers.ipynb        # Plot performance over layers
β”‚   β”œβ”€β”€ plot_reconstruction.py                       # Plot reconstructions as figure
β”‚   β”œβ”€β”€ plot_reconstruction_gif.py                   # Plot reconstructions as gif
β”‚   β”œβ”€β”€ plot_reconstruction.sh                       # Script to plot reconstructions for selected images
β”‚   β”œβ”€β”€ plot_rfs_waveforms.ipynb                     # Visual RFs and spike waveforms filters
β”‚   └── visualize_learned_filters.ipynb              # Visual CNN filters in the Conv3D-1 Model

β”œβ”€β”€ encoding                                         
β”‚   β”œβ”€β”€ baselines                                    # Baseline encoding models
β”‚   β”‚   β”œβ”€β”€ dorsalnet                                # 3D ResNet - Self Motion model
β”‚   β”‚   β”‚   β”œβ”€β”€ checkpoints                          
β”‚   β”‚   β”‚   β”‚   └── dorsalnet.pt                     # Checkpoint for 3D ResNet - Self Motion model
β”‚   β”‚   β”‚   β”œβ”€β”€ dorsal_net.py                        # Model for 3D ResNet - Self Motion 
β”‚   β”‚   β”‚   └── resblocks.py                         # Helpers for 3D ResNet - Self Motion
β”‚   β”‚   β”œβ”€β”€ gaborpyramid                             
β”‚   β”‚   β”‚   └── gabor_pyramid.py                     # Model for 3D Gabor 
β”‚   β”‚   └── simple3d.py                              # Models for 3D CNNs trained end-to-end
β”‚   β”œβ”€β”€ checkpoints                                  
β”‚   β”‚   β”œβ”€β”€ dorsal_stream_dorsalnet_112_res0.pth     # Encoding model checkpoint
β”‚   β”‚   └── ...                                      
β”‚   β”œβ”€β”€ configs                                      
β”‚   β”‚   β”œβ”€β”€ dorsal_stream_dorsalnet.yaml             # Config file for 3D ResNet - Self Motion model
β”‚   β”‚   └── ...
β”‚   β”œβ”€β”€ logs                                         
β”‚   β”‚   β”œβ”€β”€ dorsal_stream_dorsalnet_112_res0_grid.npy  # Results of grid search
β”‚   β”‚   β”œβ”€β”€ dorsal_stream_dorsalnet_112_res0_test.npy  # Test set correlation to average
β”‚   β”‚   β”œβ”€β”€ dorsal_stream_dorsalnet_112_res0_train.txt # Training logs
β”‚   β”‚   └── ...
β”‚   β”œβ”€β”€ dataloader.py                                # Data loader utilities
β”‚   β”œβ”€β”€ dataset.py                                   # Dataset class for stimuli and neural data
β”‚   β”œβ”€β”€ model.py                                     # Readout architecture and generic feature extractor 
β”‚   β”œβ”€β”€ grid_search.py                               # Script to grid search model with specific config
β”‚   β”œβ”€β”€ test.py                                      # Script for testing model with specific config
β”‚   β”œβ”€β”€ train.py                                     # Script for training model with specific config
β”‚   β”œβ”€β”€ grid_search_all.sh                           # Script to run grid search and eval all encoding models 
β”‚   β”œβ”€β”€ train_all.sh                                 # Script to train and eval all encoding models 
β”‚   └── utils.py

│── decoding
    └── optic_flow_decoding.ipynb                     # Notebook to train and test optic flow decoders 
                            
└── reconstruction
    β”œβ”€β”€ LICENSE                                      # License for diffusion model code from Explaining AI 
    β”œβ”€β”€ checkpoints                                  # Model checkpoints for diffusion and baselines
    β”‚   β”œβ”€β”€ dorsal_stream
    β”‚       ...
    β”‚   └── ventral_stream
    β”‚       ...
    β”œβ”€β”€ configs                                      
    β”‚   β”œβ”€β”€ dorsal_stream_diffusion.yaml              # Model configuration files for diffusion
    β”‚   ...

    β”œβ”€β”€ logs                                          # Training logs and sampled images for specific models
    β”‚   β”œβ”€β”€ dorsal_stream
            ...
    β”‚   └── ventral_stream
    β”‚       ...
    β”‚   β”œβ”€β”€ dorsal_diffusion_LPIPS.npy                # LPIPS metric evaluated for diffusion
    β”‚   β”œβ”€β”€ dorsal_diffusion_PSNR.npy                 # PSNR metric evaluated for diffusion
    |   ...
    β”œβ”€β”€ models                                        # LDM components, see Explaining AI
    β”‚   β”œβ”€β”€ __init__.py
    β”‚   β”œβ”€β”€ blocks.py
    β”‚   β”œβ”€β”€ discriminator.py
    β”‚   β”œβ”€β”€ lpips.py
    β”‚   β”œβ”€β”€ unet_cond_base.py
    β”‚   β”œβ”€β”€ vqvae.py
    β”‚   └── weights
    β”‚       └── v0.1
    β”‚           └── vgg.pth                           # VGG checkpoint for LPIPS
    β”œβ”€β”€ scheduler                                     # LDM noise scheduler, see Explaining AI
    β”‚   β”œβ”€β”€ __init__.py
    β”‚   └── linear_noise_scheduler.py
    β”œβ”€β”€ __init__.py
    β”œβ”€β”€ baselines.py                                  # Linear and CNN baselines and training code
    β”œβ”€β”€ dataloader.py                                  
    β”œβ”€β”€ dataset.py
    β”œβ”€β”€ train_vqvae.py                                # Script to train VQVAE from scratch for LDM
    β”œβ”€β”€ train_ddpm_cond.py                            # Script to train LDM from scratch
    β”œβ”€β”€ sample_ddpm_cond.py                           # Script to sample from trained LDM
    β”œβ”€β”€ eval.py                                       # Script to compute metrics for reconstructions 
    └── utils.py

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published