1+ import os , json
2+ from torchvision import datasets , transforms
3+ import torch
4+ import torch .optim as optim
5+ import torch .nn as nn
6+ import torch .nn .functional as F
7+ from filelock import FileLock
8+ import ray
9+ from ray import tune
10+
11+ DATA_ROOT = '../data/mnist'
12+ EPOCH_SIZE = 512
13+ TEST_SIZE = 256
14+
15+ # Code that was defined inline in the 02 lesson, but loaded from this file
16+ # in subsequent lessons.
17+
18+ class ConvNet (nn .Module ):
19+ def __init__ (self ):
20+ super (ConvNet , self ).__init__ ()
21+ self .conv1 = nn .Conv2d (1 , 3 , kernel_size = 3 )
22+ self .fc = nn .Linear (192 , 10 )
23+
24+ def forward (self , x ):
25+ x = F .relu (F .max_pool2d (self .conv1 (x ), 3 ))
26+ x = x .view (- 1 , 192 )
27+ x = self .fc (x )
28+ return F .log_softmax (x , dim = 1 )
29+
30+ def get_data_loaders ():
31+ mnist_transforms = transforms .Compose (
32+ [transforms .ToTensor (),
33+ transforms .Normalize ((0.1307 , ), (0.3081 , ))])
34+
35+ # We add FileLock here because multiple workers on the same machine coulde try
36+ # download the data. This would cause overwrites, since DataLoader is not threadsafe.
37+ # You wouldn't need this for single-process training.
38+ lock_file = f'{ DATA_ROOT } /data.lock'
39+ import os
40+ if not os .path .exists (DATA_ROOT ):
41+ os .makedirs (DATA_ROOT )
42+
43+ with FileLock (os .path .expanduser (lock_file )):
44+ train_loader = torch .utils .data .DataLoader (
45+ datasets .MNIST (DATA_ROOT , train = True , download = True , transform = mnist_transforms ),
46+ batch_size = 64 ,
47+ shuffle = True )
48+
49+ test_loader = torch .utils .data .DataLoader (
50+ datasets .MNIST (DATA_ROOT , train = False , transform = mnist_transforms ),
51+ batch_size = 64 ,
52+ shuffle = True )
53+ return train_loader , test_loader
54+
55+ # In the notebook, this was called "train" and it referenced EPOCH_SIZE directly.
56+ # Here, we eliminate the global variable by returning a closure.
57+ def make_train_step (bound ):
58+ def train_step (model , optimizer , train_loader , device = torch .device ("cpu" )):
59+ model .train ()
60+ for batch_idx , (data , target ) in enumerate (train_loader ):
61+ if batch_idx * len (data ) > bound :
62+ return
63+ data , target = data .to (device ), target .to (device )
64+ optimizer .zero_grad ()
65+ output = model (data )
66+ loss = F .nll_loss (output , target )
67+ loss .backward ()
68+ optimizer .step ()
69+ return train_step
70+
71+ # In the notebook, this was called "test" and it referenced TEST_SIZE directly.
72+ # Here, we eliminate the global variable by returning a closure.
73+ def make_test_step (bound ):
74+ def test_step (model , data_loader , device = torch .device ("cpu" )):
75+ model .eval ()
76+ correct = 0
77+ total = 0
78+ with torch .no_grad ():
79+ for batch_idx , (data , target ) in enumerate (data_loader ):
80+ if batch_idx * len (data ) > bound :
81+ break
82+ data , target = data .to (device ), target .to (device )
83+ outputs = model (data )
84+ _ , predicted = torch .max (outputs .data , 1 )
85+ total += target .size (0 )
86+ correct += (predicted == target ).sum ().item ()
87+
88+ return correct / total
89+ return test_step
90+
91+ # This is called train_mnist in the notebook, but it's redefined later to be
92+ # what we call train_mnist in this file:
93+ def train_mnist_no_tune (config ):
94+ train_loader , test_loader = get_data_loaders ()
95+ model = ConvNet ()
96+ optimizer = optim .SGD (model .parameters (), lr = config ["lr" ], momentum = config ['momentum' ])
97+ train_step = make_train_step (EPOCH_SIZE )
98+ test_step = make_test_step (TEST_SIZE )
99+ for i in range (10 ):
100+ train_step (model , optimizer , train_loader )
101+ acc = test_step (model , test_loader )
102+ print (acc )
103+
104+ def train_mnist (config ):
105+ from ray .tune import report
106+ train_loader , test_loader = get_data_loaders ()
107+ model = ConvNet ()
108+ optimizer = optim .SGD (model .parameters (), lr = config ["lr" ], momentum = config ['momentum' ])
109+ train_step = make_train_step (EPOCH_SIZE )
110+ test_step = make_test_step (TEST_SIZE )
111+ for i in range (10 ):
112+ train_step (model , optimizer , train_loader )
113+ acc = test_step (model , test_loader )
114+ report (mean_accuracy = acc )
115+
116+ # This class implements more features than the notebook version,
117+ # including the ability to save and restore from checkpoints and the
118+ # tracking of timesteps, following this example:
119+ # https://github.com/ray-project/ray/blob/releases/0.8.6/python/ray/tune/examples/bohb_example.py
120+ class TrainMNIST (tune .Trainable ):
121+ def _setup (self , config ):
122+ self .timestep = 0
123+ self .config = config
124+ self .train_loader , self .test_loader = get_data_loaders ()
125+ self .model = ConvNet ()
126+ self .optimizer = optim .SGD (self .model .parameters (), lr = self .config ["lr" ])
127+ self .train_step = make_train_step (EPOCH_SIZE )
128+ self .test_step = make_test_step (TEST_SIZE )
129+
130+ def _train (self ):
131+ self .timestep += 1
132+ self .train_step (self .model , self .optimizer , self .train_loader )
133+ acc = self .test_step (self .model , self .test_loader )
134+ return {"mean_accuracy" : acc }
135+
136+
137+ def _save (self , checkpoint_dir ):
138+ path = os .path .join (checkpoint_dir , "checkpoint" )
139+ with open (path , "w" ) as f :
140+ f .write (json .dumps ({"timestep" : self .timestep }))
141+ return path
142+
143+ def _restore (self , checkpoint_path ):
144+ with open (checkpoint_path ) as f :
145+ self .timestep = json .loads (f .read ())["timestep" ]
146+
147+
148+ if __name__ == '__main__' :
149+
150+ config = {
151+ "lr" : tune .grid_search ([0.001 , 0.01 , 0.1 ]),
152+ "momentum" : tune .grid_search ([0.001 , 0.01 , 0.1 , 0.9 ])
153+ }
154+
155+ analysis = tune .run (
156+ TrainMNIST ,
157+ config = config ,
158+ stop = {"training_iteration" : 10 },
159+ verbose = 1 , # Change to 0 or 1 to reduce the output.
160+ ray_auto_init = False # Don't allow Tune to initialize Ray.
161+ )
162+ print ("Best config: " , analysis .get_best_config (metric = "mean_accuracy" ))
163+ print ("Best performing trials:" )
164+ print (analysis .dataframe ().sort_values ('mean_accuracy' , ascending = False ).head ())
0 commit comments