-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmnist.patch
More file actions
129 lines (117 loc) · 4.75 KB
/
mnist.patch
File metadata and controls
129 lines (117 loc) · 4.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
--- main.py.d0b7e37 2026-02-18 14:50:27.722389236 -0600
+++ main.py 2026-02-18 14:59:02.737245645 -0600
@@ -1,3 +1,8 @@
+#
+# Copyright (C) 2025, Northwestern University and Argonne National Laboratory
+# See COPYRIGHT notice in top-level directory.
+#
+
import argparse
import torch
import torch.nn as nn
@@ -5,7 +10,11 @@
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data.distributed import DistributedSampler
+import comm_file, pnetcdf_io
+from mpi4py import MPI
class Net(nn.Module):
def __init__(self):
@@ -42,7 +51,7 @@
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
- if batch_idx % args.log_interval == 0:
+ if rank == 0 and batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
@@ -62,9 +71,14 @@
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
+ # aggregate loss among all ranks
+ test_loss = comm.mpi_comm.allreduce(test_loss, op=MPI.SUM)
+ correct = comm.mpi_comm.allreduce(correct, op=MPI.SUM)
+
test_loss /= len(test_loader.dataset)
- print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
+ if rank == 0:
+ print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
@@ -92,6 +106,8 @@
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true',
help='For Saving the current Model')
+ parser.add_argument('--input-file', type=str, required=True,
+ help='NetCDF file storing train and test samples')
args = parser.parse_args()
use_accel = not args.no_accel and torch.accelerator.is_available()
@@ -103,7 +119,7 @@
else:
device = torch.device("cpu")
- train_kwargs = {'batch_size': args.batch_size}
+ train_kwargs = {'batch_size': args.batch_size//nprocs}
test_kwargs = {'batch_size': args.test_batch_size}
if use_accel:
accel_kwargs = {'num_workers': 1,
@@ -117,25 +133,54 @@
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
- dataset1 = datasets.MNIST('../data', train=True, download=True,
- transform=transform)
- dataset2 = datasets.MNIST('../data', train=False,
- transform=transform)
- train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
- test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
+
+ # Open files storing training and testing samples
+ infile = args.input_file
+ train_file = pnetcdf_io.dataset(infile, 'train_samples', 'train_labels', transform, comm.mpi_comm)
+ test_file = pnetcdf_io.dataset(infile, 'test_samples', 'test_labels', transform, comm.mpi_comm)
+
+ # create distributed samplers
+ train_sampler = DistributedSampler(train_file, num_replicas=nprocs, rank=rank, shuffle=True)
+ test_sampler = DistributedSampler(test_file, num_replicas=nprocs, rank=rank, shuffle=False)
+
+ # add distributed samplers to DataLoaders
+ train_loader = torch.utils.data.DataLoader(train_file, sampler=train_sampler, **train_kwargs)
+ test_loader = torch.utils.data.DataLoader(test_file, sampler=test_sampler, **test_kwargs, drop_last=False)
model = Net().to(device)
+
+ # use DDP
+ model = DDP(model, device_ids=[device] if use_accel else None)
+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
+ # train sampler set epoch
+ train_sampler.set_epoch(epoch)
+ test_sampler.set_epoch(epoch)
+
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()
if args.save_model:
- torch.save(model.state_dict(), "mnist_cnn.pt")
+ if rank == 0:
+ torch.save(model.state_dict(), "mnist_cnn.pt")
+
+ # close files
+ train_file.close()
+ test_file.close()
if __name__ == '__main__':
- main()
+ ## initialize parallel environment
+ comm, device = comm_file.init_parallel()
+
+ rank = comm.get_rank()
+ nprocs = comm.get_size()
+
+ main()
+
+ comm.finalize()
+