Skip to content

Commit a38b686

Browse files
committed
Fixing some missed changes in last merge with main.
1 parent e58aa50 commit a38b686

2 files changed

Lines changed: 19 additions & 121 deletions

File tree

jetstream/core/orchestrator.py

Lines changed: 0 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,109 +1611,6 @@ def list_adapters_from_tensorstore(self):
16111611
return listed_adapters
16121612

16131613

1614-
def load_adapter_to_tensorstore(
1615-
self,
1616-
adapter_id: str,
1617-
adapter_path: str):
1618-
"""Load the adapter to adapter_tensorstore for each engine."""
1619-
logger.info("Loading adapter_id=%s from %s.",
1620-
adapter_id, adapter_path)
1621-
1622-
for idx, tensorstore in enumerate(self._prefill_adapterstore):
1623-
try:
1624-
engine = self._prefill_engines[idx]
1625-
adapter_params, adapter_config = engine.load_single_adapter(
1626-
adapter_path)
1627-
1628-
if not adapter_params or not adapter_config:
1629-
raise ValueError(
1630-
f"Failed to load adapter={adapter_id} from {adapter_path}.")
1631-
1632-
tensorstore.register_adapter(
1633-
adapter_id,
1634-
adapter_path,
1635-
adapter_config)
1636-
1637-
asyncio.run(tensorstore.load_adapter(adapter_id, adapter_params, True))
1638-
1639-
logger.info("Successfully loaded '%s' in engine_%d.",
1640-
adapter_id, idx)
1641-
engine.print_stats(f"After loading '{adapter_id}' in engine_{idx}")
1642-
1643-
except Exception as e:
1644-
logger.info("Adapter loading failed with error: %s", str(e))
1645-
raise e
1646-
1647-
for idx, tensorstore in enumerate(self._generate_adapterstore):
1648-
try:
1649-
engine = self._generate_engines[idx]
1650-
adapter_params, adapter_config = engine.load_single_adapter(
1651-
adapter_path)
1652-
1653-
if not adapter_params or not adapter_config:
1654-
raise ValueError(
1655-
f"Failed to load adapter={adapter_id} from {adapter_path}.")
1656-
1657-
tensorstore.register_adapter(
1658-
adapter_id,
1659-
adapter_path,
1660-
adapter_config)
1661-
1662-
asyncio.run(tensorstore.load_adapter(adapter_id, adapter_params, True))
1663-
1664-
logger.info("Successfully loaded '%s' in engine_%d.",
1665-
adapter_id, idx)
1666-
engine.print_stats(f"After loading '{adapter_id}' in engine_{idx}")
1667-
1668-
except Exception as e:
1669-
logger.info("Adapter loading failed with error: %s", str(e))
1670-
raise e
1671-
1672-
1673-
def unload_adapter_from_tensorstore(
1674-
self,
1675-
adapter_id: str):
1676-
"""Unload the adapter from adapter_tensorstore of each engine."""
1677-
logger.info("Unloading adapter_id=%s", adapter_id)
1678-
1679-
for idx, tensorstore in enumerate(self._prefill_adapterstore):
1680-
try:
1681-
engine = self._prefill_engines[idx]
1682-
asyncio.run(tensorstore.unload_adapter(adapter_id))
1683-
1684-
logger.info("Successfully unloaded '%s' in engine_%d.",
1685-
adapter_id, idx)
1686-
engine.print_stats(f"After unloading '{adapter_id}' in engine_{idx}")
1687-
1688-
except Exception as e:
1689-
logger.info("Adapter unloading failed with error: %s", str(e))
1690-
raise e
1691-
1692-
for idx, tensorstore in enumerate(self._generate_adapterstore):
1693-
try:
1694-
engine = self._generate_engines[idx]
1695-
asyncio.run(tensorstore.unload_adapter(adapter_id))
1696-
1697-
logger.info("Successfully unloaded '%s' in engine_%d.",
1698-
adapter_id, idx)
1699-
engine.print_stats(f"After unloading '{adapter_id}' in engine_{idx}")
1700-
1701-
except Exception as e:
1702-
logger.info("Adapter unloading failed with error: %s", str(e))
1703-
raise e
1704-
1705-
1706-
def list_adapters_from_tensorstore(self):
1707-
"""List all the adapters from the adapter_tensorstore of each engine."""
1708-
logger.info("Listing loaded adapters.")
1709-
1710-
listed_adapters = {}
1711-
for tensorstore in self._generate_adapterstore:
1712-
listed_adapters.update(tensorstore.adapter_registry)
1713-
1714-
return listed_adapters
1715-
1716-
17171614
class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer):
17181615
"""Coordinates a set of prefill and generate slices for LLM decoding."""
17191616

jetstream/core/proto/jetstream_pb2_grpc.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,13 @@
1919

2020

2121
class OrchestratorStub(object):
22-
"""TODO: Merge this with main JetStream core once we settle on an API.
22+
"""TODO: Merge this with main JetStream core once we settle on an API."""
2323

24+
def __init__(self, channel):
25+
"""Constructor.
26+
27+
Args:
28+
channel: A grpc.Channel.
2429
"""
2530
self.Decode = channel.unary_stream(
2631
"/jetstream_proto.Orchestrator/Decode",
@@ -35,23 +40,19 @@ class OrchestratorStub(object):
3540

3641

3742
class OrchestratorServicer(object):
38-
"""TODO: Merge this with main JetStream core once we settle on an API.
39-
40-
"""
43+
"""TODO: Merge this with main JetStream core once we settle on an API."""
4144

42-
def Decode(self, request, context):
43-
"""Query LLM to generate text or tokens.
44-
"""
45-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
46-
context.set_details('Method not implemented!')
47-
raise NotImplementedError('Method not implemented!')
45+
def Decode(self, request, context):
46+
"""Query LLM to generate text or tokens."""
47+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
48+
context.set_details("Method not implemented!")
49+
raise NotImplementedError("Method not implemented!")
4850

49-
def HealthCheck(self, request, context):
50-
"""Checks if the model server is live.
51-
"""
52-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
53-
context.set_details('Method not implemented!')
54-
raise NotImplementedError('Method not implemented!')
51+
def HealthCheck(self, request, context):
52+
"""Checks if the model server is live."""
53+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
54+
context.set_details("Method not implemented!")
55+
raise NotImplementedError("Method not implemented!")
5556

5657

5758
def add_OrchestratorServicer_to_server(servicer, server):
@@ -73,9 +74,9 @@ def add_OrchestratorServicer_to_server(servicer, server):
7374
server.add_generic_rpc_handlers((generic_handler,))
7475

7576

76-
# This class is part of an EXPERIMENTAL API.
77+
# This class is part of an EXPERIMENTAL API.
7778
class Orchestrator(object):
78-
"""TODO: Merge this with main JetStream core once we settle on an API.
79+
"""TODO: Merge this with main JetStream core once we settle on an API."""
7980

8081
@staticmethod
8182
def Decode(

0 commit comments

Comments
 (0)