This project implements a distributed training pipeline for learning representations over spatial cell graphs using PyTorch, PyTorch Geometric, and Distributed Data Parallel (DDP).
This project uses uv for dependency management.
Install uv if you don't have it:
curl -LsSf https://astral.sh/uv/install.sh | shThen create the environment and install dependencies:
uv syncThis will create a .venv/ in the project root with all dependencies from pyproject.toml / uv.lock. Prefix commands with uv run (e.g. uv run python main_ddp.py ...) or activate the venv directly:
source .venv/bin/activate- Place your dataset in the
data/pretraining/directory. - The data should be preprocessed and saved as PyTorch objects using
torch.save.
Run the the full HEIST model on all the available GPUs, use the following command to start distributed training:
python main_ddp.py --data_dir data/pretraining/ --pe --cross_message_passingIf you want to train only on single GPU, change main_ddp with main:
python main.py --data_dir data/pretraining/ --pe --cross_message_passing| Argument | Description | Default |
|---|---|---|
--data_dir |
Path to preprocessed data | data/pretraining/ |
--pe_dim |
Positional Encoding Dim | 128 |
--init_dim |
Initial MLP Hidden Dim | 128 |
--hidden_dim |
Hidden Dimension | 128 |
--output_dim |
Output Dimension | 128 |
--num_layers |
Number of MLP Layers | 10 |
--num_heads |
Number of Transformer Heads | 8 |
--batch_size |
Batch Size | 128 |
--lr |
Learning Rate | 1e-3 |
--wd |
Weight Decay | 3e-3 |
--num_epochs |
Number of Training Epochs | 20 |
- Checkpoints are saved automatically under the
saved_models/directory. - The best model is saved as
HEIST.pth.
Pre-trained HEIST checkpoints are hosted on the Hugging Face Hub. Load any of them with three lines:
from model.model import GraphEncoder
model = GraphEncoder.from_pretrained("HirenMadhu/HEIST").to(device)
model.eval()For an end-to-end tutorial on extracting cell embeddings from a pre-trained model — including preprocessing, graph construction, and visualization with PHATE — see cell_embeddings.ipynb.
If you want to resume from a saved checkpoint, ensure that the model and optimizer state dictionaries are correctly loaded in the script.
After training, you can evaluate the model using the provided evaluation script.
bash eval.shThis script will:
-
Activate the uv-managed
.venvenvironment. -
Run a series of evaluations across multiple datasets and tasks:
-
Representation Space Calculation
dfci,upmc,charville,sea,melanoma,placenta,lung
-
Tissue Classification
- Predict primary outcomes and recurrence for clinical datasets.
-
Melanoma and Cell Clustering Evaluations
-
Placenta Dataset Analysis
-
Gene Imputation Tasks
- Both standard and fine-tuned versions.
-
- Evaluation results will be stored in the corresponding directories or logged to the console.
- Ensure that the trained model checkpoint is available and named correctly (default:
HEIST).
This work is licensed under a Creative Commons Attribution 4.0 International License (CC BY 4.0).
If you find this work useful, please cite:
@inproceedings{
madhu2026heist,
title={{HEIST}: A Graph Foundation Model for Spatial Transcriptomics and Proteomics Data},
author={Hiren Madhu and Jo{\~a}o Felipe Rocha and Tinglin Huang and Siddharth Viswanath and Smita Krishnaswamy and Rex Ying},
booktitle={The Fourteenth International Conference on Learning Representations},
year={2026},
url={https://openreview.net/forum?id=lK82jpa8jr}
}