-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexchange.py
More file actions
140 lines (115 loc) · 4.8 KB
/
Copy pathexchange.py
File metadata and controls
140 lines (115 loc) · 4.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""CapabilityService servicer + gossip client (design doc §2).
Server side: :class:`CapabilityServiceServicer` is a thin adapter
between the wire messages and :class:`CapabilityRegistry` — merge what
the caller pushed, reply with the merged snapshot.
Client side: :func:`exchange_once` performs one push-pull round with a
list of peers. Failures are collected per peer, never raised through:
a dead peer must not stop gossip with the live ones (failure model,
design doc §6). The periodic loop wrapper lives in the server
launcher, not here, so this module stays trivially testable.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence
import grpc
import grpc.aio
from inference_engine.distributed.capability import (
CapabilityRegistry,
NodeCapability,
)
from inference_engine.server.proto_gen.kakeya.v1 import (
distributed_pb2,
distributed_pb2_grpc,
)
_LOG = logging.getLogger("kakeya.distributed.exchange")
DEFAULT_EXCHANGE_TIMEOUT_S = 5.0
class CapabilityServiceServicer(distributed_pb2_grpc.CapabilityServiceServicer):
"""gRPC adapter over a node's :class:`CapabilityRegistry`."""
def __init__(self, registry: CapabilityRegistry) -> None:
self._registry = registry
@property
def registry(self) -> CapabilityRegistry:
return self._registry
async def ExchangeCapabilities( # noqa: N802 - gRPC casing
self,
request: distributed_pb2.ExchangeCapabilitiesRequest,
context: grpc.aio.ServicerContext,
) -> distributed_pb2.ExchangeCapabilitiesResponse:
pushed = [NodeCapability.from_proto(m) for m in request.known_nodes]
changed = self._registry.merge(pushed)
if changed:
_LOG.debug("gossip merge updated %d card(s)", changed)
return distributed_pb2.ExchangeCapabilitiesResponse(
known_nodes=[c.to_proto() for c in self._registry.snapshot()],
)
async def GetNodeCapability( # noqa: N802 - gRPC casing
self,
request: distributed_pb2.GetNodeCapabilityRequest,
context: grpc.aio.ServicerContext,
) -> distributed_pb2.GetNodeCapabilityResponse:
return distributed_pb2.GetNodeCapabilityResponse(
node=self._registry.self_card.to_proto(),
)
def add_capability_service(
server: grpc.aio.Server, registry: CapabilityRegistry,
) -> CapabilityServiceServicer:
"""Register a CapabilityService for ``registry`` on ``server``."""
servicer = CapabilityServiceServicer(registry)
distributed_pb2_grpc.add_CapabilityServiceServicer_to_server(servicer, server)
return servicer
@dataclass(frozen=True)
class ExchangeReport:
"""Outcome of one gossip round across a peer list."""
merged_cards: int
errors: Dict[str, str]
@property
def ok(self) -> bool:
return not self.errors
async def exchange_once(
registry: CapabilityRegistry,
peers: Sequence[str],
*,
timeout_s: float = DEFAULT_EXCHANGE_TIMEOUT_S,
) -> ExchangeReport:
"""One push-pull gossip round with every address in ``peers``.
Per peer: push our full snapshot, merge the peer's reply. Errors
(connect refused, deadline, …) are recorded per peer address and
do not interrupt the round for remaining peers.
"""
merged = 0
errors: Dict[str, str] = {}
for peer in peers:
request = distributed_pb2.ExchangeCapabilitiesRequest(
known_nodes=[c.to_proto() for c in registry.snapshot()],
)
try:
async with grpc.aio.insecure_channel(peer) as channel:
stub = distributed_pb2_grpc.CapabilityServiceStub(channel)
response = await stub.ExchangeCapabilities(
request, timeout=timeout_s,
)
except grpc.aio.AioRpcError as exc:
errors[peer] = f"{exc.code().name}: {exc.details()}"
_LOG.warning("gossip with %s failed: %s", peer, errors[peer])
continue
merged += registry.merge(
NodeCapability.from_proto(m) for m in response.known_nodes
)
return ExchangeReport(merged_cards=merged, errors=errors)
async def fetch_node_capability(
address: str,
*,
timeout_s: float = DEFAULT_EXCHANGE_TIMEOUT_S,
) -> Optional[NodeCapability]:
"""Liveness probe: fetch one node's own card, or None on failure."""
try:
async with grpc.aio.insecure_channel(address) as channel:
stub = distributed_pb2_grpc.CapabilityServiceStub(channel)
response = await stub.GetNodeCapability(
distributed_pb2.GetNodeCapabilityRequest(), timeout=timeout_s,
)
except grpc.aio.AioRpcError as exc:
_LOG.warning("probe of %s failed: %s", address, exc.code().name)
return None
return NodeCapability.from_proto(response.node)