Skip to content

Commit 5e79433

Browse files
author
Dean Wampler
committed
MNIST code
1 parent a506b2d commit 5e79433

File tree

1 file changed

+164
-0
lines changed

1 file changed

+164
-0
lines changed

ray-tune/mnist.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import os, json
2+
from torchvision import datasets, transforms
3+
import torch
4+
import torch.optim as optim
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from filelock import FileLock
8+
import ray
9+
from ray import tune
10+
11+
DATA_ROOT = '../data/mnist'
12+
EPOCH_SIZE = 512
13+
TEST_SIZE = 256
14+
15+
# Code that was defined inline in the 02 lesson, but loaded from this file
16+
# in subsequent lessons.
17+
18+
class ConvNet(nn.Module):
19+
def __init__(self):
20+
super(ConvNet, self).__init__()
21+
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
22+
self.fc = nn.Linear(192, 10)
23+
24+
def forward(self, x):
25+
x = F.relu(F.max_pool2d(self.conv1(x), 3))
26+
x = x.view(-1, 192)
27+
x = self.fc(x)
28+
return F.log_softmax(x, dim=1)
29+
30+
def get_data_loaders():
31+
mnist_transforms = transforms.Compose(
32+
[transforms.ToTensor(),
33+
transforms.Normalize((0.1307, ), (0.3081, ))])
34+
35+
# We add FileLock here because multiple workers on the same machine coulde try
36+
# download the data. This would cause overwrites, since DataLoader is not threadsafe.
37+
# You wouldn't need this for single-process training.
38+
lock_file = f'{DATA_ROOT}/data.lock'
39+
import os
40+
if not os.path.exists(DATA_ROOT):
41+
os.makedirs(DATA_ROOT)
42+
43+
with FileLock(os.path.expanduser(lock_file)):
44+
train_loader = torch.utils.data.DataLoader(
45+
datasets.MNIST(DATA_ROOT, train=True, download=True, transform=mnist_transforms),
46+
batch_size=64,
47+
shuffle=True)
48+
49+
test_loader = torch.utils.data.DataLoader(
50+
datasets.MNIST(DATA_ROOT, train=False, transform=mnist_transforms),
51+
batch_size=64,
52+
shuffle=True)
53+
return train_loader, test_loader
54+
55+
# In the notebook, this was called "train" and it referenced EPOCH_SIZE directly.
56+
# Here, we eliminate the global variable by returning a closure.
57+
def make_train_step(bound):
58+
def train_step(model, optimizer, train_loader, device=torch.device("cpu")):
59+
model.train()
60+
for batch_idx, (data, target) in enumerate(train_loader):
61+
if batch_idx * len(data) > bound:
62+
return
63+
data, target = data.to(device), target.to(device)
64+
optimizer.zero_grad()
65+
output = model(data)
66+
loss = F.nll_loss(output, target)
67+
loss.backward()
68+
optimizer.step()
69+
return train_step
70+
71+
# In the notebook, this was called "test" and it referenced TEST_SIZE directly.
72+
# Here, we eliminate the global variable by returning a closure.
73+
def make_test_step(bound):
74+
def test_step(model, data_loader, device=torch.device("cpu")):
75+
model.eval()
76+
correct = 0
77+
total = 0
78+
with torch.no_grad():
79+
for batch_idx, (data, target) in enumerate(data_loader):
80+
if batch_idx * len(data) > bound:
81+
break
82+
data, target = data.to(device), target.to(device)
83+
outputs = model(data)
84+
_, predicted = torch.max(outputs.data, 1)
85+
total += target.size(0)
86+
correct += (predicted == target).sum().item()
87+
88+
return correct / total
89+
return test_step
90+
91+
# This is called train_mnist in the notebook, but it's redefined later to be
92+
# what we call train_mnist in this file:
93+
def train_mnist_no_tune(config):
94+
train_loader, test_loader = get_data_loaders()
95+
model = ConvNet()
96+
optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'])
97+
train_step = make_train_step(EPOCH_SIZE)
98+
test_step = make_test_step(TEST_SIZE)
99+
for i in range(10):
100+
train_step(model, optimizer, train_loader)
101+
acc = test_step(model, test_loader)
102+
print(acc)
103+
104+
def train_mnist(config):
105+
from ray.tune import report
106+
train_loader, test_loader = get_data_loaders()
107+
model = ConvNet()
108+
optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'])
109+
train_step = make_train_step(EPOCH_SIZE)
110+
test_step = make_test_step(TEST_SIZE)
111+
for i in range(10):
112+
train_step(model, optimizer, train_loader)
113+
acc = test_step(model, test_loader)
114+
report(mean_accuracy=acc)
115+
116+
# This class implements more features than the notebook version,
117+
# including the ability to save and restore from checkpoints and the
118+
# tracking of timesteps, following this example:
119+
# https://github.com/ray-project/ray/blob/releases/0.8.6/python/ray/tune/examples/bohb_example.py
120+
class TrainMNIST(tune.Trainable):
121+
def _setup(self, config):
122+
self.timestep = 0
123+
self.config = config
124+
self.train_loader, self.test_loader = get_data_loaders()
125+
self.model = ConvNet()
126+
self.optimizer = optim.SGD(self.model.parameters(), lr=self.config["lr"])
127+
self.train_step = make_train_step(EPOCH_SIZE)
128+
self.test_step = make_test_step(TEST_SIZE)
129+
130+
def _train(self):
131+
self.timestep += 1
132+
self.train_step(self.model, self.optimizer, self.train_loader)
133+
acc = self.test_step(self.model, self.test_loader)
134+
return {"mean_accuracy": acc}
135+
136+
137+
def _save(self, checkpoint_dir):
138+
path = os.path.join(checkpoint_dir, "checkpoint")
139+
with open(path, "w") as f:
140+
f.write(json.dumps({"timestep": self.timestep}))
141+
return path
142+
143+
def _restore(self, checkpoint_path):
144+
with open(checkpoint_path) as f:
145+
self.timestep = json.loads(f.read())["timestep"]
146+
147+
148+
if __name__ == '__main__':
149+
150+
config = {
151+
"lr": tune.grid_search([0.001, 0.01, 0.1]),
152+
"momentum": tune.grid_search([0.001, 0.01, 0.1, 0.9])
153+
}
154+
155+
analysis = tune.run(
156+
TrainMNIST,
157+
config=config,
158+
stop={"training_iteration": 10},
159+
verbose=1, # Change to 0 or 1 to reduce the output.
160+
ray_auto_init=False # Don't allow Tune to initialize Ray.
161+
)
162+
print("Best config: ", analysis.get_best_config(metric="mean_accuracy"))
163+
print("Best performing trials:")
164+
print(analysis.dataframe().sort_values('mean_accuracy', ascending=False).head())

0 commit comments

Comments
 (0)