Fix torch.compile graph breaks from Params4bit __getattr__ (#1904, #1917)#1916
Fix torch.compile graph breaks from Params4bit __getattr__ (#1904, #1917)#1916Titus-von-Koeller wants to merge 3 commits intomainfrom
Conversation
Replace __getattr__ + _QUANT_STATE_ATTR_MAP on Params4bit with @Property descriptors. Dynamo cannot trace __getattr__ on torch.Tensor subclasses, causing graph breaks that multiply under activation checkpointing. Properties use the descriptor protocol which Dynamo handles correctly. Add regression test that compiles Linear4bit with fullgraph=True and torch.utils.checkpoint to catch this class of issue. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
PR Review: #1916 -- Fix torch.compile graph breaks from Params4bit __getattr__ (#1904)Classification: Bug fix replacing No blocking issues. The approach is correct. Properties are resolved via Python's descriptor protocol at the class MRO level, which Dynamo handles natively. The behavioral semantics are preserved: each property delegates to One observation (non-blocking): The Regression test is well-designed. Uses CI status: All CUDA/GPU tests pass across A10, L40S, T4 (CUDA 11.8/12.8/13.0). All builds pass. Lint passes. 4 CPU test failures on torch 2.10.0 appear pre-existing/infra -- PR only touches Python code.
|
matmul_4bit mutates quant_state.dtype in-place on CPU, which Dynamo flags as a side effect under fullgraph=True + activation checkpointing. This is a pre-existing issue unrelated to the __getattr__ → @Property fix. Skip on CPU and track the mutation fix separately in #1917. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The mutation `quant_state.dtype = A.dtype` is unnecessary: MatMul4Bit.forward already casts via `.to(A.dtype)`, and gemv_4bit doesn't read state.dtype. Removing it eliminates the Dynamo graph break on CPU under activation checkpointing, so the regression test no longer needs a CPU skip. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary
__getattr__+_QUANT_STATE_ATTR_MAPonParams4bitwith@propertydescriptors to eliminatetorch.compilegraph breaks under activation checkpointingquant_state.dtypemutation inmatmul_4bitCPU path that caused a separatetorch.compilegraph break under activation checkpointingtest_linear4bit_torch_compile_activation_checkpointingthat compiles withfullgraph=True+torch.utils.checkpointto catch this class of issueContext
PR #1866 added
__getattr__toParams4bitfor FSDPstate_dicttraversal. SinceParams4bitis atorch.Tensorsubclass, Dynamo cannot trace through__getattr__, creating graph breaks on every attribute access. With activation checkpointing these multiply across layers, causing significant compilation overhead (#1904).@propertydescriptors use Python's descriptor protocol (resolved at class level), which Dynamo handles correctly — no graph breaks. FSDP still works becausegetattr(weight, "absmax")resolves the same way.Attributes that collide with
Params4bitinstance attrs (blocksize,quant_type) ortorch.Tensorattrs (dtype,shape) are intentionally omitted — they're packed into thebitsandbytes__*blob and never traversed by FSDP as separate keys.QuantState.__getattr__is left unchanged sinceQuantStateis not aTensorsubclass.Additionally,
matmul_4bit()mutatedquant_state.dtype = A.dtypeon the CPU path (#1917). This in-place mutation is unnecessary —MatMul4Bit.forwardalready casts via.to(A.dtype), andgemv_4bitdoesn't readstate.dtype. Removing it eliminates the Dynamo graph break on CPU under activation checkpointing.Validated against three code states
__getattr__)__getattr__)@property)Test plan
test_linear4bit_torch_compile_activation_checkpointing— 4 variants pass (nf4/fp4 × compress_statistics), including CPUtest_linear4bit_torch_compile— all 64 variants pass (no regressions)test_params4bit_quant_state_attr_access— all 4 variants pass (FSDP traversal still works)pre-commit run --all-files— cleanFixes #1904
Fixes #1917
🤖 Generated with Claude Code