Skip to content

microsoft/automated-brain-explanations

Repository files navigation

🧠 Automated brain explanations 🧠

How does the brain process language? We've been studying this question using large-scale brain-imaging datasets collected from human subjects as they read and listen to stories. Along the way, we've used LLMs to help us predict and explain patterns in this data and found a bunch of cool things! This repo contains code for doing these analyses & applying the tools we've developed to various domains.

Reference

This repo contains code underlying two neuroscience studies:

Generative causal testing to bridge data-driven models and scientific theories in language neuroscience (Antonello*, Singh*, et al., 2024, arXiv)
Evaluating scientific theories as predictive models in language neuroscience (Singh*, Antonello*, et al. 2024, in prep)

This repo also contains code for experiments in three ML papers (for a simple scikit-learn interface to use these, see imodelsX):

Augmenting interpretable models with large language models during training (Singh et al. 2023, Nature communications)
QA-Emb: Crafting interpretable Embeddings by asking LLMs questions (Benara*, Singh* et al. 2024, NeurIPS)
SASC: Explaining black box text modules in natural language with language models (Singh*, Hsu*, et al. 2023, NeurIPS workshop)
SASC takes in a text module and produces a natural explanation for it that describes what it types of inputs elicit the largest response from the module (see Fig below). The GCT paper tests this in detail in an fMRI setting.

SASC is similar to the nice concurrent paper by OpenAI, but simplifies explanations to describe the function rather than produce token-level activations. This makes it simpler/faster, and makes it more effective at describing semantic functions from limited data (e.g. fMRI voxels) but worse at finding patterns that depend on sequences / ordering.

To use with imodelsX, install with pip install imodelsx then the below shows a quickstart example.

from imodelsx import explain_module_sasc
# a toy module that responds to the length of a string
mod = lambda str_list: np.array([len(s) for s in str_list])

# a toy dataset where the longest strings are animals
text_str_list = ["red", "blue", "x", "1", "2", "hippopotamus", "elephant", "rhinoceros"]
explanation_dict = explain_module_sasc(
    text_str_list,
    mod,
    ngrams=1,
)

Setting up

Dataset

  • The data/decoding folder contains a quickstart easy example for TR-level decoding
    • it has everything needed, but if you want to visualize the results on a flatmap, you need to download the relevant PCs from here
  • to quickstart, just download the responses / wordsequences for 3 subjects from the encoding scaling laws paper
    • this is all the data you need if you only want to analyze 3 subjects and don't want to make flatmaps
  • to run Eng1000, need to grab em_data directory from here and move its contents to {root_dir}/em_data
  • for more, download data with python experiments/00_load_dataset.py
    • create a data dir under wherever you run it and will use datalad to download the preprocessed data as well as feature spaces needed for fitting semantic encoding models
  • to make flatmaps, need to set [pycortex filestore] to {root_dir}/ds003020/derivative/pycortex-db/

Code

  • pip install -e . from the repo directory to locally install the neuro package
  • set neuro.config.root_dir/data to where you put all the data
    • loading responses
      • neuro.data.response_utils function load_response
      • loads responses from at {neuro.config.root_dir}/ds003020/derivative/preprocessed_data/{subject}, where they are stored in an h5 file for each story, e.g. wheretheressmoke.h5
    • loading stimulus
      • ridge_utils.features.stim_utils function load_story_wordseqs
      • loads textgrids from {root_dir}/ds003020/derivative/TextGrids, where each story has a TextGrid file, e.g. wheretheressmoke.TextGrid
      • uses {root_dir}/ds003020/derivative/respdict.json to get the length of each story
  • python experiments/02_fit_encoding.py
    • This script takes many relevant arguments through argparse

Reference