99
1010
1111def trace_cache (func ):
12+ @no_eval_frame
1213 def call_with_cache (* args , ** kwargs ):
13- args , kwargs = convert_arguments (args ), convert_arguments (kwargs )
14+ args , kwargs = convert_arguments (args ), convert_arguments (kwargs )
15+ args , kwargs , outter_names = construct_inner_proxy_tensor (func .__name__ , * args , ** kwargs )
16+
1417 if frame_enter (func .__name__ , args ):
15- return cache_and_return (func .__name__ , args )
18+ return cache_and_return (func .__name__ , outter_names )
1619 ret = func (* args )
17- frame_leave (func .__name__ , ret )
20+ frame_leave (func .__name__ , outter_names , ret )
1821 return ret
1922 return call_with_cache
2023
2124
25+ def construct_inner_proxy_tensor (func_name , * args , ** kwargs ):
26+ flat_args = paddle .utils .flatten (args )
27+ flat_kwargs = paddle .utils .flatten (kwargs )
28+ outter_names = []
29+ name_i = 0
30+ for i , v in enumerate (flat_args ):
31+ if isinstance (v , ProxyTensor ):
32+ name = '{}_input_{}' .format (func_name , name_i )
33+ outter_names .append (v .name )
34+ flat_args [i ] = ProxyTensor (name , v .meta )
35+ name_i = name_i + 1
36+ for i , v in enumerate (flat_kwargs ):
37+ if isinstance (v , ProxyTensor ):
38+ name = '{}_input_{}' .format (func_name , name_i )
39+ outter_names .append (v .name )
40+ flat_kwargs [i ] = ProxyTensor (name , v .meta )
41+ name_i = name_i + 1
42+
43+ args = paddle .utils .pack_sequence_as (args , flat_args )
44+ kwargs = paddle .utils .pack_sequence_as (kwargs , flat_kwargs )
45+
46+ return args , kwargs , outter_names
47+
2248@no_eval_frame
2349# should generate a unique name for every function
2450def frame_enter (name , inputs ):
@@ -52,7 +78,7 @@ def frame_enter(name, inputs):
5278
5379
5480@no_eval_frame
55- def frame_leave (name , outputs ):
81+ def frame_leave (name , outter_names , outputs ):
5682 key_name = SymbolicTraceContext ().sir_key_stack [- 1 ]
5783 SymbolicTraceContext ().sir_key_stack .pop ()
5884
@@ -88,24 +114,20 @@ def frame_leave(name, outputs):
88114 return
89115
90116 # at the first time, the inputs and outputs need not change
91- SymbolicTraceContext ().call_SIR (cur_sir .name , cur_sir . inputs , cur_sir .outputs )
117+ SymbolicTraceContext ().call_SIR (cur_sir .name , [ Symbol ( name ) for name in outter_names ] , cur_sir .outputs )
92118 log (1 , cur_sir , "\n " )
93119 return
94120
95121
96122@no_eval_frame
97- def cache_and_return (name , inputs ):
123+ def cache_and_return (name , outter_names ):
98124 key_name = SymbolicTraceContext ().sir_key_stack [- 1 ]
99125 SymbolicTraceContext ().sir_key_stack .pop ()
100126
101127 # find sir and it's origin_outputs
102128 cached_sir = SymbolicTraceContext ().statement_factory [key_name ]
103129 origin_outputs = SIRRuntimeCache ().get_origin_outputs (key_name )
104130
105- # gen call_SIR inputs
106- flat_inputs = paddle .utils .flatten (inputs )
107- symbol_inputs = [Symbol (x .name ) for x in flat_inputs if isinstance (x , ProxyTensor )]
108-
109131 # create return value
110132 outputs = gen_new_proxy_tensor_output (origin_outputs )
111133
@@ -114,7 +136,7 @@ def cache_and_return(name, inputs):
114136 symbol_outputs = [Symbol (x .name ) for x in flat_outputs if isinstance (x , ProxyTensor )]
115137
116138 # add call_SIR
117- SymbolicTraceContext ().call_SIR (cached_sir .name , symbol_inputs , symbol_outputs )
139+ SymbolicTraceContext ().call_SIR (cached_sir .name , [ Symbol ( name ) for name in outter_names ] , symbol_outputs )
118140 return outputs
119141
120142
0 commit comments