Skip to content

Commit 4b8e7b4

Browse files
Ambient Code Botclaude
andcommitted
refactor: unify Dial and router retry into a single loop
Extract _dial_and_connect() to perform Dial + router connection as a single atomic operation. This eliminates the duplicated Dial call that was in the separate router retry block, addressing the code review feedback about entangled and repeated code. The single retry loop in handle_async now retries the full _dial_and_connect() on transient errors, which naturally handles both Dial failures and router connection failures with the same backoff logic and always gets a fresh router token on retry. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 21e1ad3 commit 4b8e7b4

2 files changed

Lines changed: 56 additions & 84 deletions

File tree

python/packages/jumpstarter/jumpstarter/client/lease.py

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,21 @@ def __contextmanager__(self) -> Generator[Self]:
320320
grpc.StatusCode.INTERNAL,
321321
})
322322

323-
async def handle_async(self, stream): # noqa: C901
323+
async def _dial_and_connect(self, stream):
324+
"""Dial the controller and connect to the router stream.
325+
326+
Performs a single Dial + router connection attempt. Raises on failure
327+
so the caller can decide whether to retry.
328+
"""
329+
response = await self.controller.Dial(jumpstarter_pb2.DialRequest(lease_name=self.name))
330+
async with connect_router_stream(
331+
response.router_endpoint, response.router_token, stream, self.tls_config, self.grpc_options
332+
):
333+
pass
334+
335+
async def handle_async(self, stream):
324336
logger.debug("Connecting to Lease with name %s", self.name)
325-
# Retry Dial and router connection with exponential backoff for transient
337+
# Retry Dial + router connection with exponential backoff for transient
326338
# errors. This handles:
327339
# 1. The race condition where the client acquires a lease before the
328340
# exporter has transitioned to LEASE_READY status (FAILED_PRECONDITION).
@@ -335,8 +347,8 @@ async def handle_async(self, stream): # noqa: C901
335347
attempt = 0
336348
while True:
337349
try:
338-
response = await self.controller.Dial(jumpstarter_pb2.DialRequest(lease_name=self.name))
339-
break
350+
await self._dial_and_connect(stream)
351+
return
340352
except AioRpcError as e:
341353
remaining = deadline - time.monotonic()
342354
if e.code() == grpc.StatusCode.FAILED_PRECONDITION and "not ready" in str(e.details()):
@@ -349,7 +361,7 @@ async def handle_async(self, stream): # noqa: C901
349361
raise
350362
delay = min(base_delay * (2**attempt), max_delay, remaining)
351363
logger.debug(
352-
"Exporter not ready, retrying Dial in %.1fs (attempt %d, %.1fs remaining)",
364+
"Exporter not ready, retrying in %.1fs (attempt %d, %.1fs remaining)",
353365
delay,
354366
attempt + 1,
355367
remaining,
@@ -361,15 +373,15 @@ async def handle_async(self, stream): # noqa: C901
361373
if e.code() in self._TRANSIENT_GRPC_CODES:
362374
if remaining <= 0:
363375
logger.warning(
364-
"Dial failed with transient error after %d attempts (%.1fs elapsed): %s",
376+
"Connection failed with transient error after %d attempts (%.1fs elapsed): %s",
365377
attempt + 1,
366378
self.dial_timeout,
367379
e.details(),
368380
)
369381
return
370382
delay = min(base_delay * (2**attempt), max_delay, remaining)
371383
logger.info(
372-
"Dial failed with %s, retrying in %.1fs (attempt %d, %.1fs remaining): %s",
384+
"Connection failed with %s, retrying in %.1fs (attempt %d, %.1fs remaining): %s",
373385
e.code().name,
374386
delay,
375387
attempt + 1,
@@ -389,60 +401,22 @@ async def handle_async(self, stream): # noqa: C901
389401
else:
390402
logger.warning("Connection to exporter lost: %s", e.details())
391403
return
392-
393-
# Connect to the router with retry for transient failures.
394-
# After a successful Dial, the router endpoint may still be temporarily
395-
# unreachable (e.g. after a tunnel drop). Retry the connection to give
396-
# the network time to recover.
397-
remaining = deadline - time.monotonic()
398-
router_attempt = 0
399-
while True:
400-
try:
401-
async with connect_router_stream(
402-
response.router_endpoint, response.router_token, stream, self.tls_config, self.grpc_options
403-
):
404-
return
405-
except AioRpcError as e:
406-
remaining = deadline - time.monotonic()
407-
if e.code() in self._TRANSIENT_GRPC_CODES and remaining > 0:
408-
delay = min(base_delay * (2**router_attempt), max_delay, remaining)
409-
logger.info(
410-
"Router connection failed with %s, retrying in %.1fs (attempt %d, %.1fs remaining): %s",
411-
e.code().name,
412-
delay,
413-
router_attempt + 1,
414-
remaining,
415-
e.details(),
416-
)
417-
await sleep(delay)
418-
router_attempt += 1
419-
# Re-dial to get a fresh router token since the old one may
420-
# have expired during the retry window
421-
try:
422-
response = await self.controller.Dial(
423-
jumpstarter_pb2.DialRequest(lease_name=self.name)
424-
)
425-
except AioRpcError:
426-
logger.debug("Re-dial failed during router retry, will retry from Dial")
427-
continue
428-
logger.warning("Router connection failed: %s (code=%s)", e.details(), e.code().name)
429-
return
430404
except OSError as e:
431405
# OSError can occur when the router endpoint is unreachable
432406
remaining = deadline - time.monotonic()
433407
if remaining > 0:
434-
delay = min(base_delay * (2**router_attempt), max_delay, remaining)
408+
delay = min(base_delay * (2**attempt), max_delay, remaining)
435409
logger.info(
436-
"Router connection failed with OSError, retrying in %.1fs (attempt %d, %.1fs remaining): %s",
410+
"Connection failed with OSError, retrying in %.1fs (attempt %d, %.1fs remaining): %s",
437411
delay,
438-
router_attempt + 1,
412+
attempt + 1,
439413
remaining,
440414
e,
441415
)
442416
await sleep(delay)
443-
router_attempt += 1
417+
attempt += 1
444418
continue
445-
logger.warning("Router connection failed: %s", e)
419+
logger.warning("Connection failed: %s", e)
446420
return
447421

448422
@asynccontextmanager

python/packages/jumpstarter/jumpstarter/client/lease_test.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -582,12 +582,12 @@ def _make_lease_for_handle():
582582
return lease
583583

584584

585-
class TestHandleAsyncTransientDialRetry:
586-
"""Tests for transient gRPC error retry in handle_async Dial phase."""
585+
class TestHandleAsyncTransientRetry:
586+
"""Tests for transient gRPC error retry in handle_async (unified Dial + router loop)."""
587587

588588
@pytest.mark.anyio
589-
async def test_dial_retries_on_unavailable_then_succeeds(self):
590-
"""Dial should retry on UNAVAILABLE and succeed on the next attempt."""
589+
async def test_retries_on_dial_unavailable_then_succeeds(self):
590+
"""Should retry on UNAVAILABLE from Dial and succeed on the next attempt."""
591591
lease = _make_lease_for_handle()
592592

593593
dial_response = Mock(router_endpoint="ep", router_token="tok")
@@ -615,8 +615,8 @@ async def fake_router(*args, **kwargs):
615615
assert call_count == 2
616616

617617
@pytest.mark.anyio
618-
async def test_dial_transient_error_returns_after_timeout(self):
619-
"""Dial should give up and return when dial_timeout is exceeded."""
618+
async def test_transient_error_returns_after_timeout(self):
619+
"""Should give up and return when dial_timeout is exceeded."""
620620
lease = _make_lease_for_handle()
621621
lease.dial_timeout = 0.0 # already expired
622622

@@ -636,8 +636,8 @@ async def test_dial_transient_error_returns_after_timeout(self):
636636
[grpc.StatusCode.RESOURCE_EXHAUSTED, grpc.StatusCode.ABORTED, grpc.StatusCode.INTERNAL],
637637
ids=["RESOURCE_EXHAUSTED", "ABORTED", "INTERNAL"],
638638
)
639-
async def test_dial_retries_multiple_transient_codes(self, status_code):
640-
"""Dial should retry on RESOURCE_EXHAUSTED, ABORTED, INTERNAL."""
639+
async def test_retries_multiple_transient_codes(self, status_code):
640+
"""Should retry on RESOURCE_EXHAUSTED, ABORTED, INTERNAL."""
641641
lease = _make_lease_for_handle()
642642
dial_response = Mock(router_endpoint="ep", router_token="tok")
643643
call_count = 0
@@ -663,13 +663,9 @@ async def fake_router(*args, **kwargs):
663663

664664
assert call_count == 2, f"Expected 2 calls for {status_code}, got {call_count}"
665665

666-
667-
class TestHandleAsyncRouterRetry:
668-
"""Tests for router connection retry in handle_async."""
669-
670666
@pytest.mark.anyio
671-
async def test_router_retries_on_transient_error_then_succeeds(self):
672-
"""Router connection should retry on transient error, re-dial, then succeed."""
667+
async def test_router_transient_error_retries_full_dial_and_connect(self):
668+
"""Router transient error should retry the full Dial + connect cycle."""
673669
lease = _make_lease_for_handle()
674670
dial_response = Mock(router_endpoint="ep", router_token="tok")
675671
lease.controller.Dial = AsyncMock(return_value=dial_response)
@@ -689,30 +685,30 @@ async def fake_router(*args, **kwargs):
689685
await lease.handle_async(Mock())
690686

691687
assert connect_count == 2
692-
# Dial called once for initial + once for re-dial
688+
# Dial is called fresh each attempt (unified loop)
693689
assert lease.controller.Dial.call_count == 2
694690

695691
@pytest.mark.anyio
696-
async def test_router_non_transient_error_returns_immediately(self):
697-
"""Router connection should not retry on non-transient errors."""
692+
async def test_non_transient_error_returns_immediately(self):
693+
"""Non-transient errors should not be retried."""
698694
lease = _make_lease_for_handle()
699695
dial_response = Mock(router_endpoint="ep", router_token="tok")
700696
lease.controller.Dial = AsyncMock(return_value=dial_response)
701697

702698
@asynccontextmanager
703699
async def fail_router(*args, **kwargs):
704-
raise _make_aio_rpc_error(grpc.StatusCode.PERMISSION_DENIED, "no access")
700+
raise _make_aio_rpc_error(grpc.StatusCode.NOT_FOUND, "not found")
705701
yield # pragma: no cover
706702

707703
with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fail_router):
708704
await lease.handle_async(Mock())
709705

710-
# Only the initial Dial, no re-dial
706+
# Only one Dial attempt, no retry
711707
assert lease.controller.Dial.call_count == 1
712708

713709
@pytest.mark.anyio
714-
async def test_router_transient_error_returns_after_timeout(self):
715-
"""Router should give up when dial_timeout is exceeded."""
710+
async def test_transient_router_error_returns_after_timeout(self):
711+
"""Should give up when dial_timeout is exceeded during router retries."""
716712
lease = _make_lease_for_handle()
717713
lease.dial_timeout = 0.0 # already expired
718714
dial_response = Mock(router_endpoint="ep", router_token="tok")
@@ -730,8 +726,8 @@ async def fail_router(*args, **kwargs):
730726
assert lease.controller.Dial.call_count == 1
731727

732728
@pytest.mark.anyio
733-
async def test_router_redial_failure_is_swallowed(self):
734-
"""When re-dial fails during router retry, the error is logged and retry continues."""
729+
async def test_dial_failure_on_retry_is_retried_again(self):
730+
"""When Dial fails with a transient error during retry, it should keep retrying."""
735731
lease = _make_lease_for_handle()
736732
dial_response = Mock(router_endpoint="ep", router_token="tok")
737733

@@ -741,11 +737,10 @@ async def dial_side_effect(req):
741737
nonlocal dial_count
742738
dial_count += 1
743739
if dial_count == 1:
744-
return dial_response
740+
return dial_response # first Dial succeeds, router will fail
745741
if dial_count == 2:
746-
# Re-dial fails
747742
raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "re-dial failed")
748-
return dial_response
743+
return dial_response # third Dial succeeds
749744

750745
lease.controller.Dial = AsyncMock(side_effect=dial_side_effect)
751746

@@ -755,22 +750,23 @@ async def dial_side_effect(req):
755750
async def fake_router(*args, **kwargs):
756751
nonlocal connect_count
757752
connect_count += 1
758-
if connect_count <= 2:
753+
if connect_count == 1:
759754
raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "router fail")
760755
yield
761756

762757
with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock):
763758
with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fake_router):
764759
await lease.handle_async(Mock())
765760

766-
# Should have retried: connect fails, re-dial fails, connect fails again,
767-
# re-dial succeeds, third connect succeeds
768-
assert connect_count == 3
761+
# Attempt 1: Dial OK -> router fails (UNAVAILABLE)
762+
# Attempt 2: Dial fails (UNAVAILABLE) -> retried
763+
# Attempt 3: Dial OK -> router OK
769764
assert dial_count == 3
765+
assert connect_count == 2
770766

771767
@pytest.mark.anyio
772-
async def test_router_oserror_retries_then_succeeds(self):
773-
"""Router connection should retry on OSError, then succeed."""
768+
async def test_oserror_retries_then_succeeds(self):
769+
"""OSError from router should retry the full Dial + connect cycle."""
774770
lease = _make_lease_for_handle()
775771
dial_response = Mock(router_endpoint="ep", router_token="tok")
776772
lease.controller.Dial = AsyncMock(return_value=dial_response)
@@ -790,10 +786,12 @@ async def fake_router(*args, **kwargs):
790786
await lease.handle_async(Mock())
791787

792788
assert connect_count == 2
789+
# Dial called fresh each attempt
790+
assert lease.controller.Dial.call_count == 2
793791

794792
@pytest.mark.anyio
795-
async def test_router_oserror_returns_after_timeout(self):
796-
"""Router should give up on OSError when dial_timeout is exceeded."""
793+
async def test_oserror_returns_after_timeout(self):
794+
"""Should give up on OSError when dial_timeout is exceeded."""
797795
lease = _make_lease_for_handle()
798796
lease.dial_timeout = 0.0 # already expired
799797
dial_response = Mock(router_endpoint="ep", router_token="tok")

0 commit comments

Comments
 (0)