Skip to content

Commit 7fdfb11

Browse files
bobrenjc93meta-codesync[bot]
authored andcommitted
follow up source attribution review comments (#180697)
Summary: ## Summary Follow-up to #179350 to address the post-land review comments. Root cause: The landed source-attribution change still had three rough edges: `SourceLocation.format()` kept its own single-line formatter instead of reusing Dynamo's existing multiline-aware source rendering, caret-only line normalization was applied globally in `munge_exc`, and the new version-dependent expected strings were still using `assertExpectedInline` even though `EXPECTTEST_ACCEPT=1` cannot rewrite those dynamic expectations. Proposed fix: - factor the shared source-range rendering into `torch._dynamo.utils.format_source_range()` and use it from both `get_instruction_source_311()` and `SourceLocation.format()` - move the caret-only normalization into a local helper in `test_error_messages.py` - switch the version-dependent expectation checks to `assertEqual`, and add a multiline `SourceLocation.format()` regression test Why this is the right long term fix: A shared formatter keeps traceback rendering and Dynamo source attribution in sync for multiline spans, the caret normalization stays scoped to the one test file that needs it, and the dynamic expectation checks stop relying on expecttest behavior they cannot support while the new regression test locks in the multiline formatting behavior. Drafted via Codex, published after manual review by bobrenjc93 X-link: pytorch/pytorch#180697 Approved by: https://github.com/williamwen42 Reviewed By: wdvr Differential Revision: D103354335 fbshipit-source-id: c8141c92ee54cd931d65f664ee4c0a8f0c557725
1 parent 03012fa commit 7fdfb11

1 file changed

Lines changed: 91 additions & 80 deletions

File tree

  • userbenchmark/dynamo/dynamobench/_dynamo

userbenchmark/dynamo/dynamobench/_dynamo/utils.py

Lines changed: 91 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4694,83 +4694,73 @@ def nextline(lineno: int, col: int) -> tuple[int, int]:
46944694
return None
46954695

46964696

4697-
def get_instruction_source_311(code: types.CodeType, inst: Instruction) -> str:
4698-
"""
4699-
Python 3.11+ only. Returns lines of source code (from code object `code`)
4700-
corresponding to `inst`'s location data, and underlines relevant code to `inst`.
4697+
def format_source_range(
4698+
filename: str,
4699+
lineno: int | None,
4700+
end_lineno: int | None = None,
4701+
col_offset: int | None = None,
4702+
end_col_offset: int | None = None,
4703+
*,
4704+
function_name: str = "<unknown>",
4705+
) -> str:
4706+
if lineno is None:
4707+
return ""
47014708

4702-
Example: CALL on `g`:
4703-
f(g(
4704-
^^
4705-
h(x)))
4706-
^^^^^
4709+
last_lineno = end_lineno if end_lineno is not None else lineno
4710+
source_lines = [
4711+
linecache.getline(filename, current_lineno).rstrip()
4712+
for current_lineno in range(lineno, last_lineno + 1)
4713+
]
4714+
if not any(source_lines):
4715+
return ""
47074716

4708-
We need our own implementation in < 3.13 since `format_frame_summary` in
4709-
Python's `traceback` module doesn't handle multi-line expressions
4710-
(and their anchor extraction code is not completely correct).
4711-
"""
4712-
if sys.version_info >= (3, 13):
4713-
# multiline traceback implemented in 3.13+
4717+
if (
4718+
sys.version_info >= (3, 13)
4719+
and end_lineno is not None
4720+
and end_lineno != lineno
4721+
and col_offset is not None
4722+
and end_col_offset is not None
4723+
):
4724+
# Keep single-line ranges on Dynamo's manual path. The stdlib traceback
4725+
# formatter is useful for multiline spans on 3.13+, but for single-line
4726+
# spans it can emit version-specific mixed `~`/`^` markers or omit the
4727+
# marker line entirely for comments.
47144728
frame_summary = traceback.FrameSummary(
4715-
code.co_filename,
4716-
inst.positions.lineno,
4717-
code.co_name,
4718-
end_lineno=inst.positions.end_lineno,
4719-
colno=inst.positions.col_offset,
4720-
end_colno=inst.positions.end_col_offset,
4729+
filename,
4730+
lineno,
4731+
function_name,
4732+
end_lineno=end_lineno,
4733+
colno=col_offset,
4734+
end_colno=end_col_offset,
47214735
)
47224736
result = traceback.format_list([frame_summary])[0]
4723-
# remove first line containing filename info
47244737
result = "\n".join(result.splitlines()[1:])
4725-
# indent lines with original indentation
4726-
orig_lines = [
4727-
linecache.getline(code.co_filename, lineno).rstrip()
4728-
for lineno in range(inst.positions.lineno, inst.positions.end_lineno + 1)
4729-
]
4730-
orig_lines_dedent = textwrap.dedent("\n".join(orig_lines)).splitlines()
4731-
indent_len = len(orig_lines[0]) - len(orig_lines_dedent[0])
4732-
indent = orig_lines[0][:indent_len]
4733-
result = textwrap.indent(textwrap.dedent(result), indent)
4734-
return result
4735-
4736-
assert hasattr(inst, "positions") and inst.positions is not None
4737-
if inst.positions.lineno is None:
4738-
return ""
4739-
# The rstrip + "\n" pattern is used throughout this function to handle
4740-
# linecache.getline errors. Error lines are treated as empty strings "", but we want
4741-
# to treat them as blank lines "\n".
4742-
first_line = linecache.getline(code.co_filename, inst.positions.lineno).rstrip()
4743-
if inst.positions.end_lineno is None:
4744-
return first_line
4745-
if inst.positions.col_offset is None or inst.positions.end_col_offset is None:
4738+
source_lines_dedent = textwrap.dedent("\n".join(source_lines)).splitlines()
4739+
if not source_lines_dedent:
4740+
return result
4741+
indent_len = len(source_lines[0]) - len(source_lines_dedent[0])
4742+
indent = source_lines[0][:indent_len]
4743+
return textwrap.indent(textwrap.dedent(result), indent)
4744+
4745+
first_line = source_lines[0]
4746+
if end_lineno is None or col_offset is None or end_col_offset is None:
47464747
return first_line
47474748

4748-
# character index of the start of the instruction
4749-
start_offset = _fix_offset(first_line, inst.positions.col_offset)
4750-
# character index of the end of the instruction
4751-
# compute later since end may be a different line
4752-
end_offset = None
4753-
# expression corresponding to the instruction so we can get anchors
4749+
start_offset = _fix_offset(first_line, col_offset)
47544750
segment = ""
4755-
# underline markers to be printed - start with `~` marker and replace with `^` later
4756-
markers = []
4751+
markers: list[str] = []
47574752

4758-
# Compute segment and initial markers
4759-
if inst.positions.end_lineno == inst.positions.lineno:
4760-
end_offset = _fix_offset(first_line, inst.positions.end_col_offset)
4753+
if end_lineno == lineno:
4754+
end_offset = _fix_offset(first_line, end_col_offset)
47614755
segment = first_line[start_offset:end_offset]
47624756
markers.append(" " * start_offset + "~" * (end_offset - start_offset))
47634757
else:
47644758
segment = first_line[start_offset:] + "\n"
47654759
markers.append(" " * start_offset + "~" * (len(first_line) - start_offset))
4766-
last_line = linecache.getline(
4767-
code.co_filename, inst.positions.end_lineno
4768-
).rstrip()
4769-
end_offset = _fix_offset(last_line, inst.positions.end_col_offset)
4770-
for lineno in range(inst.positions.lineno + 1, inst.positions.end_lineno):
4771-
line = linecache.getline(code.co_filename, lineno).rstrip()
4760+
last_line = source_lines[-1]
4761+
end_offset = _fix_offset(last_line, end_col_offset)
4762+
for line in source_lines[1:-1]:
47724763
segment += line + "\n"
4773-
# don't underline leading spaces
47744764
num_spaces = len(line) - len(line.lstrip())
47754765
markers.append(" " * num_spaces + "~" * (len(line) - num_spaces))
47764766
segment += last_line[:end_offset]
@@ -4783,49 +4773,70 @@ def get_instruction_source_311(code: types.CodeType, inst: Instruction) -> str:
47834773
except AssertionError:
47844774
pass
47854775

4786-
# replace `~` markers with `^` where necessary
47874776
if anchors is None:
47884777
markers = [marker.replace("~", "^") for marker in markers]
47894778
else:
4790-
# make markers mutable
47914779
mutable_markers: list[list[str]] = [list(marker) for marker in markers]
47924780

4793-
# anchor positions do not take start_offset into account
47944781
if anchors.left_end_lineno == 0:
47954782
anchors.left_end_offset += start_offset
47964783
if anchors.right_start_lineno == 0:
47974784
anchors.right_start_offset += start_offset
47984785

4799-
# Turn `~`` markers between anchors to `^`
4800-
for lineno in range(len(markers)):
4801-
for col in range(len(mutable_markers[lineno])):
4802-
if lineno < anchors.left_end_lineno:
4786+
for marker_lineno in range(len(markers)):
4787+
for col in range(len(mutable_markers[marker_lineno])):
4788+
if marker_lineno < anchors.left_end_lineno:
48034789
continue
4804-
if lineno == anchors.left_end_lineno and col < anchors.left_end_offset:
4790+
if (
4791+
marker_lineno == anchors.left_end_lineno
4792+
and col < anchors.left_end_offset
4793+
):
48054794
continue
48064795
if (
4807-
lineno == anchors.right_start_lineno
4796+
marker_lineno == anchors.right_start_lineno
48084797
and col >= anchors.right_start_offset
48094798
):
48104799
continue
4811-
if lineno > anchors.right_start_lineno:
4800+
if marker_lineno > anchors.right_start_lineno:
48124801
continue
4813-
if mutable_markers[lineno][col] == "~":
4814-
mutable_markers[lineno][col] = "^"
4802+
if mutable_markers[marker_lineno][col] == "~":
4803+
mutable_markers[marker_lineno][col] = "^"
48154804

4816-
# make markers into strings again
48174805
markers = ["".join(marker) for marker in mutable_markers]
48184806

48194807
result = ""
4820-
for i in range(len(markers)):
4821-
result += (
4822-
linecache.getline(code.co_filename, inst.positions.lineno + i).rstrip()
4823-
+ "\n"
4824-
)
4825-
result += markers[i] + "\n"
4808+
for line, marker in zip(source_lines, markers):
4809+
result += line + "\n"
4810+
result += marker + "\n"
48264811
return result
48274812

48284813

4814+
def get_instruction_source_311(code: types.CodeType, inst: Instruction) -> str:
4815+
"""
4816+
Python 3.11+ only. Returns lines of source code (from code object `code`)
4817+
corresponding to `inst`'s location data, and underlines relevant code to `inst`.
4818+
4819+
Example: CALL on `g`:
4820+
f(g(
4821+
^^
4822+
h(x)))
4823+
^^^^^
4824+
4825+
We need our own implementation in < 3.13 since `format_frame_summary` in
4826+
Python's `traceback` module doesn't handle multi-line expressions
4827+
(and their anchor extraction code is not completely correct).
4828+
"""
4829+
assert hasattr(inst, "positions") and inst.positions is not None
4830+
return format_source_range(
4831+
code.co_filename,
4832+
inst.positions.lineno,
4833+
inst.positions.end_lineno,
4834+
inst.positions.col_offset,
4835+
inst.positions.end_col_offset,
4836+
function_name=code.co_name,
4837+
)
4838+
4839+
48294840
def get_static_address_type(t: Any) -> Any:
48304841
if isinstance(t, torch.Tensor):
48314842
return getattr(t, "_dynamo_static_input_type", None)

0 commit comments

Comments
 (0)