Skip to content

Commit c2bee4f

Browse files
committed
fix: preserve tunnel client test hook
1 parent b510e56 commit c2bee4f

1 file changed

Lines changed: 16 additions & 16 deletions

File tree

  • packages/prime/src/prime_cli/commands

packages/prime/src/prime_cli/commands/tunnel.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import signal
3-
from typing import List, Optional
3+
from typing import Any, List, Optional
44

55
import typer
66
from rich.table import Table
@@ -18,6 +18,16 @@
1818
app = PlainTyper(help="Manage tunnels for exposing local services", no_args_is_help=True)
1919
console = get_console()
2020

21+
TunnelClient: Any = None
22+
23+
24+
def _tunnel_client_class() -> Any:
25+
if TunnelClient is not None:
26+
return TunnelClient
27+
from prime_tunnel.core.client import TunnelClient as loaded_tunnel_client
28+
29+
return loaded_tunnel_client
30+
2131

2232
def _format_tunnel_for_output(tunnel) -> dict:
2333
created_at = tunnel.created_at
@@ -157,9 +167,7 @@ def list_tunnels(
157167
validate_output_format(output, console)
158168

159169
async def fetch_tunnels():
160-
from prime_tunnel.core.client import TunnelClient
161-
162-
client = TunnelClient()
170+
client = _tunnel_client_class()()
163171
try:
164172
return await client.list_tunnels_page(
165173
team_id=team_id,
@@ -248,9 +256,7 @@ def tunnel_status(
248256
"""Get status of a specific tunnel."""
249257

250258
async def fetch_status():
251-
from prime_tunnel.core.client import TunnelClient
252-
253-
client = TunnelClient()
259+
client = _tunnel_client_class()()
254260
try:
255261
return await client.get_tunnel(tunnel_id)
256262
finally:
@@ -336,9 +342,7 @@ def stop_tunnel(
336342
if all:
337343

338344
async def fetch_tunnel_ids() -> List[str]:
339-
from prime_tunnel.core.client import TunnelClient
340-
341-
client = TunnelClient()
345+
client = _tunnel_client_class()()
342346
try:
343347
scoped_team_id = team_id
344348
if scoped_team_id is None:
@@ -400,9 +404,7 @@ async def fetch_tunnel_ids() -> List[str]:
400404
if labels:
401405

402406
async def validate_label_scope() -> None:
403-
from prime_tunnel.core.client import TunnelClient
404-
405-
client = TunnelClient()
407+
client = _tunnel_client_class()()
406408
try:
407409
scoped_user_id = client.config.user_id if only_mine else None
408410
scoped_team_id = team_id if team_id is not None else client.config.team_id
@@ -454,9 +456,7 @@ async def validate_label_scope() -> None:
454456
return
455457

456458
async def delete_tunnels() -> tuple[List[str], List[dict], List[dict]]:
457-
from prime_tunnel.core.client import TunnelClient
458-
459-
client = TunnelClient()
459+
client = _tunnel_client_class()()
460460
succeeded: List[str] = []
461461
not_found: List[dict] = []
462462
failed: List[dict] = []

0 commit comments

Comments
 (0)