Skip to content

Commit bdcc454

Browse files
zhxchen17pytorchmergebot
authored andcommitted
[dynamo] Add missing fields for THPPyInterpreterFrame. (pytorch#103227)
Fixes pytorch#103210 Test Plan: Before the fix: ``` pytest test/dynamo/test_export.py -k suppress_errors ``` got result: ``` File "/data/users/zhxchen17/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/zhxchen17/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/zhxchen17/pytorch/torch/_dynamo/eval_frame.py", line 295, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/data/users/zhxchen17/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/zhxchen17/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/zhxchen17/pytorch/torch/_dynamo/eval_frame.py", line 448, in catch_errors return callback(frame, cache_size, hooks, frame_state) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/zhxchen17/pytorch/torch/_dynamo/convert_frame.py", line 127, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/data/users/zhxchen17/pytorch/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert return _compile( ^^^^^^^^^ File "/data/users/zhxchen17/pytorch/torch/_dynamo/utils.py", line 180, in time_wrapper r = func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/data/users/zhxchen17/pytorch/torch/_dynamo/convert_frame.py", line 511, in _compile exception_handler(e, code, frame) File "/data/users/zhxchen17/pytorch/torch/_dynamo/convert_frame.py", line 216, in exception_handler log.error(format_error_msg(e, code, record_filename, frame)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/zhxchen17/pytorch/torch/_dynamo/exc.py", line 248, in format_error_msg stack_above_dynamo = filter_stack(extract_stack(frame)) ^^^^^^^^^^^^^^^^^^^^ File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/traceback.py", line 231, in extract_stack stack = StackSummary.extract(walk_stack(f), limit=limit) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/traceback.py", line 393, in extract return klass._extract_from_extended_frame_gen( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/traceback.py", line 416, in _extract_from_extended_frame_gen for f, (lineno, end_lineno, colno, end_colno) in frame_gen: File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/traceback.py", line 390, in extended_frame_gen for f, lineno in frame_gen: File "/home/zhxchen17/miniconda3/envs/dev/lib/python3.11/traceback.py", line 334, in walk_stack yield f, f.f_lineno ^^^^^^^^^^ AttributeError: 'torch._C.dynamo.eval_frame._PyInterpreterFrame' object has no attribute 'f_lineno' ``` After the fix: ``` pytest test/dynamo/test_export.py -k suppress_errors -s ``` Got Result: ``` File "/data/users/zhxchen17/pytorch/torch/_dynamo/exc.py", line 135, in unimplemented raise Unsupported(msg) torch._dynamo.exc.Unsupported: map() operator doesn't support scalar or zero-sized tensors during tracing. ========== The above exception occurred while processing the following code ========== File "/data/users/zhxchen17/pytorch/test/dynamo/test_export.py", line 3043, in forward def forward(self, xs): File "/data/users/zhxchen17/pytorch/test/dynamo/test_export.py", line 3047, in forward return map(body, xs) ========== unimplemented [("map() operator doesn't support scalar or zero-sized tensors during tracing.", 1)] . =============================== 1 passed, 133 deselected in 4.60s ================================ ``` Pull Request resolved: pytorch#103227 Approved by: https://github.com/williamwen42
1 parent a8c5286 commit bdcc454

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

test/dynamo/test_export.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3405,6 +3405,25 @@ def foo(args):
34053405

34063406
self.assertTrue(torch.allclose(foo(inp_container), gm(inp_container)))
34073407

3408+
@config.patch(suppress_errors=True)
3409+
@config.patch(verbose=True)
3410+
def test_export_with_map_zero_sized_tensor_suppress_errors(self):
3411+
from functorch.experimental.control_flow import map
3412+
3413+
class Module(torch.nn.Module):
3414+
def forward(self, xs):
3415+
def body(x):
3416+
return x + 1
3417+
3418+
return map(body, xs)
3419+
3420+
mod = Module()
3421+
xs = torch.randn(0, 2)
3422+
with self.assertRaises(
3423+
torch._dynamo.exc.Unsupported,
3424+
):
3425+
out_graph, _ = torch._dynamo.export(mod, xs)
3426+
34083427

34093428
common_utils.instantiate_parametrized_tests(ExportTests)
34103429

torch/csrc/dynamo/eval_frame.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,24 @@ static PyObject* THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame* self, PyOb
5959
return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame));
6060
}
6161

62+
static PyObject* THPPyInterpreterFrame_f_lineno(THPPyInterpreterFrame* self, PyObject* _noargs) {
63+
if (!self->frame->frame_obj) {
64+
return PyLong_FromLong(self->frame->f_code->co_firstlineno);
65+
}
66+
int lineno = PyFrame_GetLineNumber(self->frame->frame_obj);
67+
if (lineno < 0) {
68+
Py_RETURN_NONE;
69+
}
70+
return PyLong_FromLong(lineno);
71+
}
72+
73+
static PyObject* THPPyInterpreterFrame_f_back(THPPyInterpreterFrame* self, PyObject* _noargs) {
74+
if (!self->frame->frame_obj) {
75+
Py_RETURN_NONE;
76+
}
77+
return (PyObject*)PyFrame_GetBack(self->frame->frame_obj);
78+
}
79+
6280
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
6381
static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {
6482
{"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL},
@@ -69,6 +87,8 @@ static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {
6987
{"frame_obj", (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL},
7088
{"previous", (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL},
7189
{"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL},
90+
{"f_lineno", (getter)THPPyInterpreterFrame_f_lineno, NULL, NULL, NULL},
91+
{"f_back", (getter)THPPyInterpreterFrame_f_back, NULL, NULL, NULL},
7292
{NULL}};
7393

7494
static PyTypeObject THPPyInterpreterFrameType = {

0 commit comments

Comments
 (0)