@@ -1632,16 +1632,18 @@ def main(
16321632 input_1 : R .Tensor ((1 , 3 , 10 ), dtype = "float32" )
16331633 ) -> R .Tuple (R .Tensor ((1 , 3 , 5 ), dtype = "float32" )):
16341634 with R .dataflow ():
1635- lv : R .Tensor ((1 , 3 , 5 ), dtype = "float32" ) = R .nn .adaptive_avg_pool1d (
1636- input_1 , output_size = [5 ], layout = "NCW"
1635+ lv : R .Tensor ((1 , 3 , 1 , 10 ), dtype = "float32" ) = R .expand_dims (input_1 , axis = [- 2 ])
1636+ lv1 : R .Tensor ((1 , 3 , 1 , 5 ), dtype = "float32" ) = R .nn .adaptive_avg_pool2d (
1637+ lv , output_size = [1 , 5 ], layout = "NCHW"
16371638 )
1638- gv : R .Tuple (R .Tensor ((1 , 3 , 5 ), dtype = "float32" )) = (lv ,)
1639+ lv2 : R .Tensor ((1 , 3 , 5 ), dtype = "float32" ) = R .squeeze (lv1 , axis = [- 2 ])
1640+ gv : R .Tuple (R .Tensor ((1 , 3 , 5 ), dtype = "float32" )) = (lv2 ,)
16391641 R .output (gv )
16401642 return gv
16411643
16421644 example_args = (torch .randn (1 , 3 , 10 , dtype = torch .float32 ),)
1643- verify_model (AdaptiveAvgPool1d0 (), example_args , {}, expected1 )
1644- verify_model (AdaptiveAvgPool1d1 (), example_args , {}, expected1 )
1645+ verify_model (AdaptiveAvgPool1d0 (), example_args , {}, expected1 , run_ep_decomposition = True )
1646+ verify_model (AdaptiveAvgPool1d1 (), example_args , {}, expected1 , run_ep_decomposition = True )
16451647
16461648
16471649def test_adaptive_avgpool2d ():
@@ -1673,8 +1675,8 @@ def main(
16731675 return gv
16741676
16751677 example_args = (torch .randn (1 , 3 , 10 , 10 , dtype = torch .float32 ),)
1676- verify_model (AdaptiveAvgPool2d0 (), example_args , {}, expected1 )
1677- verify_model (AdaptiveAvgPool2d1 (), example_args , {}, expected1 )
1678+ verify_model (AdaptiveAvgPool2d0 (), example_args , {}, expected1 , run_ep_decomposition = True )
1679+ verify_model (AdaptiveAvgPool2d1 (), example_args , {}, expected1 , run_ep_decomposition = True )
16781680
16791681
16801682def test_adaptive_avgpool3d ():
@@ -1705,8 +1707,8 @@ def main(
17051707 return gv
17061708
17071709 example_args = (torch .randn (1 , 3 , 8 , 8 , 8 , dtype = torch .float32 ),)
1708- verify_model (AdaptiveAvgPool3d0 (), example_args , {}, expected1 )
1709- verify_model (AdaptiveAvgPool3d1 (), example_args , {}, expected1 )
1710+ verify_model (AdaptiveAvgPool3d0 (), example_args , {}, expected1 , run_ep_decomposition = True )
1711+ verify_model (AdaptiveAvgPool3d1 (), example_args , {}, expected1 , run_ep_decomposition = True )
17101712
17111713
17121714def test_addmm ():
@@ -1781,21 +1783,23 @@ def forward(self, input):
17811783 class expected1 :
17821784 @R .function
17831785 def main (
1784- input_1 : R .Tensor ((1 , 3 , 10 ), dtype = "float32" )
1786+ input : R .Tensor ((1 , 3 , 10 ), dtype = "float32" )
17851787 ) -> R .Tuple (R .Tensor ((1 , 3 , 10 ), dtype = "float32" )):
17861788 with R .dataflow ():
1787- lv : R .Tensor ((1 , 3 , 10 ), dtype = "float32" ) = R .nn .avg_pool1d (
1788- input_1 ,
1789- pool_size = [1 ],
1790- strides = [1 ],
1791- dilation = [1 ],
1792- padding = [0 , 0 ],
1789+ lv : R .Tensor ((1 , 3 , 1 , 10 ), dtype = "float32" ) = R .expand_dims (input , axis = [- 2 ])
1790+ lv1 : R .Tensor ((1 , 3 , 1 , 10 ), dtype = "float32" ) = R .nn .avg_pool2d (
1791+ lv ,
1792+ pool_size = [1 , 1 ],
1793+ strides = [1 , 1 ],
1794+ dilation = [1 , 1 ],
1795+ padding = [0 , 0 , 0 , 0 ],
17931796 ceil_mode = False ,
1794- count_include_pad = True ,
1795- layout = "NCW " ,
1796- out_layout = "NCW " ,
1797+ count_include_pad = False ,
1798+ layout = "NCHW " ,
1799+ out_layout = "NCHW " ,
17971800 )
1798- gv : R .Tuple (R .Tensor ((1 , 3 , 10 ), dtype = "float32" )) = (lv ,)
1801+ lv2 : R .Tensor ((1 , 3 , 10 ), dtype = "float32" ) = R .squeeze (lv1 , axis = [- 2 ])
1802+ gv : R .Tuple (R .Tensor ((1 , 3 , 10 ), dtype = "float32" )) = (lv2 ,)
17991803 R .output (gv )
18001804 return gv
18011805
@@ -1816,20 +1820,24 @@ def forward(self, input):
18161820 @tvm .script .ir_module
18171821 class expected2 :
18181822 @R .function
1819- def main (input_1 : R .Tensor ((1 , 3 , 10 ), dtype = "float32" )):
1823+ def main (
1824+ input : R .Tensor ((1 , 3 , 10 ), dtype = "float32" )
1825+ ) -> R .Tuple (R .Tensor ((1 , 3 , 6 ), dtype = "float32" )):
18201826 with R .dataflow ():
1821- lv = R .nn .avg_pool1d (
1822- input_1 ,
1823- pool_size = [3 ],
1824- strides = [2 ],
1825- dilation = [1 ],
1826- padding = [1 , 1 ],
1827+ lv : R .Tensor ((1 , 3 , 1 , 10 ), dtype = "float32" ) = R .expand_dims (input , axis = [- 2 ])
1828+ lv1 : R .Tensor ((1 , 3 , 1 , 6 ), dtype = "float32" ) = R .nn .avg_pool2d (
1829+ lv ,
1830+ pool_size = [1 , 3 ],
1831+ strides = [1 , 2 ],
1832+ dilation = [1 , 1 ],
1833+ padding = [0 , 1 , 0 , 1 ],
18271834 ceil_mode = True ,
1828- count_include_pad = True ,
1829- layout = "NCW " ,
1830- out_layout = "NCW " ,
1835+ count_include_pad = False ,
1836+ layout = "NCHW " ,
1837+ out_layout = "NCHW " ,
18311838 )
1832- gv = (lv ,)
1839+ lv2 : R .Tensor ((1 , 3 , 6 ), dtype = "float32" ) = R .squeeze (lv1 , axis = [- 2 ])
1840+ gv : R .Tuple (R .Tensor ((1 , 3 , 6 ), dtype = "float32" )) = (lv2 ,)
18331841 R .output (gv )
18341842 return gv
18351843
@@ -1840,28 +1848,32 @@ def forward(self, input):
18401848 @tvm .script .ir_module
18411849 class expected3 :
18421850 @R .function
1843- def main (input_1 : R .Tensor ((1 , 3 , 10 ), dtype = "float32" )):
1851+ def main (
1852+ input : R .Tensor ((1 , 3 , 10 ), dtype = "float32" )
1853+ ) -> R .Tuple (R .Tensor ((1 , 3 , 5 ), dtype = "float32" )):
18441854 with R .dataflow ():
1845- lv = R .nn .avg_pool1d (
1846- input_1 ,
1847- pool_size = [2 ],
1848- strides = [2 ],
1849- dilation = [1 ],
1850- padding = [0 , 0 ],
1855+ lv : R .Tensor ((1 , 3 , 1 , 10 ), dtype = "float32" ) = R .expand_dims (input , axis = [- 2 ])
1856+ lv1 : R .Tensor ((1 , 3 , 1 , 5 ), dtype = "float32" ) = R .nn .avg_pool2d (
1857+ lv ,
1858+ pool_size = [1 , 2 ],
1859+ strides = [1 , 2 ],
1860+ dilation = [1 , 1 ],
1861+ padding = [0 , 0 , 0 , 0 ],
18511862 ceil_mode = False ,
1852- count_include_pad = True ,
1853- layout = "NCW " ,
1854- out_layout = "NCW " ,
1863+ count_include_pad = False ,
1864+ layout = "NCHW " ,
1865+ out_layout = "NCHW " ,
18551866 )
1856- gv = (lv ,)
1867+ lv2 : R .Tensor ((1 , 3 , 5 ), dtype = "float32" ) = R .squeeze (lv1 , axis = [- 2 ])
1868+ gv : R .Tuple (R .Tensor ((1 , 3 , 5 ), dtype = "float32" )) = (lv2 ,)
18571869 R .output (gv )
18581870 return gv
18591871
18601872 example_args = (torch .randn (1 , 3 , 10 , dtype = torch .float32 ),)
1861- verify_model (AvgPool1d1 (), example_args , {}, expected1 )
1862- verify_model (AvgPool1d2 (), example_args , {}, expected2 )
1863- verify_model (AvgPool1d3 (), example_args , {}, expected2 )
1864- verify_model (AvgPool1d4 (), example_args , {}, expected3 )
1873+ verify_model (AvgPool1d1 (), example_args , {}, expected1 , run_ep_decomposition = True )
1874+ verify_model (AvgPool1d2 (), example_args , {}, expected2 , run_ep_decomposition = True )
1875+ verify_model (AvgPool1d3 (), example_args , {}, expected2 , run_ep_decomposition = True )
1876+ verify_model (AvgPool1d4 (), example_args , {}, expected3 , run_ep_decomposition = True )
18651877
18661878
18671879def test_avg_pool2d ():
@@ -1951,10 +1963,10 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
19511963 return gv
19521964
19531965 example_args = (torch .randn (1 , 3 , 10 , 10 , dtype = torch .float32 ),)
1954- verify_model (AvgPool2d1 (), example_args , {}, expected1 )
1955- verify_model (AvgPool2d2 (), example_args , {}, expected2 )
1956- verify_model (AvgPool2d3 (), example_args , {}, expected2 )
1957- verify_model (AvgPool2d4 (), example_args , {}, expected3 )
1966+ verify_model (AvgPool2d1 (), example_args , {}, expected1 , run_ep_decomposition = True )
1967+ verify_model (AvgPool2d2 (), example_args , {}, expected2 , run_ep_decomposition = True )
1968+ verify_model (AvgPool2d3 (), example_args , {}, expected2 , run_ep_decomposition = True )
1969+ verify_model (AvgPool2d4 (), example_args , {}, expected3 , run_ep_decomposition = True )
19581970
19591971
19601972def test_avg_pool3d ():
@@ -2047,10 +2059,10 @@ def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")):
20472059 return gv
20482060
20492061 example_args = (torch .randn (1 , 3 , 8 , 8 , 8 , dtype = torch .float32 ),)
2050- verify_model (AvgPool3d1 (), example_args , {}, expected1 )
2051- verify_model (AvgPool3d2 (), example_args , {}, expected2 )
2052- verify_model (AvgPool3d3 (), example_args , {}, expected2 )
2053- verify_model (AvgPool3d4 (), example_args , {}, expected3 )
2062+ verify_model (AvgPool3d1 (), example_args , {}, expected1 , run_ep_decomposition = True )
2063+ verify_model (AvgPool3d2 (), example_args , {}, expected2 , run_ep_decomposition = True )
2064+ verify_model (AvgPool3d3 (), example_args , {}, expected2 , run_ep_decomposition = True )
2065+ verify_model (AvgPool3d4 (), example_args , {}, expected3 , run_ep_decomposition = True )
20542066
20552067
20562068def test_baddbmm ():
@@ -2284,15 +2296,15 @@ def main(
22842296
22852297 model = ConvTranspose1d1 ()
22862298 binding = {"w1" : model .conv .weight .detach ().numpy (), "w2" : model .conv .bias .detach ().numpy ()}
2287- verify_model (model , example_args , binding , expected1 )
2299+ verify_model (model , example_args , binding , expected1 , run_ep_decomposition = True )
22882300
22892301 model = ConvTranspose1d1Func ()
22902302 binding = {"w1" : model .weight .detach ().numpy (), "w2" : model .bias .detach ().numpy ()}
2291- verify_model (model , example_args , binding , expected1 )
2303+ verify_model (model , example_args , binding , expected1 , run_ep_decomposition = True )
22922304
22932305 model = ConvTranspose1d2 ()
22942306 binding = {"w1" : model .conv .weight .detach ().numpy ()}
2295- verify_model (model , example_args , binding , expected2 )
2307+ verify_model (model , example_args , binding , expected2 , run_ep_decomposition = True )
22962308
22972309
22982310def test_conv_transpose2d ():
@@ -2378,15 +2390,15 @@ def main(
23782390
23792391 model = ConvTranspose2d1 ()
23802392 binding = {"w1" : model .conv .weight .detach ().numpy (), "w2" : model .conv .bias .detach ().numpy ()}
2381- verify_model (model , example_args , binding , expected1 )
2393+ verify_model (model , example_args , binding , expected1 , run_ep_decomposition = True )
23822394
23832395 model = ConvTranspose2d1Func ()
23842396 binding = {"w1" : model .weight .detach ().numpy (), "w2" : model .bias .detach ().numpy ()}
2385- verify_model (model , example_args , binding , expected1 )
2397+ verify_model (model , example_args , binding , expected1 , run_ep_decomposition = True )
23862398
23872399 model = ConvTranspose2d2 ()
23882400 binding = {"w1" : model .conv .weight .detach ().numpy ()}
2389- verify_model (model , example_args , binding , expected2 )
2401+ verify_model (model , example_args , binding , expected2 , run_ep_decomposition = True )
23902402
23912403
23922404def test_conv1d ():
@@ -2470,15 +2482,15 @@ def main(
24702482
24712483 model = Conv1D1 ()
24722484 binding = {"w1" : model .conv .weight .detach ().numpy (), "w2" : model .conv .bias .detach ().numpy ()}
2473- verify_model (model , example_args , binding , expected1 )
2485+ verify_model (model , example_args , binding , expected1 , run_ep_decomposition = True )
24742486
24752487 model = Conv1D1Func ()
24762488 binding = {"w1" : model .weight .detach ().numpy (), "w2" : model .bias .detach ().numpy ()}
2477- verify_model (model , example_args , binding , expected1 )
2489+ verify_model (model , example_args , binding , expected1 , run_ep_decomposition = True )
24782490
24792491 model = Conv1D2 ()
24802492 binding = {"w1" : model .conv .weight .detach ().numpy ()}
2481- verify_model (model , example_args , binding , expected2 )
2493+ verify_model (model , example_args , binding , expected2 , run_ep_decomposition = True )
24822494
24832495
24842496def test_conv2d ():
@@ -2562,15 +2574,15 @@ def main(
25622574
25632575 model = Conv2D1 ()
25642576 binding = {"w1" : model .conv .weight .detach ().numpy (), "w2" : model .conv .bias .detach ().numpy ()}
2565- verify_model (model , example_args , binding , expected1 )
2577+ verify_model (model , example_args , binding , expected1 , run_ep_decomposition = True )
25662578
25672579 model = Conv2D1Func ()
25682580 binding = {"w1" : model .weight .numpy (), "w2" : model .bias .numpy ()}
2569- verify_model (model , example_args , binding , expected1 )
2581+ verify_model (model , example_args , binding , expected1 , run_ep_decomposition = True )
25702582
25712583 model = Conv2D2 ()
25722584 binding = {"w1" : model .conv .weight .detach ().numpy ()}
2573- verify_model (model , example_args , binding , expected2 )
2585+ verify_model (model , example_args , binding , expected2 , run_ep_decomposition = True )
25742586
25752587
25762588def test_conv3d ():
@@ -2654,15 +2666,15 @@ def main(
26542666
26552667 model = Conv3D1 ()
26562668 binding = {"w1" : model .conv .weight .detach ().numpy (), "w2" : model .conv .bias .detach ().numpy ()}
2657- verify_model (model , example_args , binding , expected1 )
2669+ verify_model (model , example_args , binding , expected1 , run_ep_decomposition = True )
26582670
26592671 model = Conv3D1Func ()
26602672 binding = {"w1" : model .weight .detach ().numpy (), "w2" : model .bias .detach ().numpy ()}
2661- verify_model (model , example_args , binding , expected1 )
2673+ verify_model (model , example_args , binding , expected1 , run_ep_decomposition = True )
26622674
26632675 model = Conv3D2 ()
26642676 binding = {"w1" : model .conv .weight .detach ().numpy ()}
2665- verify_model (model , example_args , binding , expected2 )
2677+ verify_model (model , example_args , binding , expected2 , run_ep_decomposition = True )
26662678
26672679
26682680def test_pad ():
@@ -6523,7 +6535,7 @@ def forward(self, x):
65236535 with torch .no_grad ():
65246536 pytorch_output = model (x )
65256537 exported_program = export (model , args = (x ,))
6526- mod = from_exported_program (exported_program )
6538+ mod = from_exported_program (exported_program , run_ep_decomposition = True )
65276539 target = tvm .target .Target ("llvm" )
65286540 ex = relax .build (mod , target )
65296541 vm = relax .VirtualMachine (ex , tvm .cpu ())
@@ -6559,7 +6571,7 @@ def forward(self, x):
65596571 with torch .no_grad ():
65606572 pytorch_output2 = model2 (x2 )
65616573 exported_program2 = export (model2 , args = (x2 ,))
6562- mod2 = from_exported_program (exported_program2 )
6574+ mod2 = from_exported_program (exported_program2 , run_ep_decomposition = True )
65636575 ex2 = relax .build (mod2 , target )
65646576 vm2 = relax .VirtualMachine (ex2 , tvm .cpu ())
65656577 x2_tvm = tvm .runtime .tensor (x2 .numpy ())
@@ -6616,7 +6628,7 @@ def forward(self, x):
66166628 with torch .no_grad ():
66176629 pytorch_output = model (x )
66186630 exported_program = export (model , args = (x ,))
6619- mod = from_exported_program (exported_program )
6631+ mod = from_exported_program (exported_program , run_ep_decomposition = True )
66206632 target = tvm .target .Target ("llvm" )
66216633 ex = relax .build (mod , target )
66226634 vm = relax .VirtualMachine (ex , tvm .cpu ())
@@ -6652,7 +6664,7 @@ def forward(self, x):
66526664 with torch .no_grad ():
66536665 pytorch_output2 = model2 (x2 )
66546666 exported_program2 = export (model2 , args = (x2 ,))
6655- mod2 = from_exported_program (exported_program2 )
6667+ mod2 = from_exported_program (exported_program2 , run_ep_decomposition = True )
66566668 ex2 = relax .build (mod2 , target )
66576669 vm2 = relax .VirtualMachine (ex2 , tvm .cpu ())
66586670 x2_tvm = tvm .runtime .tensor (x2 .numpy ())
0 commit comments