Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 315b190

Browse files
authored
Fix SIR cache when call func many times (#21)
* Fix SIR cache when call func many times * Rename ut * Polish code * Remove bind * Rename var and fun
1 parent 038d464 commit 315b190

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

symbolic_trace/trace_cache_entrance.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,42 @@
99

1010

1111
def trace_cache(func):
12+
@no_eval_frame
1213
def call_with_cache(*args, **kwargs):
13-
args, kwargs = convert_arguments(args), convert_arguments(kwargs)
14+
args, kwargs = convert_arguments(args), convert_arguments(kwargs)
15+
args, kwargs, outter_names = construct_inner_proxy_tensor(func.__name__, *args, **kwargs)
16+
1417
if frame_enter(func.__name__, args):
15-
return cache_and_return(func.__name__, args)
18+
return cache_and_return(func.__name__, outter_names)
1619
ret = func(*args)
17-
frame_leave(func.__name__, ret)
20+
frame_leave(func.__name__, outter_names, ret)
1821
return ret
1922
return call_with_cache
2023

2124

25+
def construct_inner_proxy_tensor(func_name, *args, **kwargs):
26+
flat_args = paddle.utils.flatten(args)
27+
flat_kwargs = paddle.utils.flatten(kwargs)
28+
outter_names = []
29+
name_i = 0
30+
for i, v in enumerate(flat_args):
31+
if isinstance(v, ProxyTensor):
32+
name = '{}_input_{}'.format(func_name, name_i)
33+
outter_names.append(v.name)
34+
flat_args[i] = ProxyTensor(name, v.meta)
35+
name_i = name_i + 1
36+
for i, v in enumerate(flat_kwargs):
37+
if isinstance(v, ProxyTensor):
38+
name = '{}_input_{}'.format(func_name, name_i)
39+
outter_names.append(v.name)
40+
flat_kwargs[i] = ProxyTensor(name, v.meta)
41+
name_i = name_i + 1
42+
43+
args = paddle.utils.pack_sequence_as(args, flat_args)
44+
kwargs = paddle.utils.pack_sequence_as(kwargs, flat_kwargs)
45+
46+
return args, kwargs, outter_names
47+
2248
@no_eval_frame
2349
# should generate a unique name for every function
2450
def frame_enter(name, inputs):
@@ -52,7 +78,7 @@ def frame_enter(name, inputs):
5278

5379

5480
@no_eval_frame
55-
def frame_leave(name, outputs):
81+
def frame_leave(name, outter_names, outputs):
5682
key_name = SymbolicTraceContext().sir_key_stack[-1]
5783
SymbolicTraceContext().sir_key_stack.pop()
5884

@@ -88,24 +114,20 @@ def frame_leave(name, outputs):
88114
return
89115

90116
# at the first time, the inputs and outputs need not change
91-
SymbolicTraceContext().call_SIR(cur_sir.name, cur_sir.inputs, cur_sir.outputs)
117+
SymbolicTraceContext().call_SIR(cur_sir.name, [Symbol(name) for name in outter_names], cur_sir.outputs)
92118
log(1, cur_sir, "\n")
93119
return
94120

95121

96122
@no_eval_frame
97-
def cache_and_return(name, inputs):
123+
def cache_and_return(name, outter_names):
98124
key_name = SymbolicTraceContext().sir_key_stack[-1]
99125
SymbolicTraceContext().sir_key_stack.pop()
100126

101127
# find sir and it's origin_outputs
102128
cached_sir = SymbolicTraceContext().statement_factory[key_name]
103129
origin_outputs = SIRRuntimeCache().get_origin_outputs(key_name)
104130

105-
# gen call_SIR inputs
106-
flat_inputs = paddle.utils.flatten(inputs)
107-
symbol_inputs = [Symbol(x.name) for x in flat_inputs if isinstance(x, ProxyTensor)]
108-
109131
# create return value
110132
outputs = gen_new_proxy_tensor_output(origin_outputs)
111133

@@ -114,7 +136,7 @@ def cache_and_return(name, inputs):
114136
symbol_outputs = [Symbol(x.name) for x in flat_outputs if isinstance(x, ProxyTensor)]
115137

116138
# add call_SIR
117-
SymbolicTraceContext().call_SIR(cached_sir.name, symbol_inputs, symbol_outputs)
139+
SymbolicTraceContext().call_SIR(cached_sir.name, [Symbol(name) for name in outter_names], symbol_outputs)
118140
return outputs
119141

120142

tests/error_test_sir_call.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import unittest
22
import paddle
33
from symbolic_trace import symbolic_trace
4-
from symbolic_trace.trace_cache_entrance import frame_enter, frame_leave, cache_and_return
4+
from symbolic_trace.trace_cache_entrance import trace_cache
55

66

7+
@trace_cache
78
def sum(x, y):
8-
if frame_enter("sum", (x, y)):
9-
return cache_and_return("sum", (x, y))
109
ret = x + y
11-
frame_leave("sum", (ret))
1210
return ret
1311

1412
def main(x, y):
@@ -21,7 +19,7 @@ def test_return_callable(self):
2119
x = paddle.to_tensor([1.0])
2220
y = paddle.to_tensor([2.0])
2321
ret = symbolic_trace(main)(x, y)
24-
assert (ret.item() == 3.0), "Should be 4.0"
22+
assert (ret.item() == 3.0), "Should be 3.0"
2523

2624
if __name__ == "__main__":
2725
unittest.main()
File renamed without changes.

0 commit comments

Comments
 (0)