@@ -6,6 +6,10 @@ def train_func():
66 from torchvision import datasets , transforms
77 import torch .distributed as dist
88 from pathlib import Path
9+ from minio import Minio
10+ import shutil
11+ import gzip
12+
913
1014 # [1] Setup PyTorch DDP. Distributed environment will be set automatically by Training Operator.
1115 dist .init_process_group (backend = "nccl" if torch .cuda .is_available () else "gloo" )
@@ -45,13 +49,63 @@ def forward(self, x):
4549 optimizer = torch .optim .SGD (model .parameters (), lr = 0.01 , momentum = 0.5 )
4650
4751 # [4] Setup FashionMNIST dataloader and distribute data across PyTorchJob workers.
48- Path (f"./data{ local_rank } " ).mkdir (exist_ok = True )
49- dataset = datasets .FashionMNIST (
50- f"./data{ local_rank } " ,
51- download = True ,
52- train = True ,
53- transform = transforms .Compose ([transforms .ToTensor ()]),
54- )
52+ dataset_path = "./data"
53+ dataset_dir = os .path .join (dataset_path , "MNIST/raw" )
54+ with_aws = "{{.StorageBucketNameExists}}"
55+ endpoint = "{{.StorageBucketDefaultEndpoint}}"
56+ access_key = "{{.StorageBucketAccessKeyId}}"
57+ secret_key = "{{.StorageBucketSecretKey}}"
58+ bucket_name = "{{.StorageBucketName}}"
59+ prefix = "{{.StorageBucketMnistDir}}"
60+ if with_aws != "true" :
61+ client = Minio (
62+ endpoint ,
63+ access_key = access_key ,
64+ secret_key = secret_key ,
65+ cert_check = False ,
66+ secure = False , #TODO
67+ )
68+
69+ if not os .path .exists (dataset_dir ):
70+ os .makedirs (dataset_dir )
71+
72+ for item in client .list_objects (
73+ bucket_name , prefix = prefix , recursive = True
74+ ):
75+ file_name = item .object_name [len (prefix )+ 1 :]
76+ dataset_file_path = os .path .join (dataset_dir , file_name )
77+ print (f"Downloading dataset file { file_name } to { dataset_file_path } .." )
78+ if not os .path .exists (dataset_file_path ):
79+ client .fget_object (
80+ bucket_name , item .object_name , dataset_file_path
81+ )
82+ # Unzip files --
83+ ## Sample zipfilepath : ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
84+ with gzip .open (dataset_file_path , "rb" ) as f_in :
85+ filename = file_name .split ("." )[0 ] #-> t10k-images-idx3-ubyte
86+ file_path = ("/" .join (dataset_file_path .split ("/" )[:- 1 ])) #->../data/MNIST/raw
87+ full_file_path = os .path .join (file_path ,filename ) #->../data/MNIST/raw/t10k-images-idx3-ubyte
88+ print (f"Extracting { dataset_file_path } to { file_path } .." )
89+
90+ with open (full_file_path , "wb" ) as f_out :
91+ shutil .copyfileobj (f_in , f_out )
92+ print (f"Dataset file downloaded : { full_file_path } \n " )
93+ # delete zip file
94+ os .remove (dataset_file_path )
95+
96+ dataset = datasets .MNIST (
97+ dataset_path ,
98+ train = True ,
99+ download = False ,
100+ transform = transforms .Compose ([transforms .ToTensor ()]),
101+ )
102+ else :
103+ dataset = datasets .MNIST (
104+ dataset_path ,
105+ train = True ,
106+ download = True ,
107+ transform = transforms .Compose ([transforms .ToTensor ()]),
108+ )
55109 train_loader = torch .utils .data .DataLoader (
56110 dataset = dataset ,
57111 batch_size = 128 ,
0 commit comments