Skip to content

Commit 4e16400

Browse files
sungsoohaclaude
andcommitted
fix: add torch.inference_mode() to parallel worker threads
torch.inference_mode() is thread-local — ThreadPoolExecutor workers don't inherit the parent's context. Without it, parallel workers run with autograd enabled (extra memory, different semantics than sequential path). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com>
1 parent f412e29 commit 4e16400

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _process_weight(item: _WeightQuantWork) -> tuple[str, torch.Tensor, str | No
114114

115115
def _process_device_batch(items: list[_WeightQuantWork], device: torch.device):
116116
"""Process all weight items on a single GPU. Runs in a dedicated thread."""
117-
with torch.cuda.device(device):
117+
with torch.inference_mode(), torch.cuda.device(device):
118118
results = [_process_weight(item) for item in items]
119119
torch.cuda.synchronize(device)
120120
return results

0 commit comments

Comments
 (0)