Skip to content

Commit 060345f

Browse files
committed
Silence pyright Optional-access on async_op=True wait() calls
PyTorch's distributed stubs type dist.barrier/all_reduce/broadcast as returning Work | None because the sync path (async_op=False) returns None. Pyright cannot narrow the return when async_op=True is a literal kwarg, so every .wait(...) on the async handle flags reportOptionalMemberAccess. Add the same # type: ignore the rest of the codebase uses at three sites: _barrier_with_timeout, the eval-loss broadcast, and check_nccl_health.
1 parent 442223f commit 060345f

3 files changed

Lines changed: 3 additions & 3 deletions

File tree

kempnerforge/distributed/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _barrier_with_timeout(seconds: float, reason: str) -> None:
107107
"""
108108
work = dist.barrier(async_op=True)
109109
try:
110-
done = work.wait(timeout=timedelta(seconds=seconds))
110+
done = work.wait(timeout=timedelta(seconds=seconds)) # type: ignore[reportOptionalMemberAccess]
111111
except RuntimeError as e:
112112
raise RuntimeError(
113113
f"Barrier timed out after {seconds}s during {reason}. "

kempnerforge/resilience/health.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def check_nccl_health(timeout_sec: float = 10.0) -> bool:
235235
# timeout; older/alternate backends may return False instead. Handle
236236
# both so the timeout is honored regardless of version.
237237
try:
238-
done = work.wait(timeout=timedelta(seconds=timeout_sec))
238+
done = work.wait(timeout=timedelta(seconds=timeout_sec)) # type: ignore[reportOptionalMemberAccess]
239239
except RuntimeError as e:
240240
logger.warning(f"NCCL health check timed out after {timeout_sec}s: {e}")
241241
return False

kempnerforge/training/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def run_eval(
9393
async_op=True,
9494
)
9595
try:
96-
done = work.wait(timeout=timedelta(seconds=_EVAL_BROADCAST_TIMEOUT_SEC))
96+
done = work.wait(timeout=timedelta(seconds=_EVAL_BROADCAST_TIMEOUT_SEC)) # type: ignore[reportOptionalMemberAccess]
9797
except RuntimeError as e:
9898
logger.error(
9999
f"Eval loss broadcast timed out after {_EVAL_BROADCAST_TIMEOUT_SEC}s; "

0 commit comments

Comments
 (0)