Skip to content

Commit 506a0bb

Browse files
authored
[Relax][PyTorch] Add decomposed operator support for AdaptiveAvgPool (#18437)
* Add decomposed operator support for AdaptiveAvgPool * Refactor avg_pool1d tests
1 parent 6785c8f commit 506a0bb

File tree

2 files changed

+88
-73
lines changed

2 files changed

+88
-73
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,9 @@ def create_convert_map(
950950
# linear algebra
951951
"linalg_vector_norm.default": self._norm,
952952
# neural network
953+
"_adaptive_avg_pool1d.default": self._adaptive_avg_pool1d,
954+
"_adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
955+
"_adaptive_avg_pool3d.default": self._adaptive_avg_pool3d,
953956
"_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional,
954957
"_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training,
955958
"batch_norm.default": self._batch_norm_legit_no_training,

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 85 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,16 +1632,18 @@ def main(
16321632
input_1: R.Tensor((1, 3, 10), dtype="float32")
16331633
) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")):
16341634
with R.dataflow():
1635-
lv: R.Tensor((1, 3, 5), dtype="float32") = R.nn.adaptive_avg_pool1d(
1636-
input_1, output_size=[5], layout="NCW"
1635+
lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input_1, axis=[-2])
1636+
lv1: R.Tensor((1, 3, 1, 5), dtype="float32") = R.nn.adaptive_avg_pool2d(
1637+
lv, output_size=[1, 5], layout="NCHW"
16371638
)
1638-
gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv,)
1639+
lv2: R.Tensor((1, 3, 5), dtype="float32") = R.squeeze(lv1, axis=[-2])
1640+
gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv2,)
16391641
R.output(gv)
16401642
return gv
16411643

16421644
example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
1643-
verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1)
1644-
verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1)
1645+
verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1, run_ep_decomposition=True)
1646+
verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1, run_ep_decomposition=True)
16451647

16461648

16471649
def test_adaptive_avgpool2d():
@@ -1673,8 +1675,8 @@ def main(
16731675
return gv
16741676

16751677
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
1676-
verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1)
1677-
verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)
1678+
verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1, run_ep_decomposition=True)
1679+
verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1, run_ep_decomposition=True)
16781680

16791681

16801682
def test_adaptive_avgpool3d():
@@ -1705,8 +1707,8 @@ def main(
17051707
return gv
17061708

17071709
example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
1708-
verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1)
1709-
verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1)
1710+
verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1, run_ep_decomposition=True)
1711+
verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1, run_ep_decomposition=True)
17101712

17111713

17121714
def test_addmm():
@@ -1781,21 +1783,23 @@ def forward(self, input):
17811783
class expected1:
17821784
@R.function
17831785
def main(
1784-
input_1: R.Tensor((1, 3, 10), dtype="float32")
1786+
input: R.Tensor((1, 3, 10), dtype="float32")
17851787
) -> R.Tuple(R.Tensor((1, 3, 10), dtype="float32")):
17861788
with R.dataflow():
1787-
lv: R.Tensor((1, 3, 10), dtype="float32") = R.nn.avg_pool1d(
1788-
input_1,
1789-
pool_size=[1],
1790-
strides=[1],
1791-
dilation=[1],
1792-
padding=[0, 0],
1789+
lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input, axis=[-2])
1790+
lv1: R.Tensor((1, 3, 1, 10), dtype="float32") = R.nn.avg_pool2d(
1791+
lv,
1792+
pool_size=[1, 1],
1793+
strides=[1, 1],
1794+
dilation=[1, 1],
1795+
padding=[0, 0, 0, 0],
17931796
ceil_mode=False,
1794-
count_include_pad=True,
1795-
layout="NCW",
1796-
out_layout="NCW",
1797+
count_include_pad=False,
1798+
layout="NCHW",
1799+
out_layout="NCHW",
17971800
)
1798-
gv: R.Tuple(R.Tensor((1, 3, 10), dtype="float32")) = (lv,)
1801+
lv2: R.Tensor((1, 3, 10), dtype="float32") = R.squeeze(lv1, axis=[-2])
1802+
gv: R.Tuple(R.Tensor((1, 3, 10), dtype="float32")) = (lv2,)
17991803
R.output(gv)
18001804
return gv
18011805

@@ -1816,20 +1820,24 @@ def forward(self, input):
18161820
@tvm.script.ir_module
18171821
class expected2:
18181822
@R.function
1819-
def main(input_1: R.Tensor((1, 3, 10), dtype="float32")):
1823+
def main(
1824+
input: R.Tensor((1, 3, 10), dtype="float32")
1825+
) -> R.Tuple(R.Tensor((1, 3, 6), dtype="float32")):
18201826
with R.dataflow():
1821-
lv = R.nn.avg_pool1d(
1822-
input_1,
1823-
pool_size=[3],
1824-
strides=[2],
1825-
dilation=[1],
1826-
padding=[1, 1],
1827+
lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input, axis=[-2])
1828+
lv1: R.Tensor((1, 3, 1, 6), dtype="float32") = R.nn.avg_pool2d(
1829+
lv,
1830+
pool_size=[1, 3],
1831+
strides=[1, 2],
1832+
dilation=[1, 1],
1833+
padding=[0, 1, 0, 1],
18271834
ceil_mode=True,
1828-
count_include_pad=True,
1829-
layout="NCW",
1830-
out_layout="NCW",
1835+
count_include_pad=False,
1836+
layout="NCHW",
1837+
out_layout="NCHW",
18311838
)
1832-
gv = (lv,)
1839+
lv2: R.Tensor((1, 3, 6), dtype="float32") = R.squeeze(lv1, axis=[-2])
1840+
gv: R.Tuple(R.Tensor((1, 3, 6), dtype="float32")) = (lv2,)
18331841
R.output(gv)
18341842
return gv
18351843

@@ -1840,28 +1848,32 @@ def forward(self, input):
18401848
@tvm.script.ir_module
18411849
class expected3:
18421850
@R.function
1843-
def main(input_1: R.Tensor((1, 3, 10), dtype="float32")):
1851+
def main(
1852+
input: R.Tensor((1, 3, 10), dtype="float32")
1853+
) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")):
18441854
with R.dataflow():
1845-
lv = R.nn.avg_pool1d(
1846-
input_1,
1847-
pool_size=[2],
1848-
strides=[2],
1849-
dilation=[1],
1850-
padding=[0, 0],
1855+
lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input, axis=[-2])
1856+
lv1: R.Tensor((1, 3, 1, 5), dtype="float32") = R.nn.avg_pool2d(
1857+
lv,
1858+
pool_size=[1, 2],
1859+
strides=[1, 2],
1860+
dilation=[1, 1],
1861+
padding=[0, 0, 0, 0],
18511862
ceil_mode=False,
1852-
count_include_pad=True,
1853-
layout="NCW",
1854-
out_layout="NCW",
1863+
count_include_pad=False,
1864+
layout="NCHW",
1865+
out_layout="NCHW",
18551866
)
1856-
gv = (lv,)
1867+
lv2: R.Tensor((1, 3, 5), dtype="float32") = R.squeeze(lv1, axis=[-2])
1868+
gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv2,)
18571869
R.output(gv)
18581870
return gv
18591871

18601872
example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
1861-
verify_model(AvgPool1d1(), example_args, {}, expected1)
1862-
verify_model(AvgPool1d2(), example_args, {}, expected2)
1863-
verify_model(AvgPool1d3(), example_args, {}, expected2)
1864-
verify_model(AvgPool1d4(), example_args, {}, expected3)
1873+
verify_model(AvgPool1d1(), example_args, {}, expected1, run_ep_decomposition=True)
1874+
verify_model(AvgPool1d2(), example_args, {}, expected2, run_ep_decomposition=True)
1875+
verify_model(AvgPool1d3(), example_args, {}, expected2, run_ep_decomposition=True)
1876+
verify_model(AvgPool1d4(), example_args, {}, expected3, run_ep_decomposition=True)
18651877

18661878

18671879
def test_avg_pool2d():
@@ -1951,10 +1963,10 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
19511963
return gv
19521964

19531965
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
1954-
verify_model(AvgPool2d1(), example_args, {}, expected1)
1955-
verify_model(AvgPool2d2(), example_args, {}, expected2)
1956-
verify_model(AvgPool2d3(), example_args, {}, expected2)
1957-
verify_model(AvgPool2d4(), example_args, {}, expected3)
1966+
verify_model(AvgPool2d1(), example_args, {}, expected1, run_ep_decomposition=True)
1967+
verify_model(AvgPool2d2(), example_args, {}, expected2, run_ep_decomposition=True)
1968+
verify_model(AvgPool2d3(), example_args, {}, expected2, run_ep_decomposition=True)
1969+
verify_model(AvgPool2d4(), example_args, {}, expected3, run_ep_decomposition=True)
19581970

19591971

19601972
def test_avg_pool3d():
@@ -2047,10 +2059,10 @@ def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")):
20472059
return gv
20482060

20492061
example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
2050-
verify_model(AvgPool3d1(), example_args, {}, expected1)
2051-
verify_model(AvgPool3d2(), example_args, {}, expected2)
2052-
verify_model(AvgPool3d3(), example_args, {}, expected2)
2053-
verify_model(AvgPool3d4(), example_args, {}, expected3)
2062+
verify_model(AvgPool3d1(), example_args, {}, expected1, run_ep_decomposition=True)
2063+
verify_model(AvgPool3d2(), example_args, {}, expected2, run_ep_decomposition=True)
2064+
verify_model(AvgPool3d3(), example_args, {}, expected2, run_ep_decomposition=True)
2065+
verify_model(AvgPool3d4(), example_args, {}, expected3, run_ep_decomposition=True)
20542066

20552067

20562068
def test_baddbmm():
@@ -2284,15 +2296,15 @@ def main(
22842296

22852297
model = ConvTranspose1d1()
22862298
binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()}
2287-
verify_model(model, example_args, binding, expected1)
2299+
verify_model(model, example_args, binding, expected1, run_ep_decomposition=True)
22882300

22892301
model = ConvTranspose1d1Func()
22902302
binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()}
2291-
verify_model(model, example_args, binding, expected1)
2303+
verify_model(model, example_args, binding, expected1, run_ep_decomposition=True)
22922304

22932305
model = ConvTranspose1d2()
22942306
binding = {"w1": model.conv.weight.detach().numpy()}
2295-
verify_model(model, example_args, binding, expected2)
2307+
verify_model(model, example_args, binding, expected2, run_ep_decomposition=True)
22962308

22972309

22982310
def test_conv_transpose2d():
@@ -2378,15 +2390,15 @@ def main(
23782390

23792391
model = ConvTranspose2d1()
23802392
binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()}
2381-
verify_model(model, example_args, binding, expected1)
2393+
verify_model(model, example_args, binding, expected1, run_ep_decomposition=True)
23822394

23832395
model = ConvTranspose2d1Func()
23842396
binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()}
2385-
verify_model(model, example_args, binding, expected1)
2397+
verify_model(model, example_args, binding, expected1, run_ep_decomposition=True)
23862398

23872399
model = ConvTranspose2d2()
23882400
binding = {"w1": model.conv.weight.detach().numpy()}
2389-
verify_model(model, example_args, binding, expected2)
2401+
verify_model(model, example_args, binding, expected2, run_ep_decomposition=True)
23902402

23912403

23922404
def test_conv1d():
@@ -2470,15 +2482,15 @@ def main(
24702482

24712483
model = Conv1D1()
24722484
binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()}
2473-
verify_model(model, example_args, binding, expected1)
2485+
verify_model(model, example_args, binding, expected1, run_ep_decomposition=True)
24742486

24752487
model = Conv1D1Func()
24762488
binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()}
2477-
verify_model(model, example_args, binding, expected1)
2489+
verify_model(model, example_args, binding, expected1, run_ep_decomposition=True)
24782490

24792491
model = Conv1D2()
24802492
binding = {"w1": model.conv.weight.detach().numpy()}
2481-
verify_model(model, example_args, binding, expected2)
2493+
verify_model(model, example_args, binding, expected2, run_ep_decomposition=True)
24822494

24832495

24842496
def test_conv2d():
@@ -2562,15 +2574,15 @@ def main(
25622574

25632575
model = Conv2D1()
25642576
binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()}
2565-
verify_model(model, example_args, binding, expected1)
2577+
verify_model(model, example_args, binding, expected1, run_ep_decomposition=True)
25662578

25672579
model = Conv2D1Func()
25682580
binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()}
2569-
verify_model(model, example_args, binding, expected1)
2581+
verify_model(model, example_args, binding, expected1, run_ep_decomposition=True)
25702582

25712583
model = Conv2D2()
25722584
binding = {"w1": model.conv.weight.detach().numpy()}
2573-
verify_model(model, example_args, binding, expected2)
2585+
verify_model(model, example_args, binding, expected2, run_ep_decomposition=True)
25742586

25752587

25762588
def test_conv3d():
@@ -2654,15 +2666,15 @@ def main(
26542666

26552667
model = Conv3D1()
26562668
binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()}
2657-
verify_model(model, example_args, binding, expected1)
2669+
verify_model(model, example_args, binding, expected1, run_ep_decomposition=True)
26582670

26592671
model = Conv3D1Func()
26602672
binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()}
2661-
verify_model(model, example_args, binding, expected1)
2673+
verify_model(model, example_args, binding, expected1, run_ep_decomposition=True)
26622674

26632675
model = Conv3D2()
26642676
binding = {"w1": model.conv.weight.detach().numpy()}
2665-
verify_model(model, example_args, binding, expected2)
2677+
verify_model(model, example_args, binding, expected2, run_ep_decomposition=True)
26662678

26672679

26682680
def test_pad():
@@ -6523,7 +6535,7 @@ def forward(self, x):
65236535
with torch.no_grad():
65246536
pytorch_output = model(x)
65256537
exported_program = export(model, args=(x,))
6526-
mod = from_exported_program(exported_program)
6538+
mod = from_exported_program(exported_program, run_ep_decomposition=True)
65276539
target = tvm.target.Target("llvm")
65286540
ex = relax.build(mod, target)
65296541
vm = relax.VirtualMachine(ex, tvm.cpu())
@@ -6559,7 +6571,7 @@ def forward(self, x):
65596571
with torch.no_grad():
65606572
pytorch_output2 = model2(x2)
65616573
exported_program2 = export(model2, args=(x2,))
6562-
mod2 = from_exported_program(exported_program2)
6574+
mod2 = from_exported_program(exported_program2, run_ep_decomposition=True)
65636575
ex2 = relax.build(mod2, target)
65646576
vm2 = relax.VirtualMachine(ex2, tvm.cpu())
65656577
x2_tvm = tvm.runtime.tensor(x2.numpy())
@@ -6616,7 +6628,7 @@ def forward(self, x):
66166628
with torch.no_grad():
66176629
pytorch_output = model(x)
66186630
exported_program = export(model, args=(x,))
6619-
mod = from_exported_program(exported_program)
6631+
mod = from_exported_program(exported_program, run_ep_decomposition=True)
66206632
target = tvm.target.Target("llvm")
66216633
ex = relax.build(mod, target)
66226634
vm = relax.VirtualMachine(ex, tvm.cpu())
@@ -6652,7 +6664,7 @@ def forward(self, x):
66526664
with torch.no_grad():
66536665
pytorch_output2 = model2(x2)
66546666
exported_program2 = export(model2, args=(x2,))
6655-
mod2 = from_exported_program(exported_program2)
6667+
mod2 = from_exported_program(exported_program2, run_ep_decomposition=True)
66566668
ex2 = relax.build(mod2, target)
66576669
vm2 = relax.VirtualMachine(ex2, tvm.cpu())
66586670
x2_tvm = tvm.runtime.tensor(x2.numpy())

0 commit comments

Comments
 (0)