Skip to content

Commit eb74d86

Browse files
committed
Refactoring part-2.
1 parent e4d875a commit eb74d86

6 files changed

Lines changed: 8 additions & 100 deletions

File tree

jetstream/core/orchestrator.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,10 @@ def _export_lora_request_info(self):
516516
max_loras = 0
517517
if self._metrics_collector:
518518
for idx, engine in enumerate(self._generate_engines):
519-
adapters_list_str += asyncio.run(self._adapter_tensorstore.get_hbm_loaded_adapters())
520519
max_loras += engine.max_concurrent_decodes
521520

521+
adapters_list_str += asyncio.run(self._adapter_tensorstore.get_hbm_loaded_adapters())
522+
522523
self._metrics_collector.get_lora_request_info_metric(max_loras,
523524
adapters_list_str).set_to_current_time()
524525

@@ -580,7 +581,6 @@ def _prefill_thread(self, idx: int):
580581
if request is None:
581582
break
582583

583-
584584
request.metadata.prefill_dequeue_time = time.perf_counter()
585585
is_bos = True
586586
logging.info(
@@ -616,7 +616,6 @@ def _prefill_thread(self, idx: int):
616616
padded_tokens=padded_tokens,
617617
true_length=true_length,
618618
)
619-
620619
del final_params
621620

622621
request.prefill_result = prefill_result
@@ -705,7 +704,6 @@ def _transfer_thread(self, idx: int):
705704
new_request.metadata.generate_enqueue_time = time.perf_counter()
706705
self._generate_backlogs[target_idx].put(new_request, block=True)
707706

708-
elapsed_time = (new_request.metadata.generate_enqueue_time - new_request.metadata.transfer_dequeue_time) * 1e6
709707
logging.info(
710708
"Successfully transferred prefill "
711709
"from prefill engine %d to generate engine %d "
@@ -821,7 +819,8 @@ def _generate_thread(self, idx: int):
821819
decode_state = generate_engine.insert(
822820
new_request.prefill_result, decode_state, slot=slot
823821
)
824-
822+
823+
# Export the lora_request_info metric
825824
self._export_lora_request_info()
826825

827826
del new_request.prefill_result
@@ -1123,7 +1122,6 @@ async def Decode( # pylint: disable=invalid-overridden-method
11231122
request: jetstream_pb2.DecodeRequest,
11241123
context: Optional[grpc.aio.ServicerContext] = None,
11251124
) -> AsyncIterator[jetstream_pb2.DecodeResponse]:
1126-
11271125
"""Decode."""
11281126
if context is None:
11291127
logging.warning(
@@ -1134,7 +1132,6 @@ async def Decode( # pylint: disable=invalid-overridden-method
11341132
return_channel = async_multifuture.AsyncMultifuture()
11351133
if context:
11361134
context.add_done_callback(return_channel.cancel)
1137-
11381135
prefill_content, is_client_side_tokenization = self._get_prefill_content(
11391136
request
11401137
)

jetstream/core/proto/jetstream.proto

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,5 +93,4 @@ message HealthCheckRequest {}
9393
message HealthCheckResponse {
9494
// Denotes whether the model server is live
9595
bool is_live = 1;
96-
9796
}

jetstream/core/proto/jetstream_pb2.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131

32-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fjetstream.proto\x12\x0fjetstream_proto\"\x90\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x12\n\nadapter_id\x18\x08 \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\"\x14\n\x12HealthCheckRequest\"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\"\x15\n\x13ListAdaptersRequest\"s\n\x14ListAdaptersResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\x12\x33\n\radapter_infos\x18\x03 \x03(\x0b\x32\x1c.jetstream_proto.AdapterInfo\"\x82\x01\n\x0b\x41\x64\x61pterInfo\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0cloading_cost\x18\x02 \x01(\x03\x12\x10\n\x08size_hbm\x18\x03 \x01(\x03\x12\x10\n\x08size_cpu\x18\x04 \x01(\x03\x12\x15\n\rlast_accessed\x18\x05 \x01(\x02\x12\x0e\n\x06status\x18\x06 \x01(\t\">\n\x12LoadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\"=\n\x13LoadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\"*\n\x14UnloadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"?\n\x15UnloadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t2\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse\"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse\"\x00\x32\xb2\x02\n\x13MultiAdapterManager\x12]\n\x0cListAdapters\x12$.jetstream_proto.ListAdaptersRequest\x1a%.jetstream_proto.ListAdaptersResponse\"\x00\x12Z\n\x0bLoadAdapter\x12#.jetstream_proto.LoadAdapterRequest\x1a$.jetstream_proto.LoadAdapterResponse\"\x00\x12`\n\rUnloadAdapter\x12%.jetstream_proto.UnloadAdapterRequest\x1a&.jetstream_proto.UnloadAdapterResponse\"\x00\x62\x06proto3')
32+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fjetstream.proto\x12\x0fjetstream_proto\"\x90\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x12\n\nadapter_id\x18\x08 \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04\"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\"\x14\n\x12HealthCheckRequest\"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse\"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse\"\x00\x62\x06proto3')
3333

3434
_globals = globals()
3535
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -56,22 +56,6 @@
5656
_globals['_HEALTHCHECKREQUEST']._serialized_end=793
5757
_globals['_HEALTHCHECKRESPONSE']._serialized_start=795
5858
_globals['_HEALTHCHECKRESPONSE']._serialized_end=833
59-
_globals['_LISTADAPTERSREQUEST']._serialized_start=835
60-
_globals['_LISTADAPTERSREQUEST']._serialized_end=856
61-
_globals['_LISTADAPTERSRESPONSE']._serialized_start=858
62-
_globals['_LISTADAPTERSRESPONSE']._serialized_end=973
63-
_globals['_ADAPTERINFO']._serialized_start=976
64-
_globals['_ADAPTERINFO']._serialized_end=1106
65-
_globals['_LOADADAPTERREQUEST']._serialized_start=1108
66-
_globals['_LOADADAPTERREQUEST']._serialized_end=1170
67-
_globals['_LOADADAPTERRESPONSE']._serialized_start=1172
68-
_globals['_LOADADAPTERRESPONSE']._serialized_end=1233
69-
_globals['_UNLOADADAPTERREQUEST']._serialized_start=1235
70-
_globals['_UNLOADADAPTERREQUEST']._serialized_end=1277
71-
_globals['_UNLOADADAPTERRESPONSE']._serialized_start=1279
72-
_globals['_UNLOADADAPTERRESPONSE']._serialized_end=1342
73-
_globals['_ORCHESTRATOR']._serialized_start=1345
74-
_globals['_ORCHESTRATOR']._serialized_end=1530
75-
_globals['_MULTIADAPTERMANAGER']._serialized_start=1533
76-
_globals['_MULTIADAPTERMANAGER']._serialized_end=1839
59+
_globals['_ORCHESTRATOR']._serialized_start=836
60+
_globals['_ORCHESTRATOR']._serialized_end=1021
7761
# @@protoc_insertion_point(module_scope)

jetstream/core/proto/jetstream_pb2_grpc.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -136,74 +136,3 @@ def HealthCheck(
136136
metadata,
137137
)
138138

139-
140-
class MultiAdapterManagerStub(object):
141-
"""MultiAdapterManagerStub."""
142-
143-
def __init__(self, channel):
144-
"""Constructor.
145-
146-
Args:
147-
channel: A grpc.Channel.
148-
"""
149-
self.ListAdapters = channel.unary_unary(
150-
'/jetstream_proto.MultiAdapterManager/ListAdapters',
151-
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersRequest.SerializeToString,
152-
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersResponse.FromString,
153-
_registered_method=True)
154-
self.LoadAdapter = channel.unary_unary(
155-
'/jetstream_proto.MultiAdapterManager/LoadAdapter',
156-
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterRequest.SerializeToString,
157-
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterResponse.FromString,
158-
_registered_method=True)
159-
self.UnloadAdapter = channel.unary_unary(
160-
'/jetstream_proto.MultiAdapterManager/UnloadAdapter',
161-
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterRequest.SerializeToString,
162-
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterResponse.FromString,
163-
_registered_method=True)
164-
165-
166-
class MultiAdapterManagerServicer(object):
167-
"""TODO: Merge this with main JetStream core once we settle on an API."""
168-
169-
def ListAdapters(self, request, context):
170-
"""Lists all the currently loaded LoRA adapters."""
171-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
172-
context.set_details("Method not implemented!")
173-
raise NotImplementedError("Method not implemented!")
174-
175-
def LoadAdapter(self, request, context):
176-
"""Check the feasibility and load the new LoRA adapter."""
177-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
178-
context.set_details("Method not implemented!")
179-
raise NotImplementedError("Method not implemented!")
180-
181-
def UnloadAdapter(self, request, context):
182-
"""Unload a LoRA adapter."""
183-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
184-
context.set_details("Method not implemented!")
185-
raise NotImplementedError("Method not implemented!")
186-
187-
188-
def add_MultiAdapterManagerServicer_to_server(servicer, server):
189-
rpc_method_handlers = {
190-
"ListAdapters": grpc.unary_unary_rpc_method_handler(
191-
servicer.ListAdapters,
192-
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersRequest.FromString,
193-
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.ListAdaptersResponse.SerializeToString,
194-
),
195-
"LoadAdapter": grpc.unary_unary_rpc_method_handler(
196-
servicer.LoadAdapter,
197-
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterRequest.FromString,
198-
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.LoadAdapterResponse.SerializeToString,
199-
),
200-
"UnloadAdapter": grpc.unary_unary_rpc_method_handler(
201-
servicer.UnloadAdapter,
202-
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterRequest.FromString,
203-
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.UnloadAdapterResponse.SerializeToString,
204-
),
205-
}
206-
generic_handler = grpc.method_handlers_generic_handler(
207-
"jetstream_proto.MultiAdapterManager", rpc_method_handlers
208-
)
209-
server.add_generic_rpc_handlers((generic_handler,))

jetstream/core/proto/multi_lora_decoding_pb2_grpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import grpc
44
import warnings
55

6-
from jetstream.core.proto import multi_lora_decoding_pb2 as multi__lora__decoding__pb2
6+
import multi_lora_decoding_pb2 as multi__lora__decoding__pb2
77

88
GRPC_GENERATED_VERSION = '1.70.0'
99
GRPC_VERSION = grpc.__version__

jetstream/core/server_lib.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def create_driver(
135135
generate_params = [ge.load_params() for ge in engines.generate_engines]
136136
shared_params = [ie.load_params() for ie in engines.interleaved_engines]
137137
logging.info("Loaded all weights.")
138-
139138
interleaved_mode = (
140139
len(config.prefill_slices) + len(config.generate_slices) == 0
141140
)

0 commit comments

Comments
 (0)