Official PyTorch/PyTorch Geometric implementation of HGAIT, a financial time-series graph model that combines time-series embeddings with graph-based dependencies among stocks to predict the next-step return for each stock as a regression task.
- In-model representation pipeline: TimeMixing (per-feature temporal embedding) → Variable Attention (inverted multi-head attention) → Feature Importance tokenization
- Graph neighbor attention: Correlation-based top/bottom neighbor selection and gated combination of multiple GATConv outputs
- Combined regression loss: MSE + Pearson-correlation loss (
utils.combined_loss) - Training utilities: Early stopping, ReduceLROnPlateau scheduler, optional Weights & Biases logging
- Modes:
mlp|lstm|gru - For each feature, embeds a length-
Lsequence into ad_model-dimensional representation
- Learns cross-feature interactions via inverted multi-head attention
- Produces a feature-importance distribution across features
- Forms a stock-level representation by importance-weighted aggregation of feature embeddings
- Builds top/bottom neighbors from correlation of returns at time t
- Combines outputs from self/top/bottom GATConv branches using learned importance weights
LayerNorm → MLP(d_model → 2*d_model → 1)to regress the next-step return
HGAIT/
├── train.py # End-to-end train/val/test script
├── model.py # HGAIT model (TimeMixing, Variable Attention, Gated GAT, Predictor)
├── data_generator.py # Raw CSV → per-date graph samples (.pkl)
├── Dataset_reg_std.py # Per-date PyG Dataset
├── Dataloader.py # Thin wrapper for PyG DataLoader
├── config.py # Hyperparameters and paths
├── utils/
│ └── utils.py # Seeding, metrics/loss, ranking utilities
├── raw_data/
│ ├── csv_maker.py
│ └── CRSP.csv (required)
├── data/ # Output directory for generated .pkl files
├── results/ # Predictions/metrics/CSV outputs
└── model/ # Checkpoints (path configured in config)
Below are the key dependencies used by this project.
pip install numpy==2.2.2 pandas==2.2.3 scikit-learn==1.6.1 tqdm wandb==0.17.5
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0
pip install torch_geometric==2.5.3
# PyG extensions (install variants matching your Torch/CUDA)
pip install \
torch-scatter==2.1.2+pt21cu121 \
torch-sparse==0.6.18+pt21cu121 \
torch-cluster==1.6.3+pt21cu121 \
torch-spline-conv==1.2.2+pt21cu121- On Windows, PyG and its extensions must match your PyTorch/CUDA versions. Please follow the official installation guides and wheels for your environment.
Generate per-date graph samples from raw_data/CRSP.csv. Update file paths inside data_generator.py to fit your environment.
python data_generator.pyUpon success, data/ will contain files like YYYY-MM-DD.pkl.
Adjust config.py to your environment, then run:
python train.pyOutputs:
- Checkpoint:
Config.best_model_path - Results pickle:
results/results_*.pkl(per-date prediction/label tensors) - Results CSV:
results_<seed>_<lr>_<heads>_<heads>_<layers>.csv
Notes:
CUDA_VISIBLE_DEVICES="1"is set insidetrain.py. Modify or comment it out for your setup.model.HGAITrequires amodeargument ('gru'|'lstm'|'mlp'). Ensure it is passed where the model is constructed intrain.py.
All settings are managed via the Config class in config.py.
- Paths:
data_dir,best_model_path,result_dir - Data split:
train_split,val_length - Model:
n_heads,d_model,n_layers,n_neighbors,mode - Training:
batch_size,num_epochs,learning_rate,early_stopping_patience - Loss:
mse_lambda - Logging:
log_to_wandb
Users should replace absolute paths like /workspace/dongwoo/HGAIT/... with local paths or use relative paths.
If this repository is useful for your research, please consider citing it (example format):
@article{lee2025hgait,
title={HGAIT: Heterogeneous Graph Attention with Inverted Transformers for Correlation-Aware Stock Return Prediction},
author={Lee, Dongwoo and Ock, Seungeun and Song, Jae Wook},
journal={Expert Systems with Applications},
pages={129292},
year={2025},
publisher={Elsevier}
}- PyTorch and PyTorch Geometric for deep learning and graph neural network tooling
- Weights & Biases for experiment tracking