Skip to content

Commit 5c8a128

Browse files
authored
Merge pull request #48 from alan-turing-institute/train_timing_flat_bask
Add support for running different enc/decoder depths on Baskerville
2 parents 1e834b8 + 4cff12b commit 5c8a128

File tree

2 files changed

+119
-6
lines changed

2 files changed

+119
-6
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#!/bin/bash -l
2+
# vim: et:ts=4:sts=4:sw=4
3+
#SBATCH --qos turing
4+
#SBATCH --account usjs9456-ati-test
5+
#SBATCH --time 1:00:0
6+
#SBATCH --nodes 1
7+
#SBATCH --ntasks-per-node 1
8+
#SBATCH --gpus-per-node 1
9+
#SBATCH --mem 0
10+
#SBATCH --constraint=a100_80
11+
#SBATCH --job-name aurora-train
12+
#SBATCH --output bask-encoder-%a.txt
13+
14+
# Execute using:
15+
# sbatch --array=5-29 ./bask-train-ddp-1x1.sh
16+
17+
# 1 node, 1 GPU
18+
# For this we don't need to 'skip' any GPUs
19+
20+
#set -o xtrace
21+
set -o errexit
22+
23+
pushd ../../scripts
24+
25+
if [ ! -d ../../downloads ]; then
26+
echo "Please run the batch-download.sh script to download the data."
27+
exit 1
28+
fi
29+
30+
echo
31+
echo "## Loading modules"
32+
33+
module -q purge
34+
module -q load baskerville
35+
module -q load bask-apps/live
36+
module -q load PyTorch/2.0.1-foss-2022a-CUDA-11.7.0
37+
module -q load torchvision/0.15.2-foss-2022a-CUDA-11.7.0
38+
39+
echo
40+
echo "## Configuring environment"
41+
42+
export PRIMARY_PORT=$((16384 + $RANDOM % 16384))
43+
export PRIMARY_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
44+
export OMP_NUM_THREADS=1
45+
export ENCODER_DEPTH=${SLURM_ARRAY_TASK_ID}
46+
47+
echo
48+
echo "## Initialising virtual environment"
49+
50+
python -m venv venv
51+
. ./venv/bin/activate
52+
53+
pip install --quiet --upgrade pip
54+
pip install --quiet ../../.[bask]
55+
56+
echo
57+
echo "## Details"
58+
echo
59+
echo "Nodes: ${SLURM_JOB_NUM_NODES}"
60+
echo "GPUs per node: ${SLURM_GPUS_PER_NODE}"
61+
echo "Primary address: ${PRIMARY_ADDR}"
62+
echo "Primary port: ${PRIMARY_PORT}"
63+
echo "Encoder depth: ${ENCODER_DEPTH}"
64+
65+
echo
66+
echo "## Running model"
67+
68+
# Track GPU and CPU metrics
69+
#nvidia-smi dmon -o TD -s puct -d 1 > log-train-gpu.txt &
70+
#vmstat -t 1 -y > log-train-cpu.txt &
71+
72+
# Perform the prediction
73+
# Repeat this 4 times so we get better logs
74+
srun bash -c \
75+
'python -m torch.distributed.run \
76+
--nnodes ${SLURM_JOB_NUM_NODES} \
77+
--nproc-per-node ${SLURM_GPUS_PER_NODE} \
78+
--master_addr ${PRIMARY_ADDR} \
79+
--master_port ${PRIMARY_PORT} \
80+
--node_rank ${SLURM_NODEID} \
81+
train_ed.py \
82+
--download_path ../../downloads \
83+
--encoders ${ENCODER_DEPTH} \
84+
--grad_accum 8'
85+
86+
echo
87+
echo "## Tidying up"
88+
89+
deactivate
90+
popd

train/scripts/train_ed.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@
3131
help="path to download directory",
3232
default="../../era5/era_v_inf",
3333
)
34+
parser.add_argument(
35+
"--encoders",
36+
"-e",
37+
type=int,
38+
help="encoder/decoder depth",
39+
default=12,
40+
)
41+
parser.add_argument(
42+
"--grad_accum",
43+
"-g",
44+
type=int,
45+
help="gradient accumulation steps; must be a multiple of the world size",
46+
default=1,
47+
)
3448
args = parser.parse_args()
3549

3650
if args.xpu:
@@ -76,8 +90,9 @@
7690
RANK = int(os.environ["RANK"])
7791
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
7892

93+
assert args.grad_accum % WORLD_SIZE == 0
7994

80-
def main(download_path: str, xpu: bool = False):
95+
def main(download_path: str, encoder_depth: int, xpu: bool = False):
8196
if xpu:
8297
comms_backend = "ccl"
8398
device_type = "xpu"
@@ -102,15 +117,17 @@ def main(download_path: str, xpu: bool = False):
102117
model = Aurora(
103118
use_lora=False, # Model was not fine-tuned.
104119
autocast=True, # Use AMP.
105-
encoder_depths=(12, 12, 12),
120+
encoder_depths=(encoder_depth, encoder_depth, encoder_depth),
106121
encoder_num_heads=(4, 8, 16),
107-
decoder_depths=(12, 12, 12),
122+
decoder_depths=(encoder_depth, encoder_depth, encoder_depth),
108123
decoder_num_heads=(16, 8, 4),
109124
embed_dim=256,
110125
num_heads=8,
111126
)
112127
# can no longer load checkpoint as we have different model size
113128
# model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
129+
if not xpu:
130+
torch.cuda.set_device(LOCAL_RANK)
114131

115132
download_path = Path(download_path)
116133

@@ -147,6 +164,8 @@ def main(download_path: str, xpu: bool = False):
147164

148165
times = []
149166

167+
n_batches_per_optim = args.grad_accum // WORLD_SIZE
168+
150169
time_start = time.time()
151170
for batch, (X, y) in enumerate(data_loader):
152171
print(f"batch {batch}...", flush=True)
@@ -170,8 +189,9 @@ def main(download_path: str, xpu: bool = False):
170189
print("performing backward pass...", flush=True)
171190
loss.backward()
172191

173-
print("optimizing...", flush=True)
174-
optimizer.step()
192+
if batch % n_batches_per_optim == 0:
193+
print("optimizing...")
194+
optimizer.step()
175195

176196
time_end = time.time()
177197
times.append(time_end - time_start)
@@ -185,6 +205,9 @@ def main(download_path: str, xpu: bool = False):
185205
avg_time = sum([sum(t[1:]) for t in gathered_times]) / sum(
186206
[len(times[1:]) for t in gathered_times]
187207
)
208+
print(
209+
f"Encoder/decoder depth: ({encoder_depth}, {encoder_depth}, {encoder_depth})", flush=True
210+
)
188211
print(
189212
f"Average time per epoch (ignoring first): {avg_time} seconds", flush=True
190213
)
@@ -206,4 +229,4 @@ def main(download_path: str, xpu: bool = False):
206229
print("done", flush=True)
207230

208231

209-
main(args.download_path, xpu=args.xpu)
232+
main(args.download_path, args.encoders, xpu=args.xpu)

0 commit comments

Comments
 (0)