Skip to content

Commit 03d55df

Browse files
authored
[Relax][PyTorch] Add support for decomposed operators and fix IR of ops tests(5) (#18417)
* f1 * f2 * f3 * f5 * f7
1 parent 1e28bf9 commit 03d55df

File tree

1 file changed

+62
-58
lines changed

1 file changed

+62
-58
lines changed

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 62 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4580,47 +4580,51 @@ def forward(self, x, y):
45804580
class Expected0:
45814581
@R.function
45824582
def main(
4583-
inp_0: R.Tensor((2, 3), dtype="float32"),
4584-
inp_1: R.Tensor((2, 3), dtype="float32"),
4583+
x: R.Tensor((2, 3), dtype="float32"),
4584+
y: R.Tensor((2, 3), dtype="float32"),
45854585
) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
45864586
with R.dataflow():
4587-
lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, inp_1), axis=0)
4588-
gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
4587+
lv: R.Tensor((4, 3), dtype="float32") = R.concat((x, y), axis=0)
4588+
lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, R.shape([2, 2, 3]))
4589+
gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,)
45894590
R.output(gv)
45904591
return gv
45914592

45924593
@I.ir_module
45934594
class Expected1:
45944595
@R.function
45954596
def main(
4596-
inp_0: R.Tensor((2, 3), dtype="float32"),
4597-
inp_1: R.Tensor((2, 3), dtype="float32"),
4597+
x: R.Tensor((2, 3), dtype="float32"),
4598+
y: R.Tensor((2, 3), dtype="float32"),
45984599
) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
45994600
with R.dataflow():
4600-
lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, inp_1), axis=1)
4601-
gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
4601+
lv: R.Tensor((2, 6), dtype="float32") = R.concat((x, y), axis=1)
4602+
lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, R.shape([2, 2, 3]))
4603+
gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,)
46024604
R.output(gv)
46034605
return gv
46044606

46054607
@I.ir_module
46064608
class Expected3:
46074609
@R.function
46084610
def main(
4609-
inp_0: R.Tensor((2, 3), dtype="float32"),
4610-
inp_1: R.Tensor((2, 3), dtype="float32"),
4611+
x: R.Tensor((2, 3), dtype="float32"),
4612+
y: R.Tensor((2, 3), dtype="float32"),
46114613
) -> R.Tuple(R.Tensor((2, 3, 2), dtype="float32")):
46124614
with R.dataflow():
4613-
lv: R.Tensor((2, 3, 2), dtype="float32") = R.stack((inp_0, inp_1), axis=-1)
4614-
gv: R.Tuple(R.Tensor((2, 3, 2), dtype="float32")) = (lv,)
4615+
lv: R.Tensor((2, 3, 1), dtype="float32") = R.expand_dims(x, axis=[2])
4616+
lv1: R.Tensor((2, 3, 1), dtype="float32") = R.expand_dims(y, axis=[2])
4617+
lv2: R.Tensor((2, 3, 2), dtype="float32") = R.concat((lv, lv1), axis=-1)
4618+
gv: R.Tuple(R.Tensor((2, 3, 2), dtype="float32")) = (lv2,)
46154619
R.output(gv)
46164620
return gv
46174621

46184622
example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, dtype=torch.float32))
46194623

4620-
verify_model(Stack0(), example_args, {}, Expected0)
4621-
verify_model(Stack1(), example_args, {}, Expected1)
4622-
verify_model(Stack2(), example_args, {}, Expected1)
4623-
verify_model(Stack3(), example_args, {}, Expected3)
4624+
verify_model(Stack0(), example_args, {}, Expected0, run_ep_decomposition=True)
4625+
verify_model(Stack1(), example_args, {}, Expected1, run_ep_decomposition=True)
4626+
verify_model(Stack2(), example_args, {}, Expected1, run_ep_decomposition=True)
4627+
verify_model(Stack3(), example_args, {}, Expected3, run_ep_decomposition=True)
46244628

46254629

46264630
def test_tile():
@@ -4644,7 +4648,7 @@ def main(
46444648
) -> R.Tuple(R.Tensor((1, 6), dtype="float32")):
46454649
# block 0
46464650
with R.dataflow():
4647-
lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2])
4651+
lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, repeats=[1, 2])
46484652
gv: R.Tuple(R.Tensor((1, 6), dtype="float32")) = (lv,)
46494653
R.output(gv)
46504654
return gv
@@ -4657,15 +4661,15 @@ def main(
46574661
) -> R.Tuple(R.Tensor((4, 6), dtype="float32")):
46584662
# block 0
46594663
with R.dataflow():
4660-
lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2])
4664+
lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, repeats=[4, 2])
46614665
gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,)
46624666
R.output(gv)
46634667
return gv
46644668

46654669
example_args = (torch.randn(1, 3, dtype=torch.float32),)
4666-
verify_model(Tile1(), example_args, {}, expected1)
4667-
verify_model(Tile2(), example_args, {}, expected2)
4668-
verify_model(Tile3(), example_args, {}, expected2)
4670+
verify_model(Tile1(), example_args, {}, expected1, run_ep_decomposition=True)
4671+
verify_model(Tile2(), example_args, {}, expected2, run_ep_decomposition=True)
4672+
verify_model(Tile3(), example_args, {}, expected2, run_ep_decomposition=True)
46694673

46704674

46714675
def test_transpose():
@@ -4687,7 +4691,7 @@ def main(
46874691
return gv
46884692

46894693
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
4690-
verify_model(Transpose(), example_args, {}, expected1)
4694+
verify_model(Transpose(), example_args, {}, expected1, run_ep_decomposition=True)
46914695

46924696

46934697
def test_unsqueeze():
@@ -4727,8 +4731,8 @@ def main(
47274731

47284732
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
47294733

4730-
verify_model(Unsqueeze1(), example_args, {}, expected1)
4731-
verify_model(Unsqueeze2(), example_args, {}, expected2)
4734+
verify_model(Unsqueeze1(), example_args, {}, expected1, run_ep_decomposition=True)
4735+
verify_model(Unsqueeze2(), example_args, {}, expected2, run_ep_decomposition=True)
47324736

47334737

47344738
def test_view():
@@ -4750,7 +4754,7 @@ def main(
47504754
return gv
47514755

47524756
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
4753-
verify_model(View(), example_args, {}, expected1)
4757+
verify_model(View(), example_args, {}, expected1, run_ep_decomposition=True)
47544758

47554759

47564760
def test_arange():
@@ -4771,7 +4775,7 @@ def main(
47714775
return gv
47724776

47734777
example_args = (torch.randn(10, 10, dtype=torch.float32),)
4774-
verify_model(Arange(), example_args, {}, Expected)
4778+
verify_model(Arange(), example_args, {}, Expected, run_ep_decomposition=True)
47754779

47764780

47774781
def test_hamming_window():
@@ -4798,7 +4802,7 @@ def main(
47984802
return gv
47994803

48004804
example_args = (torch.randn(10, 10, dtype=torch.float32),)
4801-
verify_model(HammingWindow(), example_args, {}, Expected)
4805+
verify_model(HammingWindow(), example_args, {}, Expected, run_ep_decomposition=True)
48024806

48034807

48044808
def test_contiguous():
@@ -4818,7 +4822,7 @@ def main(
48184822
return gv
48194823

48204824
example_args = (torch.randn(10, 10, dtype=torch.float32),)
4821-
verify_model(Contiguous(), example_args, {}, Expected)
4825+
verify_model(Contiguous(), example_args, {}, Expected, run_ep_decomposition=True)
48224826

48234827

48244828
def test_clone():
@@ -4838,7 +4842,7 @@ def main(
48384842
return gv
48394843

48404844
example_args = (torch.randn(10, 10, dtype=torch.float32),)
4841-
verify_model(Clone(), example_args, {}, Expected)
4845+
verify_model(Clone(), example_args, {}, Expected, run_ep_decomposition=True)
48424846

48434847

48444848
def test_empty():
@@ -4850,7 +4854,7 @@ def forward(self, input):
48504854
class Expected:
48514855
@R.function
48524856
def main(
4853-
inp_0: R.Tensor((10, 10), dtype="float32")
4857+
input: R.Tensor((10, 10), dtype="float32")
48544858
) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
48554859
with R.dataflow():
48564860
lv: R.Tensor((10, 10), dtype="float32") = R.zeros(
@@ -4861,7 +4865,7 @@ def main(
48614865
return gv
48624866

48634867
example_args = (torch.randn(10, 10, dtype=torch.float32),)
4864-
verify_model(Empty(), example_args, {}, Expected)
4868+
verify_model(Empty(), example_args, {}, Expected, run_ep_decomposition=True)
48654869

48664870

48674871
def test_fill():
@@ -4873,18 +4877,18 @@ def forward(self, input: torch.Tensor):
48734877
class Expected:
48744878
@R.function
48754879
def main(
4876-
inp_0: R.Tensor((10, 10), dtype="float32")
4880+
input: R.Tensor((10, 10), dtype="float32")
48774881
) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
48784882
with R.dataflow():
4879-
lv: R.Tensor((10, 10), dtype="float32") = R.full(
4880-
R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32"
4883+
lv: R.Tensor((10, 10), dtype="float32") = R.full_like(
4884+
input, R.const(1.5, "float32"), dtype="void"
48814885
)
48824886
gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
48834887
R.output(gv)
48844888
return gv
48854889

48864890
example_args = (torch.randn(10, 10, dtype=torch.float32),)
4887-
verify_model(Fill(), example_args, {}, Expected)
4891+
verify_model(Fill(), example_args, {}, Expected, run_ep_decomposition=True)
48884892

48894893

48904894
def test_fill_inplace():
@@ -4897,18 +4901,20 @@ def forward(self, input: torch.Tensor):
48974901
class Expected:
48984902
@R.function
48994903
def main(
4900-
x: R.Tensor((2, 3), dtype="float32")
4901-
) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
4904+
input: R.Tensor((2, 3), dtype="float32")
4905+
) -> R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")):
49024906
with R.dataflow():
4903-
lv: R.Tensor((2, 3), dtype="float32") = R.full(
4904-
R.shape([2, 3]), R.const(42.0, "float32"), dtype="float32"
4907+
lv: R.Tensor((2, 3), dtype="float32") = R.full_like(
4908+
input, R.const(42.0, "float32"), dtype="void"
49054909
)
4906-
gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
4910+
gv: R.Tuple(
4911+
R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")
4912+
) = (lv, lv)
49074913
R.output(gv)
49084914
return gv
49094915

49104916
example_args = (torch.randn(2, 3, dtype=torch.float32),)
4911-
verify_model(FillInplace(), example_args, {}, Expected)
4917+
verify_model(FillInplace(), example_args, {}, Expected, run_ep_decomposition=True)
49124918

49134919

49144920
def test_masked_fill():
@@ -4923,16 +4929,14 @@ def main(
49234929
input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 128), dtype="bool")
49244930
) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
49254931
with R.dataflow():
4926-
lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
4927-
input, R.const(0, "int32"), dtype="void"
4928-
)
4932+
lv: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
49294933
lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, input)
49304934
gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,)
49314935
R.output(gv)
49324936
return gv
49334937

49344938
example_args = (torch.randn(128, 128, dtype=torch.float32), torch.rand(128, 128) < 0.5)
4935-
verify_model(Masked_Fill(), example_args, {}, Expected)
4939+
verify_model(Masked_Fill(), example_args, {}, Expected, run_ep_decomposition=True)
49364940

49374941

49384942
def test_masked_fill_inplace():
@@ -4945,18 +4949,18 @@ class Expected:
49454949
@R.function
49464950
def main(
49474951
input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 128), dtype="bool")
4948-
) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
4952+
) -> R.Tuple(R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32")):
49494953
with R.dataflow():
4950-
lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
4951-
input, R.const(1.5, "float32"), dtype="void"
4952-
)
4954+
lv: R.Tensor((), dtype="float32") = R.const(1.5, "float32")
49534955
lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, input)
4954-
gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,)
4956+
gv: R.Tuple(
4957+
R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32")
4958+
) = (lv1, lv1)
49554959
R.output(gv)
49564960
return gv
49574961

49584962
example_args = (torch.randn(128, 128, dtype=torch.float32), torch.rand(128, 128) < 0.5)
4959-
verify_model(Masked_Fill_Inplace(), example_args, {}, Expected)
4963+
verify_model(Masked_Fill_Inplace(), example_args, {}, Expected, run_ep_decomposition=True)
49604964

49614965

49624966
def test_new_ones():
@@ -4980,7 +4984,7 @@ def main(
49804984
return gv
49814985

49824986
example_args = (torch.randn(1, 2, 3, dtype=torch.float32),)
4983-
verify_model(NewOnes(), example_args, {}, expected1)
4987+
verify_model(NewOnes(), example_args, {}, expected1, run_ep_decomposition=True)
49844988

49854989

49864990
def test_new_zeros():
@@ -5003,7 +5007,7 @@ def main(
50035007
return gv
50045008

50055009
example_args = (torch.randn(1, 128, 128, dtype=torch.float32),)
5006-
verify_model(NewZeros(), example_args, {}, expected1)
5010+
verify_model(NewZeros(), example_args, {}, expected1, run_ep_decomposition=True)
50075011

50085012

50095013
def test_to_copy():
@@ -5094,11 +5098,11 @@ def main(
50945098
return gv
50955099

50965100
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
5097-
verify_model(ToFloat(), example_args, {}, expected_float)
5098-
verify_model(ToHalf(), example_args, {}, expected_half)
5099-
verify_model(Type(), example_args, {}, expected_type)
5100-
verify_model(To1(), example_args, {}, expected_to1)
5101-
verify_model(To2(), example_args, {}, expected_to2)
5101+
verify_model(ToFloat(), example_args, {}, expected_float, run_ep_decomposition=True)
5102+
verify_model(ToHalf(), example_args, {}, expected_half, run_ep_decomposition=True)
5103+
verify_model(Type(), example_args, {}, expected_type, run_ep_decomposition=True)
5104+
verify_model(To1(), example_args, {}, expected_to1, run_ep_decomposition=True)
5105+
verify_model(To2(), example_args, {}, expected_to2, run_ep_decomposition=True)
51025106

51035107

51045108
def test_keep_params():

0 commit comments

Comments
 (0)