@@ -19,28 +19,22 @@ def __init__(self, num_classes: int):
1919 self .layer4 = resnet34 .layer4 # 512, 1/8
2020
2121 # Classifier
22- self .classifier_2s = nn .Conv2d (128 , num_classes , kernel_size = 1 )
23- self .classifier_4s = nn .Conv2d (256 , num_classes , kernel_size = 1 )
24- self .classifier_8s = nn .Conv2d (512 , num_classes , kernel_size = 1 )
22+ self .classifier = nn .Conv2d (512 , num_classes , kernel_size = 1 )
2523
2624 def forward (self , x ):
2725 # Encoder
28- initial_conv = self .initial_conv (x )
29- layer1 = self .layer1 (initial_conv )
30- layer2 = self .layer2 (layer1 )
31- layer3 = self .layer3 (layer2 )
32- layer4 = self .layer4 (layer3 )
26+ x = self .initial_conv (x )
27+ x = self .layer1 (x )
28+ x = self .layer2 (x )
29+ x = self .layer3 (x )
30+ x = self .layer4 (x )
3331
3432 # Classifier
35- classifier_8s = self .classifier_8s (layer4 )
36- classifier_4s = self .classifier_4s (layer3 )
37- classifier_2s = self .classifier_2s (layer2 )
33+ x = self .classifier (x )
3834
39- # FCN
40- classifier_4s += F .interpolate (classifier_8s , scale_factor = 2 , mode = 'bilinear' , align_corners = False )
41- classifier_2s += F .interpolate (classifier_4s , scale_factor = 2 , mode = 'bilinear' , align_corners = False )
42- out = F .interpolate (classifier_2s , scale_factor = 2 , mode = 'bilinear' , align_corners = False )
43- return out
35+ # Upsample
36+ x = F .interpolate (x , scale_factor = 8 , mode = 'bilinear' , align_corners = False )
37+ return x
4438
4539 def make_initial_conv (self , in_channels : int , out_channels : int ):
4640 return nn .Sequential (
@@ -65,11 +59,11 @@ def load_backbone(num_classes: int, pretrained=False):
6559
6660if __name__ == '__main__' :
6761 device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
68- model = Backbone (8 ).to (device )
62+ model = Backbone (20 ).to (device )
6963 model .eval ()
7064
71- torchsummary .torchsummary .summary (model , (3 , 256 , 512 ))
65+ torchsummary .torchsummary .summary (model , (3 , 400 , 800 ))
7266
7367 writer = torch .utils .tensorboard .SummaryWriter ('../runs' )
74- writer .add_graph (model , torch .rand (1 , 3 , 256 , 512 ).to (device ))
68+ writer .add_graph (model , torch .rand (1 , 3 , 400 , 800 ).to (device ))
7569 writer .close ()
0 commit comments