@@ -186,8 +186,8 @@ def __repr__(self):
186186 return self .__class__ .__name__ + '()'
187187
188188
189- # 데이터셋 불러오는 코드 검증
190- def show_dataset (images : torch .Tensor , masks : torch .Tensor ):
189+ # 데이터셋 불러오는 코드 검증 (Shape: [batch, channel, height, width])
190+ def show_dataset (image : torch .Tensor , target : torch .Tensor ):
191191 def make_plt_subplot (nrows : int , ncols : int , index : int , title : str , image ):
192192 plt .subplot (nrows , ncols , index )
193193 plt .title (title )
@@ -197,8 +197,8 @@ def make_plt_subplot(nrows: int, ncols: int, index: int, title: str, image):
197197
198198 to_pil_image = torchvision .transforms .ToPILImage ()
199199
200- assert images .shape [0 ] == masks .shape [0 ]
201- for i in range (images .shape [0 ]):
202- make_plt_subplot (1 , 2 , 1 , 'Input image' , to_pil_image (images [i ].squeeze ().cpu ()))
203- make_plt_subplot (1 , 2 , 2 , 'Groundtruth' , to_pil_image (masks [i ].cpu ()))
200+ assert image .shape [0 ] == target .shape [0 ]
201+ for i in range (image .shape [0 ]):
202+ make_plt_subplot (1 , 2 , 1 , 'Input image' , to_pil_image (image [i ].squeeze ().cpu ()))
203+ make_plt_subplot (1 , 2 , 2 , 'Groundtruth' , to_pil_image (target [i ].cpu ()))
204204 plt .show ()
0 commit comments