@@ -4580,47 +4580,51 @@ def forward(self, x, y):
45804580 class Expected0 :
45814581 @R .function
45824582 def main (
4583- inp_0 : R .Tensor ((2 , 3 ), dtype = "float32" ),
4584- inp_1 : R .Tensor ((2 , 3 ), dtype = "float32" ),
4583+ x : R .Tensor ((2 , 3 ), dtype = "float32" ),
4584+ y : R .Tensor ((2 , 3 ), dtype = "float32" ),
45854585 ) -> R .Tuple (R .Tensor ((2 , 2 , 3 ), dtype = "float32" )):
45864586 with R .dataflow ():
4587- lv : R .Tensor ((2 , 2 , 3 ), dtype = "float32" ) = R .stack ((inp_0 , inp_1 ), axis = 0 )
4588- gv : R .Tuple (R .Tensor ((2 , 2 , 3 ), dtype = "float32" )) = (lv ,)
4587+ lv : R .Tensor ((4 , 3 ), dtype = "float32" ) = R .concat ((x , y ), axis = 0 )
4588+ lv1 : R .Tensor ((2 , 2 , 3 ), dtype = "float32" ) = R .reshape (lv , R .shape ([2 , 2 , 3 ]))
4589+ gv : R .Tuple (R .Tensor ((2 , 2 , 3 ), dtype = "float32" )) = (lv1 ,)
45894590 R .output (gv )
45904591 return gv
45914592
45924593 @I .ir_module
45934594 class Expected1 :
45944595 @R .function
45954596 def main (
4596- inp_0 : R .Tensor ((2 , 3 ), dtype = "float32" ),
4597- inp_1 : R .Tensor ((2 , 3 ), dtype = "float32" ),
4597+ x : R .Tensor ((2 , 3 ), dtype = "float32" ),
4598+ y : R .Tensor ((2 , 3 ), dtype = "float32" ),
45984599 ) -> R .Tuple (R .Tensor ((2 , 2 , 3 ), dtype = "float32" )):
45994600 with R .dataflow ():
4600- lv : R .Tensor ((2 , 2 , 3 ), dtype = "float32" ) = R .stack ((inp_0 , inp_1 ), axis = 1 )
4601- gv : R .Tuple (R .Tensor ((2 , 2 , 3 ), dtype = "float32" )) = (lv ,)
4601+ lv : R .Tensor ((2 , 6 ), dtype = "float32" ) = R .concat ((x , y ), axis = 1 )
4602+ lv1 : R .Tensor ((2 , 2 , 3 ), dtype = "float32" ) = R .reshape (lv , R .shape ([2 , 2 , 3 ]))
4603+ gv : R .Tuple (R .Tensor ((2 , 2 , 3 ), dtype = "float32" )) = (lv1 ,)
46024604 R .output (gv )
46034605 return gv
46044606
46054607 @I .ir_module
46064608 class Expected3 :
46074609 @R .function
46084610 def main (
4609- inp_0 : R .Tensor ((2 , 3 ), dtype = "float32" ),
4610- inp_1 : R .Tensor ((2 , 3 ), dtype = "float32" ),
4611+ x : R .Tensor ((2 , 3 ), dtype = "float32" ),
4612+ y : R .Tensor ((2 , 3 ), dtype = "float32" ),
46114613 ) -> R .Tuple (R .Tensor ((2 , 3 , 2 ), dtype = "float32" )):
46124614 with R .dataflow ():
4613- lv : R .Tensor ((2 , 3 , 2 ), dtype = "float32" ) = R .stack ((inp_0 , inp_1 ), axis = - 1 )
4614- gv : R .Tuple (R .Tensor ((2 , 3 , 2 ), dtype = "float32" )) = (lv ,)
4615+ lv : R .Tensor ((2 , 3 , 1 ), dtype = "float32" ) = R .expand_dims (x , axis = [2 ])
4616+ lv1 : R .Tensor ((2 , 3 , 1 ), dtype = "float32" ) = R .expand_dims (y , axis = [2 ])
4617+ lv2 : R .Tensor ((2 , 3 , 2 ), dtype = "float32" ) = R .concat ((lv , lv1 ), axis = - 1 )
4618+ gv : R .Tuple (R .Tensor ((2 , 3 , 2 ), dtype = "float32" )) = (lv2 ,)
46154619 R .output (gv )
46164620 return gv
46174621
46184622 example_args = (torch .randn (2 , 3 , dtype = torch .float32 ), torch .randn (2 , 3 , dtype = torch .float32 ))
46194623
4620- verify_model (Stack0 (), example_args , {}, Expected0 )
4621- verify_model (Stack1 (), example_args , {}, Expected1 )
4622- verify_model (Stack2 (), example_args , {}, Expected1 )
4623- verify_model (Stack3 (), example_args , {}, Expected3 )
4624+ verify_model (Stack0 (), example_args , {}, Expected0 , run_ep_decomposition = True )
4625+ verify_model (Stack1 (), example_args , {}, Expected1 , run_ep_decomposition = True )
4626+ verify_model (Stack2 (), example_args , {}, Expected1 , run_ep_decomposition = True )
4627+ verify_model (Stack3 (), example_args , {}, Expected3 , run_ep_decomposition = True )
46244628
46254629
46264630def test_tile ():
@@ -4644,7 +4648,7 @@ def main(
46444648 ) -> R .Tuple (R .Tensor ((1 , 6 ), dtype = "float32" )):
46454649 # block 0
46464650 with R .dataflow ():
4647- lv : R .Tensor ((1 , 6 ), dtype = "float32" ) = R .tile (x , [ 2 ])
4651+ lv : R .Tensor ((1 , 6 ), dtype = "float32" ) = R .tile (x , repeats = [ 1 , 2 ])
46484652 gv : R .Tuple (R .Tensor ((1 , 6 ), dtype = "float32" )) = (lv ,)
46494653 R .output (gv )
46504654 return gv
@@ -4657,15 +4661,15 @@ def main(
46574661 ) -> R .Tuple (R .Tensor ((4 , 6 ), dtype = "float32" )):
46584662 # block 0
46594663 with R .dataflow ():
4660- lv : R .Tensor ((4 , 6 ), dtype = "float32" ) = R .tile (x , [4 , 2 ])
4664+ lv : R .Tensor ((4 , 6 ), dtype = "float32" ) = R .tile (x , repeats = [4 , 2 ])
46614665 gv : R .Tuple (R .Tensor ((4 , 6 ), dtype = "float32" )) = (lv ,)
46624666 R .output (gv )
46634667 return gv
46644668
46654669 example_args = (torch .randn (1 , 3 , dtype = torch .float32 ),)
4666- verify_model (Tile1 (), example_args , {}, expected1 )
4667- verify_model (Tile2 (), example_args , {}, expected2 )
4668- verify_model (Tile3 (), example_args , {}, expected2 )
4670+ verify_model (Tile1 (), example_args , {}, expected1 , run_ep_decomposition = True )
4671+ verify_model (Tile2 (), example_args , {}, expected2 , run_ep_decomposition = True )
4672+ verify_model (Tile3 (), example_args , {}, expected2 , run_ep_decomposition = True )
46694673
46704674
46714675def test_transpose ():
@@ -4687,7 +4691,7 @@ def main(
46874691 return gv
46884692
46894693 example_args = (torch .randn (1 , 2 , 3 , 4 , dtype = torch .float32 ),)
4690- verify_model (Transpose (), example_args , {}, expected1 )
4694+ verify_model (Transpose (), example_args , {}, expected1 , run_ep_decomposition = True )
46914695
46924696
46934697def test_unsqueeze ():
@@ -4727,8 +4731,8 @@ def main(
47274731
47284732 example_args = (torch .randn (1 , 3 , 10 , 10 , dtype = torch .float32 ),)
47294733
4730- verify_model (Unsqueeze1 (), example_args , {}, expected1 )
4731- verify_model (Unsqueeze2 (), example_args , {}, expected2 )
4734+ verify_model (Unsqueeze1 (), example_args , {}, expected1 , run_ep_decomposition = True )
4735+ verify_model (Unsqueeze2 (), example_args , {}, expected2 , run_ep_decomposition = True )
47324736
47334737
47344738def test_view ():
@@ -4750,7 +4754,7 @@ def main(
47504754 return gv
47514755
47524756 example_args = (torch .randn (1 , 2 , 3 , 4 , dtype = torch .float32 ),)
4753- verify_model (View (), example_args , {}, expected1 )
4757+ verify_model (View (), example_args , {}, expected1 , run_ep_decomposition = True )
47544758
47554759
47564760def test_arange ():
@@ -4771,7 +4775,7 @@ def main(
47714775 return gv
47724776
47734777 example_args = (torch .randn (10 , 10 , dtype = torch .float32 ),)
4774- verify_model (Arange (), example_args , {}, Expected )
4778+ verify_model (Arange (), example_args , {}, Expected , run_ep_decomposition = True )
47754779
47764780
47774781def test_hamming_window ():
@@ -4798,7 +4802,7 @@ def main(
47984802 return gv
47994803
48004804 example_args = (torch .randn (10 , 10 , dtype = torch .float32 ),)
4801- verify_model (HammingWindow (), example_args , {}, Expected )
4805+ verify_model (HammingWindow (), example_args , {}, Expected , run_ep_decomposition = True )
48024806
48034807
48044808def test_contiguous ():
@@ -4818,7 +4822,7 @@ def main(
48184822 return gv
48194823
48204824 example_args = (torch .randn (10 , 10 , dtype = torch .float32 ),)
4821- verify_model (Contiguous (), example_args , {}, Expected )
4825+ verify_model (Contiguous (), example_args , {}, Expected , run_ep_decomposition = True )
48224826
48234827
48244828def test_clone ():
@@ -4838,7 +4842,7 @@ def main(
48384842 return gv
48394843
48404844 example_args = (torch .randn (10 , 10 , dtype = torch .float32 ),)
4841- verify_model (Clone (), example_args , {}, Expected )
4845+ verify_model (Clone (), example_args , {}, Expected , run_ep_decomposition = True )
48424846
48434847
48444848def test_empty ():
@@ -4850,7 +4854,7 @@ def forward(self, input):
48504854 class Expected :
48514855 @R .function
48524856 def main (
4853- inp_0 : R .Tensor ((10 , 10 ), dtype = "float32" )
4857+ input : R .Tensor ((10 , 10 ), dtype = "float32" )
48544858 ) -> R .Tuple (R .Tensor ((10 , 10 ), dtype = "float32" )):
48554859 with R .dataflow ():
48564860 lv : R .Tensor ((10 , 10 ), dtype = "float32" ) = R .zeros (
@@ -4861,7 +4865,7 @@ def main(
48614865 return gv
48624866
48634867 example_args = (torch .randn (10 , 10 , dtype = torch .float32 ),)
4864- verify_model (Empty (), example_args , {}, Expected )
4868+ verify_model (Empty (), example_args , {}, Expected , run_ep_decomposition = True )
48654869
48664870
48674871def test_fill ():
@@ -4873,18 +4877,18 @@ def forward(self, input: torch.Tensor):
48734877 class Expected :
48744878 @R .function
48754879 def main (
4876- inp_0 : R .Tensor ((10 , 10 ), dtype = "float32" )
4880+ input : R .Tensor ((10 , 10 ), dtype = "float32" )
48774881 ) -> R .Tuple (R .Tensor ((10 , 10 ), dtype = "float32" )):
48784882 with R .dataflow ():
4879- lv : R .Tensor ((10 , 10 ), dtype = "float32" ) = R .full (
4880- R . shape ([ 10 , 10 ]), R .const (1.5 , "float32" ), dtype = "float32 "
4883+ lv : R .Tensor ((10 , 10 ), dtype = "float32" ) = R .full_like (
4884+ input , R .const (1.5 , "float32" ), dtype = "void "
48814885 )
48824886 gv : R .Tuple (R .Tensor ((10 , 10 ), dtype = "float32" )) = (lv ,)
48834887 R .output (gv )
48844888 return gv
48854889
48864890 example_args = (torch .randn (10 , 10 , dtype = torch .float32 ),)
4887- verify_model (Fill (), example_args , {}, Expected )
4891+ verify_model (Fill (), example_args , {}, Expected , run_ep_decomposition = True )
48884892
48894893
48904894def test_fill_inplace ():
@@ -4897,18 +4901,20 @@ def forward(self, input: torch.Tensor):
48974901 class Expected :
48984902 @R .function
48994903 def main (
4900- x : R .Tensor ((2 , 3 ), dtype = "float32" )
4901- ) -> R .Tuple (R .Tensor ((2 , 3 ), dtype = "float32" )):
4904+ input : R .Tensor ((2 , 3 ), dtype = "float32" )
4905+ ) -> R .Tuple (R .Tensor ((2 , 3 ), dtype = "float32" ), R . Tensor (( 2 , 3 ), dtype = "float32" ) ):
49024906 with R .dataflow ():
4903- lv : R .Tensor ((2 , 3 ), dtype = "float32" ) = R .full (
4904- R . shape ([ 2 , 3 ]), R .const (42.0 , "float32" ), dtype = "float32 "
4907+ lv : R .Tensor ((2 , 3 ), dtype = "float32" ) = R .full_like (
4908+ input , R .const (42.0 , "float32" ), dtype = "void "
49054909 )
4906- gv : R .Tuple (R .Tensor ((2 , 3 ), dtype = "float32" )) = (lv ,)
4910+ gv : R .Tuple (
4911+ R .Tensor ((2 , 3 ), dtype = "float32" ), R .Tensor ((2 , 3 ), dtype = "float32" )
4912+ ) = (lv , lv )
49074913 R .output (gv )
49084914 return gv
49094915
49104916 example_args = (torch .randn (2 , 3 , dtype = torch .float32 ),)
4911- verify_model (FillInplace (), example_args , {}, Expected )
4917+ verify_model (FillInplace (), example_args , {}, Expected , run_ep_decomposition = True )
49124918
49134919
49144920def test_masked_fill ():
@@ -4923,16 +4929,14 @@ def main(
49234929 input : R .Tensor ((128 , 128 ), dtype = "float32" ), mask : R .Tensor ((128 , 128 ), dtype = "bool" )
49244930 ) -> R .Tuple (R .Tensor ((128 , 128 ), dtype = "float32" )):
49254931 with R .dataflow ():
4926- lv : R .Tensor ((128 , 128 ), dtype = "float32" ) = R .full_like (
4927- input , R .const (0 , "int32" ), dtype = "void"
4928- )
4932+ lv : R .Tensor ((), dtype = "float32" ) = R .const (0.0 , "float32" )
49294933 lv1 : R .Tensor ((128 , 128 ), dtype = "float32" ) = R .where (mask , lv , input )
49304934 gv : R .Tuple (R .Tensor ((128 , 128 ), dtype = "float32" )) = (lv1 ,)
49314935 R .output (gv )
49324936 return gv
49334937
49344938 example_args = (torch .randn (128 , 128 , dtype = torch .float32 ), torch .rand (128 , 128 ) < 0.5 )
4935- verify_model (Masked_Fill (), example_args , {}, Expected )
4939+ verify_model (Masked_Fill (), example_args , {}, Expected , run_ep_decomposition = True )
49364940
49374941
49384942def test_masked_fill_inplace ():
@@ -4945,18 +4949,18 @@ class Expected:
49454949 @R .function
49464950 def main (
49474951 input : R .Tensor ((128 , 128 ), dtype = "float32" ), mask : R .Tensor ((128 , 128 ), dtype = "bool" )
4948- ) -> R .Tuple (R .Tensor ((128 , 128 ), dtype = "float32" )):
4952+ ) -> R .Tuple (R .Tensor ((128 , 128 ), dtype = "float32" ), R . Tensor (( 128 , 128 ), dtype = "float32" ) ):
49494953 with R .dataflow ():
4950- lv : R .Tensor ((128 , 128 ), dtype = "float32" ) = R .full_like (
4951- input , R .const (1.5 , "float32" ), dtype = "void"
4952- )
4954+ lv : R .Tensor ((), dtype = "float32" ) = R .const (1.5 , "float32" )
49534955 lv1 : R .Tensor ((128 , 128 ), dtype = "float32" ) = R .where (mask , lv , input )
4954- gv : R .Tuple (R .Tensor ((128 , 128 ), dtype = "float32" )) = (lv1 ,)
4956+ gv : R .Tuple (
4957+ R .Tensor ((128 , 128 ), dtype = "float32" ), R .Tensor ((128 , 128 ), dtype = "float32" )
4958+ ) = (lv1 , lv1 )
49554959 R .output (gv )
49564960 return gv
49574961
49584962 example_args = (torch .randn (128 , 128 , dtype = torch .float32 ), torch .rand (128 , 128 ) < 0.5 )
4959- verify_model (Masked_Fill_Inplace (), example_args , {}, Expected )
4963+ verify_model (Masked_Fill_Inplace (), example_args , {}, Expected , run_ep_decomposition = True )
49604964
49614965
49624966def test_new_ones ():
@@ -4980,7 +4984,7 @@ def main(
49804984 return gv
49814985
49824986 example_args = (torch .randn (1 , 2 , 3 , dtype = torch .float32 ),)
4983- verify_model (NewOnes (), example_args , {}, expected1 )
4987+ verify_model (NewOnes (), example_args , {}, expected1 , run_ep_decomposition = True )
49844988
49854989
49864990def test_new_zeros ():
@@ -5003,7 +5007,7 @@ def main(
50035007 return gv
50045008
50055009 example_args = (torch .randn (1 , 128 , 128 , dtype = torch .float32 ),)
5006- verify_model (NewZeros (), example_args , {}, expected1 )
5010+ verify_model (NewZeros (), example_args , {}, expected1 , run_ep_decomposition = True )
50075011
50085012
50095013def test_to_copy ():
@@ -5094,11 +5098,11 @@ def main(
50945098 return gv
50955099
50965100 example_args = (torch .randn (1 , 2 , 3 , 4 , dtype = torch .float32 ),)
5097- verify_model (ToFloat (), example_args , {}, expected_float )
5098- verify_model (ToHalf (), example_args , {}, expected_half )
5099- verify_model (Type (), example_args , {}, expected_type )
5100- verify_model (To1 (), example_args , {}, expected_to1 )
5101- verify_model (To2 (), example_args , {}, expected_to2 )
5101+ verify_model (ToFloat (), example_args , {}, expected_float , run_ep_decomposition = True )
5102+ verify_model (ToHalf (), example_args , {}, expected_half , run_ep_decomposition = True )
5103+ verify_model (Type (), example_args , {}, expected_type , run_ep_decomposition = True )
5104+ verify_model (To1 (), example_args , {}, expected_to1 , run_ep_decomposition = True )
5105+ verify_model (To2 (), example_args , {}, expected_to2 , run_ep_decomposition = True )
51025106
51035107
51045108def test_keep_params ():
0 commit comments