@@ -138,3 +138,220 @@ def forward(self, x):
138138 loss .item (),
139139 )
140140 )
141+
142+ def train_func_3 ():
143+ import os
144+
145+ import torch
146+ import requests
147+ from pytorch_lightning import LightningModule , Trainer
148+ from pytorch_lightning .callbacks .progress import TQDMProgressBar
149+ from torch import nn
150+ from torch .nn import functional as F
151+ from torch .utils .data import DataLoader , random_split , RandomSampler
152+ from torchmetrics import Accuracy
153+ from torchvision import transforms
154+ from torchvision .datasets import MNIST
155+ import gzip
156+ import shutil
157+ from minio import Minio
158+
159+
160+ PATH_DATASETS = os .environ .get ("PATH_DATASETS" , "." )
161+ BATCH_SIZE = 256 if torch .cuda .is_available () else 64
162+
163+ local_mnist_path = os .path .dirname (os .path .abspath (__file__ ))
164+
165+ print ("prior to running the trainer" )
166+ print ("MASTER_ADDR: is " , os .getenv ("MASTER_ADDR" ))
167+ print ("MASTER_PORT: is " , os .getenv ("MASTER_PORT" ))
168+
169+
170+ STORAGE_BUCKET_EXISTS = "{{.StorageBucketDefaultEndpointExists}}"
171+ print ("STORAGE_BUCKET_EXISTS: " ,STORAGE_BUCKET_EXISTS )
172+ print (f"{ 'Storage_Bucket_Default_Endpoint : is {{.StorageBucketDefaultEndpoint}}' if '{{.StorageBucketDefaultEndpointExists}}' == 'true' else '' } " )
173+ print (f"{ 'Storage_Bucket_Name : is {{.StorageBucketName}}' if '{{.StorageBucketNameExists}}' == 'true' else '' } " )
174+ print (f"{ 'Storage_Bucket_Mnist_Directory : is {{.StorageBucketMnistDir}}' if '{{.StorageBucketMnistDirExists}}' == 'true' else '' } " )
175+
176+ class LitMNIST (LightningModule ):
177+ def __init__ (self , data_dir = PATH_DATASETS , hidden_size = 64 , learning_rate = 2e-4 ):
178+ super ().__init__ ()
179+
180+ # Set our init args as class attributes
181+ self .data_dir = data_dir
182+ self .hidden_size = hidden_size
183+ self .learning_rate = learning_rate
184+
185+ # Hardcode some dataset specific attributes
186+ self .num_classes = 10
187+ self .dims = (1 , 28 , 28 )
188+ channels , width , height = self .dims
189+ self .transform = transforms .Compose (
190+ [
191+ transforms .ToTensor (),
192+ transforms .Normalize ((0.1307 ,), (0.3081 ,)),
193+ ]
194+ )
195+
196+ # Define PyTorch model
197+ self .model = nn .Sequential (
198+ nn .Flatten (),
199+ nn .Linear (channels * width * height , hidden_size ),
200+ nn .ReLU (),
201+ nn .Dropout (0.1 ),
202+ nn .Linear (hidden_size , hidden_size ),
203+ nn .ReLU (),
204+ nn .Dropout (0.1 ),
205+ nn .Linear (hidden_size , self .num_classes ),
206+ )
207+
208+ self .val_accuracy = Accuracy ()
209+ self .test_accuracy = Accuracy ()
210+
211+ def forward (self , x ):
212+ x = self .model (x )
213+ return F .log_softmax (x , dim = 1 )
214+
215+ def training_step (self , batch , batch_idx ):
216+ x , y = batch
217+ logits = self (x )
218+ loss = F .nll_loss (logits , y )
219+ return loss
220+
221+ def validation_step (self , batch , batch_idx ):
222+ x , y = batch
223+ logits = self (x )
224+ loss = F .nll_loss (logits , y )
225+ preds = torch .argmax (logits , dim = 1 )
226+ self .val_accuracy .update (preds , y )
227+
228+ # Calling self.log will surface up scalars for you in TensorBoard
229+ self .log ("val_loss" , loss , prog_bar = True )
230+ self .log ("val_acc" , self .val_accuracy , prog_bar = True )
231+
232+ def test_step (self , batch , batch_idx ):
233+ x , y = batch
234+ logits = self (x )
235+ loss = F .nll_loss (logits , y )
236+ preds = torch .argmax (logits , dim = 1 )
237+ self .test_accuracy .update (preds , y )
238+
239+ # Calling self.log will surface up scalars for you in TensorBoard
240+ self .log ("test_loss" , loss , prog_bar = True )
241+ self .log ("test_acc" , self .test_accuracy , prog_bar = True )
242+
243+ def configure_optimizers (self ):
244+ optimizer = torch .optim .Adam (self .parameters (), lr = self .learning_rate )
245+ return optimizer
246+
247+ ####################
248+ # DATA RELATED HOOKS
249+ ####################
250+
251+ def prepare_data (self ):
252+ # download
253+ print ("Downloading MNIST dataset..." )
254+
255+ if "{{.StorageBucketDefaultEndpointExists}}" == "true" and "{{.StorageBucketDefaultEndpoint}}" != "" :
256+ print ("Using storage bucket to download datasets..." )
257+ dataset_dir = os .path .join (self .data_dir , "MNIST/raw" )
258+ endpoint = "{{.StorageBucketDefaultEndpoint}}"
259+ access_key = "{{.StorageBucketAccessKeyId}}"
260+ secret_key = "{{.StorageBucketSecretKey}}"
261+ bucket_name = "{{.StorageBucketName}}"
262+
263+ # remove prefix if specified in storage bucket endpoint url
264+ secure = True
265+ if endpoint .startswith ("https://" ):
266+ endpoint = endpoint [len ("https://" ) :]
267+ elif endpoint .startswith ("http://" ):
268+ endpoint = endpoint [len ("http://" ) :]
269+ secure = False
270+
271+ client = Minio (
272+ endpoint ,
273+ access_key = access_key ,
274+ secret_key = secret_key ,
275+ cert_check = False ,
276+ secure = secure
277+ )
278+
279+ if not os .path .exists (dataset_dir ):
280+ os .makedirs (dataset_dir )
281+ else :
282+ print (f"Directory '{ dataset_dir } ' already exists" )
283+
284+ # To download datasets from storage bucket's specific directory, use prefix to provide directory name
285+ prefix = "{{.StorageBucketMnistDir}}"
286+ # download all files from prefix folder of storage bucket recursively
287+ for item in client .list_objects (
288+ bucket_name , prefix = prefix , recursive = True
289+ ):
290+ file_name = item .object_name [len (prefix )+ 1 :]
291+ dataset_file_path = os .path .join (dataset_dir , file_name )
292+ print (dataset_file_path )
293+ if not os .path .exists (dataset_file_path ):
294+ client .fget_object (
295+ bucket_name , item .object_name , dataset_file_path
296+ )
297+ else :
298+ print (f"File-path '{ dataset_file_path } ' already exists" )
299+ # Unzip files
300+ with gzip .open (dataset_file_path , "rb" ) as f_in :
301+ with open (dataset_file_path .split ("." )[:- 1 ][0 ], "wb" ) as f_out :
302+ shutil .copyfileobj (f_in , f_out )
303+ # delete zip file
304+ os .remove (dataset_file_path )
305+ download_datasets = False
306+
307+ else :
308+ print ("Using default MNIST mirror reference to download datasets..." )
309+ download_datasets = True
310+
311+ MNIST (self .data_dir , train = True , download = download_datasets )
312+ MNIST (self .data_dir , train = False , download = download_datasets )
313+
314+ def setup (self , stage = None ):
315+
316+ # Assign train/val datasets for use in dataloaders
317+ if stage == "fit" or stage is None :
318+ mnist_full = MNIST (self .data_dir , train = True , transform = self .transform )
319+ self .mnist_train , self .mnist_val = random_split (mnist_full , [55000 , 5000 ])
320+
321+ # Assign test dataset for use in dataloader(s)
322+ if stage == "test" or stage is None :
323+ self .mnist_test = MNIST (
324+ self .data_dir , train = False , transform = self .transform
325+ )
326+
327+ def train_dataloader (self ):
328+ return DataLoader (self .mnist_train , batch_size = BATCH_SIZE , sampler = RandomSampler (self .mnist_train , num_samples = 1000 ))
329+
330+ def val_dataloader (self ):
331+ return DataLoader (self .mnist_val , batch_size = BATCH_SIZE )
332+
333+ def test_dataloader (self ):
334+ return DataLoader (self .mnist_test , batch_size = BATCH_SIZE )
335+
336+
337+ # Init DataLoader from MNIST Dataset
338+
339+ model = LitMNIST (data_dir = local_mnist_path )
340+
341+ print ("GROUP: " , int (os .environ .get ("GROUP_WORLD_SIZE" , 1 )))
342+ print ("LOCAL: " , int (os .environ .get ("LOCAL_WORLD_SIZE" , 1 )))
343+
344+ # Initialize a trainer
345+ trainer = Trainer (
346+ accelerator = "has to be specified" ,
347+ # devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
348+ max_epochs = 3 ,
349+ callbacks = [TQDMProgressBar (refresh_rate = 20 )],
350+ num_nodes = int (os .environ .get ("GROUP_WORLD_SIZE" , 1 )),
351+ devices = int (os .environ .get ("LOCAL_WORLD_SIZE" , 1 )),
352+ replace_sampler_ddp = False ,
353+ strategy = "ddp" ,
354+ )
355+
356+ # Train the model ⚡
357+ trainer .fit (model )
0 commit comments