Skip to content

Commit 5ce0fe8

Browse files
authored
gh-148378: Allow multiple consecutive recording ops per macro op (GH-148496)
1 parent 21da9d7 commit 5ce0fe8

File tree

9 files changed

+318
-82
lines changed

9 files changed

+318
-82
lines changed

Include/internal/pycore_optimizer.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,15 @@ typedef struct _PyJitTracerInitialState {
9191
_Py_CODEUNIT *jump_backward_instr;
9292
} _PyJitTracerInitialState;
9393

94+
#define MAX_RECORDED_VALUES 3
9495
typedef struct _PyJitTracerPreviousState {
9596
int instr_oparg;
9697
int instr_stacklevel;
9798
_Py_CODEUNIT *instr;
9899
PyCodeObject *instr_code; // Strong
99100
struct _PyInterpreterFrame *instr_frame;
100-
PyObject *recorded_value; // Strong, may be NULL
101+
PyObject *recorded_values[MAX_RECORDED_VALUES]; // Strong, may be NULL
102+
int recorded_count;
101103
} _PyJitTracerPreviousState;
102104

103105
typedef struct _PyJitTracerTranslatorState {
@@ -481,7 +483,12 @@ void _PyJit_TracerFree(_PyThreadStateImpl *_tstate);
481483
#ifdef _Py_TIER2
482484
typedef void (*_Py_RecordFuncPtr)(_PyInterpreterFrame *frame, _PyStackRef *stackpointer, int oparg, PyObject **recorded_value);
483485
PyAPI_DATA(const _Py_RecordFuncPtr) _PyOpcode_RecordFunctions[];
484-
PyAPI_DATA(const uint8_t) _PyOpcode_RecordFunctionIndices[256];
486+
487+
typedef struct {
488+
uint8_t count;
489+
uint8_t indices[MAX_RECORDED_VALUES];
490+
} _PyOpcodeRecordEntry;
491+
PyAPI_DATA(const _PyOpcodeRecordEntry) _PyOpcode_RecordEntries[256];
485492
#endif
486493

487494
#ifdef __cplusplus

Lib/test/test_generated_cases.py

Lines changed: 198 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ def skip_if_different_mount_drives():
2929

3030
test_tools.skip_if_missing("cases_generator")
3131
with test_tools.imports_under_tool("cases_generator"):
32-
from analyzer import StackItem
32+
from analyzer import StackItem, analyze_files
3333
from cwriter import CWriter
3434
import parser
3535
from stack import Local, Stack
3636
import tier1_generator
3737
import optimizer_generator
38+
import record_function_generator
3839

3940

4041
def handle_stderr():
@@ -1948,6 +1949,202 @@ def test_recording_after_non_specializing(self):
19481949
with self.assertRaisesRegex(SyntaxError, "Recording uop"):
19491950
self.run_cases_test(input, "")
19501951

1952+
def test_multiple_consecutive_recording_uops(self):
1953+
"""Multiple consecutive recording uops at the start of a macro are legal."""
1954+
input = """
1955+
tier2 op(_RECORD_A, (a, b -- a, b)) {
1956+
RECORD_VALUE(a);
1957+
}
1958+
tier2 op(_RECORD_B, (a, b -- a, b)) {
1959+
RECORD_VALUE(b);
1960+
}
1961+
op(_DO_STUFF, (a, b -- res)) {
1962+
res = a;
1963+
INPUTS_DEAD();
1964+
}
1965+
macro(OP) = _RECORD_A + _RECORD_B + _DO_STUFF;
1966+
"""
1967+
output = """
1968+
TARGET(OP) {
1969+
#if _Py_TAIL_CALL_INTERP
1970+
int opcode = OP;
1971+
(void)(opcode);
1972+
#endif
1973+
frame->instr_ptr = next_instr;
1974+
next_instr += 1;
1975+
INSTRUCTION_STATS(OP);
1976+
_PyStackRef a;
1977+
_PyStackRef res;
1978+
// _DO_STUFF
1979+
{
1980+
a = stack_pointer[-2];
1981+
res = a;
1982+
}
1983+
stack_pointer[-2] = res;
1984+
stack_pointer += -1;
1985+
ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__);
1986+
DISPATCH();
1987+
}
1988+
"""
1989+
self.run_cases_test(input, output)
1990+
1991+
def test_multiple_recording_uops_after_specializing(self):
1992+
"""Multiple recording uops after a specializing uop are legal."""
1993+
input = """
1994+
specializing op(_SPECIALIZE_OP, (counter/1, a, b -- a, b)) {
1995+
SPAM();
1996+
}
1997+
tier2 op(_RECORD_A, (a, b -- a, b)) {
1998+
RECORD_VALUE(a);
1999+
}
2000+
tier2 op(_RECORD_B, (a, b -- a, b)) {
2001+
RECORD_VALUE(b);
2002+
}
2003+
op(_DO_STUFF, (a, b -- res)) {
2004+
res = a;
2005+
INPUTS_DEAD();
2006+
}
2007+
macro(OP) = _SPECIALIZE_OP + _RECORD_A + _RECORD_B + unused/2 + _DO_STUFF;
2008+
"""
2009+
output = """
2010+
TARGET(OP) {
2011+
#if _Py_TAIL_CALL_INTERP
2012+
int opcode = OP;
2013+
(void)(opcode);
2014+
#endif
2015+
_Py_CODEUNIT* const this_instr = next_instr;
2016+
(void)this_instr;
2017+
frame->instr_ptr = next_instr;
2018+
next_instr += 4;
2019+
INSTRUCTION_STATS(OP);
2020+
_PyStackRef a;
2021+
_PyStackRef res;
2022+
// _SPECIALIZE_OP
2023+
{
2024+
uint16_t counter = read_u16(&this_instr[1].cache);
2025+
(void)counter;
2026+
SPAM();
2027+
}
2028+
/* Skip 2 cache entries */
2029+
// _DO_STUFF
2030+
{
2031+
a = stack_pointer[-2];
2032+
res = a;
2033+
}
2034+
stack_pointer[-2] = res;
2035+
stack_pointer += -1;
2036+
ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__);
2037+
DISPATCH();
2038+
}
2039+
"""
2040+
self.run_cases_test(input, output)
2041+
2042+
def test_recording_uop_between_real_uops_rejected(self):
2043+
"""A recording uop sandwiched between real uops is rejected."""
2044+
input = """
2045+
tier2 op(_RECORD_A, (a, b -- a, b)) {
2046+
RECORD_VALUE(a);
2047+
}
2048+
op(_FIRST, (a, b -- a, b)) {
2049+
first(a);
2050+
}
2051+
tier2 op(_RECORD_B, (a, b -- a, b)) {
2052+
RECORD_VALUE(b);
2053+
}
2054+
macro(OP) = _RECORD_A + _FIRST + _RECORD_B;
2055+
"""
2056+
with self.assertRaisesRegex(SyntaxError,
2057+
"must precede all "
2058+
"non-recording, non-specializing uops"):
2059+
self.run_cases_test(input, "")
2060+
2061+
2062+
class TestRecorderTableGeneration(unittest.TestCase):
2063+
2064+
def setUp(self) -> None:
2065+
super().setUp()
2066+
self.maxDiff = None
2067+
self.temp_dir = tempfile.gettempdir()
2068+
self.temp_input_filename = os.path.join(self.temp_dir, "input.txt")
2069+
2070+
def tearDown(self) -> None:
2071+
try:
2072+
os.remove(self.temp_input_filename)
2073+
except FileNotFoundError:
2074+
pass
2075+
super().tearDown()
2076+
2077+
def generate_tables(self, input: str) -> str:
2078+
import io
2079+
with open(self.temp_input_filename, "w+") as f:
2080+
f.write(parser.BEGIN_MARKER)
2081+
f.write(input)
2082+
f.write(parser.END_MARKER)
2083+
with handle_stderr():
2084+
analysis = analyze_files([self.temp_input_filename])
2085+
buf = io.StringIO()
2086+
out = CWriter(buf, 0, False)
2087+
record_function_generator.generate_recorder_tables(analysis, out)
2088+
return buf.getvalue()
2089+
2090+
def test_single_recording_uop_generates_count(self):
2091+
input = """
2092+
tier2 op(_RECORD_TOS, (value -- value)) {
2093+
RECORD_VALUE(value);
2094+
}
2095+
op(_DO_STUFF, (value -- res)) {
2096+
res = value;
2097+
}
2098+
macro(OP) = _RECORD_TOS + _DO_STUFF;
2099+
"""
2100+
output = self.generate_tables(input)
2101+
self.assertIn("_RECORD_TOS_INDEX", output)
2102+
self.assertIn("[OP] = {1, {_RECORD_TOS_INDEX}}", output)
2103+
2104+
def test_three_recording_uops_generate_count_3_in_order(self):
2105+
input = """
2106+
tier2 op(_RECORD_X, (a, b, c -- a, b, c)) {
2107+
RECORD_VALUE(a);
2108+
}
2109+
tier2 op(_RECORD_Y, (a, b, c -- a, b, c)) {
2110+
RECORD_VALUE(b);
2111+
}
2112+
tier2 op(_RECORD_Z, (a, b, c -- a, b, c)) {
2113+
RECORD_VALUE(c);
2114+
}
2115+
op(_DO_STUFF, (a, b, c -- res)) {
2116+
res = a;
2117+
}
2118+
macro(OP) = _RECORD_X + _RECORD_Y + _RECORD_Z + _DO_STUFF;
2119+
"""
2120+
output = self.generate_tables(input)
2121+
self.assertIn(
2122+
"[OP] = {3, {_RECORD_X_INDEX, _RECORD_Y_INDEX, _RECORD_Z_INDEX}}",
2123+
output,
2124+
)
2125+
2126+
def test_four_recording_uops_rejected(self):
2127+
input = """
2128+
tier2 op(_RECORD_A, (a, b, c, d -- a, b, c, d)) {
2129+
RECORD_VALUE(a);
2130+
}
2131+
tier2 op(_RECORD_B, (a, b, c, d -- a, b, c, d)) {
2132+
RECORD_VALUE(b);
2133+
}
2134+
tier2 op(_RECORD_C, (a, b, c, d -- a, b, c, d)) {
2135+
RECORD_VALUE(c);
2136+
}
2137+
tier2 op(_RECORD_D, (a, b, c, d -- a, b, c, d)) {
2138+
RECORD_VALUE(d);
2139+
}
2140+
op(_DO_STUFF, (a, b, c, d -- res)) {
2141+
res = a;
2142+
}
2143+
macro(OP) = _RECORD_A + _RECORD_B + _RECORD_C + _RECORD_D + _DO_STUFF;
2144+
"""
2145+
with self.assertRaisesRegex(ValueError, "exceeds MAX_RECORDED_VALUES"):
2146+
self.generate_tables(input)
2147+
19512148

19522149
class TestGeneratedAbstractCases(unittest.TestCase):
19532150
def setUp(self) -> None:

Modules/_testinternalcapi/test_cases.c.h

Lines changed: 11 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Python/bytecodes.c

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6349,7 +6349,10 @@ dummy_func(
63496349
ERROR_IF(err < 0);
63506350
DISPATCH();
63516351
}
6352-
Py_CLEAR(tracer->prev_state.recorded_value);
6352+
for (int i = 0; i < tracer->prev_state.recorded_count; i++) {
6353+
Py_CLEAR(tracer->prev_state.recorded_values[i]);
6354+
}
6355+
tracer->prev_state.recorded_count = 0;
63536356
tracer->prev_state.instr = next_instr;
63546357
PyObject *prev_code = PyStackRef_AsPyObjectBorrow(frame->f_executable);
63556358
if (tracer->prev_state.instr_code != (PyCodeObject *)prev_code) {
@@ -6363,11 +6366,12 @@ dummy_func(
63636366
(&next_instr[1])->counter = trigger_backoff_counter();
63646367
}
63656368

6366-
uint8_t record_func_index = _PyOpcode_RecordFunctionIndices[opcode];
6367-
if (record_func_index) {
6368-
_Py_RecordFuncPtr doesnt_escape = _PyOpcode_RecordFunctions[record_func_index];
6369-
doesnt_escape(frame, stack_pointer, oparg, &tracer->prev_state.recorded_value);
6369+
const _PyOpcodeRecordEntry *record_entry = &_PyOpcode_RecordEntries[opcode];
6370+
for (int i = 0; i < record_entry->count; i++) {
6371+
_Py_RecordFuncPtr doesnt_escape = _PyOpcode_RecordFunctions[record_entry->indices[i]];
6372+
doesnt_escape(frame, stack_pointer, oparg, &tracer->prev_state.recorded_values[i]);
63706373
}
6374+
tracer->prev_state.recorded_count = record_entry->count;
63716375
DISPATCH_GOTO_NON_TRACING();
63726376
#else
63736377
(void)prev_instr;

Python/generated_cases.c.h

Lines changed: 11 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Python/optimizer.c

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@ _PyJit_translate_single_bytecode_to_trace(
866866
assert(nuops > 0);
867867
uint32_t orig_oparg = oparg; // For OPARG_TOP/BOTTOM
868868
uint32_t orig_target = target;
869+
int record_idx = 0;
869870
for (int i = 0; i < nuops; i++) {
870871
oparg = orig_oparg;
871872
target = orig_target;
@@ -946,8 +947,9 @@ _PyJit_translate_single_bytecode_to_trace(
946947
operand = next->op.arg;
947948
}
948949
else if (_PyUop_Flags[uop] & HAS_RECORDS_VALUE_FLAG) {
949-
PyObject *recorded_value = tracer->prev_state.recorded_value;
950-
tracer->prev_state.recorded_value = NULL;
950+
PyObject *recorded_value = tracer->prev_state.recorded_values[record_idx];
951+
tracer->prev_state.recorded_values[record_idx] = NULL;
952+
record_idx++;
951953
operand = (uintptr_t)recorded_value;
952954
}
953955
// All other instructions
@@ -1060,12 +1062,16 @@ _PyJit_TryInitializeTracing(
10601062
tracer->prev_state.instr_frame = frame;
10611063
tracer->prev_state.instr_oparg = oparg;
10621064
tracer->prev_state.instr_stacklevel = tracer->initial_state.stack_depth;
1063-
tracer->prev_state.recorded_value = NULL;
1064-
uint8_t record_func_index = _PyOpcode_RecordFunctionIndices[curr_instr->op.code];
1065-
if (record_func_index) {
1066-
_Py_RecordFuncPtr record_func = _PyOpcode_RecordFunctions[record_func_index];
1067-
record_func(frame, stack_pointer, oparg, &tracer->prev_state.recorded_value);
1065+
tracer->prev_state.recorded_count = 0;
1066+
for (int i = 0; i < MAX_RECORDED_VALUES; i++) {
1067+
tracer->prev_state.recorded_values[i] = NULL;
10681068
}
1069+
const _PyOpcodeRecordEntry *record_entry = &_PyOpcode_RecordEntries[curr_instr->op.code];
1070+
for (int i = 0; i < record_entry->count; i++) {
1071+
_Py_RecordFuncPtr record_func = _PyOpcode_RecordFunctions[record_entry->indices[i]];
1072+
record_func(frame, stack_pointer, oparg, &tracer->prev_state.recorded_values[i]);
1073+
}
1074+
tracer->prev_state.recorded_count = record_entry->count;
10691075
assert(curr_instr->op.code == JUMP_BACKWARD_JIT || curr_instr->op.code == RESUME_CHECK_JIT || (exit != NULL));
10701076
tracer->initial_state.jump_backward_instr = curr_instr;
10711077

@@ -1117,7 +1123,10 @@ _PyJit_FinalizeTracing(PyThreadState *tstate, int err)
11171123
Py_CLEAR(tracer->initial_state.func);
11181124
Py_CLEAR(tracer->initial_state.executor);
11191125
Py_CLEAR(tracer->prev_state.instr_code);
1120-
Py_CLEAR(tracer->prev_state.recorded_value);
1126+
for (int i = 0; i < MAX_RECORDED_VALUES; i++) {
1127+
Py_CLEAR(tracer->prev_state.recorded_values[i]);
1128+
}
1129+
tracer->prev_state.recorded_count = 0;
11211130
uop_buffer_init(buffer, &tracer->uop_array[0], UOP_MAX_TRACE_LENGTH);
11221131
tracer->is_tracing = false;
11231132
}

0 commit comments

Comments
 (0)