Skip to content

Commit a41e4cd

Browse files
committed
Refactor part-3.
1 parent eb74d86 commit a41e4cd

5 files changed

Lines changed: 111 additions & 199 deletions

File tree

jetstream/core/orchestrator.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,6 @@ def _prefill_thread(self, idx: int):
580580

581581
if request is None:
582582
break
583-
584583
request.metadata.prefill_dequeue_time = time.perf_counter()
585584
is_bos = True
586585
logging.info(
@@ -590,7 +589,6 @@ def _prefill_thread(self, idx: int):
590589
self._prefill_backlog.qsize(),
591590
is_bos,
592591
)
593-
594592
# Tokenize and padding the text or token input.
595593
padded_tokens, true_length = self._process_prefill_content(
596594
request, tokenizer, is_bos, prefill_engine.max_prefill_length
@@ -703,7 +701,6 @@ def _transfer_thread(self, idx: int):
703701
# Place the request on the correct generate backlog and block if full.
704702
new_request.metadata.generate_enqueue_time = time.perf_counter()
705703
self._generate_backlogs[target_idx].put(new_request, block=True)
706-
707704
logging.info(
708705
"Successfully transferred prefill "
709706
"from prefill engine %d to generate engine %d "
@@ -727,7 +724,6 @@ def _generate_thread(self, idx: int):
727724
decode_state = generate_engine.init_decode_state()
728725

729726
generate_params = self._generate_params[idx]
730-
731727
logging.info("---------Generate params %d loaded.---------", idx)
732728
time_of_last_generate = time.time()
733729
time_of_last_print = time.time()
@@ -841,7 +837,6 @@ def _generate_thread(self, idx: int):
841837
generate_params, decode_state
842838
)
843839
sampled_tokens.copy_to_host_async()
844-
845840
# Respond to detokenization backpressure.
846841
my_detokenize_backlog.put((generate_timestep, sampled_tokens), block=True)
847842
generate_timestep += 1
@@ -1135,7 +1130,6 @@ async def Decode( # pylint: disable=invalid-overridden-method
11351130
prefill_content, is_client_side_tokenization = self._get_prefill_content(
11361131
request
11371132
)
1138-
11391133
# Wrap request as an ActiveRequest.
11401134
active_request = ActiveRequest(
11411135
max_tokens=request.max_tokens,

jetstream/core/proto/jetstream_pb2.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
# -*- coding: utf-8 -*-
1615
# Generated by the protocol buffer compiler. DO NOT EDIT!
17-
# NO CHECKED-IN PROTOBUF GENCODE
1816
# source: jetstream.proto
19-
# Protobuf Python Version: 5.29.0
17+
# Protobuf Python Version: 4.25.1
2018
"""Generated protocol buffer code."""
2119
from google.protobuf import descriptor as _descriptor
2220
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -34,8 +32,8 @@
3432
_globals = globals()
3533
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
3634
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'jetstream_pb2', _globals)
37-
if not _descriptor._USE_C_DESCRIPTORS:
38-
DESCRIPTOR._loaded_options = None
35+
if _descriptor._USE_C_DESCRIPTORS == False:
36+
DESCRIPTOR._options = None
3937
_globals['_DECODEREQUEST']._serialized_start=37
4038
_globals['_DECODEREQUEST']._serialized_end=437
4139
_globals['_DECODEREQUEST_TEXTCONTENT']._serialized_start=293

jetstream/core/proto/jetstream_pb2_grpc.py

Lines changed: 88 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -15,124 +15,106 @@
1515
"""Client and server classes corresponding to protobuf-defined services."""
1616
import grpc
1717

18-
from jetstream.core.proto import jetstream_pb2 as jetstream_dot_core_dot_proto_dot_jetstream__pb2
18+
from jetstream.core.proto import jetstream_pb2 as jetstream__pb2
1919

2020

2121
class OrchestratorStub(object):
22-
"""TODO: Merge this with main JetStream core once we settle on an API."""
22+
"""TODO: Merge this with main JetStream core once we settle on an API.
2323
24-
def __init__(self, channel):
25-
"""Constructor.
26-
27-
Args:
28-
channel: A grpc.Channel.
2924
"""
30-
self.Decode = channel.unary_stream(
31-
"/jetstream_proto.Orchestrator/Decode",
32-
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString,
33-
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString,
34-
)
35-
self.HealthCheck = channel.unary_unary(
36-
"/jetstream_proto.Orchestrator/HealthCheck",
37-
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString,
38-
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString,
39-
)
25+
26+
def __init__(self, channel):
27+
"""Constructor.
28+
29+
Args:
30+
channel: A grpc.Channel.
31+
"""
32+
self.Decode = channel.unary_stream(
33+
'/jetstream_proto.Orchestrator/Decode',
34+
request_serializer=jetstream__pb2.DecodeRequest.SerializeToString,
35+
response_deserializer=jetstream__pb2.DecodeResponse.FromString,
36+
)
37+
self.HealthCheck = channel.unary_unary(
38+
'/jetstream_proto.Orchestrator/HealthCheck',
39+
request_serializer=jetstream__pb2.HealthCheckRequest.SerializeToString,
40+
response_deserializer=jetstream__pb2.HealthCheckResponse.FromString,
41+
)
4042

4143

4244
class OrchestratorServicer(object):
43-
"""TODO: Merge this with main JetStream core once we settle on an API."""
45+
"""TODO: Merge this with main JetStream core once we settle on an API.
4446
45-
def Decode(self, request, context):
46-
"""Query LLM to generate text or tokens."""
47-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
48-
context.set_details("Method not implemented!")
49-
raise NotImplementedError("Method not implemented!")
47+
"""
5048

51-
def HealthCheck(self, request, context):
52-
"""Checks if the model server is live."""
53-
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
54-
context.set_details("Method not implemented!")
55-
raise NotImplementedError("Method not implemented!")
49+
def Decode(self, request, context):
50+
"""Query LLM to generate text or tokens.
51+
"""
52+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
53+
context.set_details('Method not implemented!')
54+
raise NotImplementedError('Method not implemented!')
55+
56+
def HealthCheck(self, request, context):
57+
"""Checks if the model server is live.
58+
"""
59+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
60+
context.set_details('Method not implemented!')
61+
raise NotImplementedError('Method not implemented!')
5662

5763

5864
def add_OrchestratorServicer_to_server(servicer, server):
59-
rpc_method_handlers = {
60-
"Decode": grpc.unary_stream_rpc_method_handler(
61-
servicer.Decode,
62-
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.FromString,
63-
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.SerializeToString,
64-
),
65-
"HealthCheck": grpc.unary_unary_rpc_method_handler(
66-
servicer.HealthCheck,
67-
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.FromString,
68-
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.SerializeToString,
69-
),
70-
}
71-
generic_handler = grpc.method_handlers_generic_handler(
72-
"jetstream_proto.Orchestrator", rpc_method_handlers
73-
)
74-
server.add_generic_rpc_handlers((generic_handler,))
75-
76-
77-
# This class is part of an EXPERIMENTAL API.
65+
rpc_method_handlers = {
66+
'Decode': grpc.unary_stream_rpc_method_handler(
67+
servicer.Decode,
68+
request_deserializer=jetstream__pb2.DecodeRequest.FromString,
69+
response_serializer=jetstream__pb2.DecodeResponse.SerializeToString,
70+
),
71+
'HealthCheck': grpc.unary_unary_rpc_method_handler(
72+
servicer.HealthCheck,
73+
request_deserializer=jetstream__pb2.HealthCheckRequest.FromString,
74+
response_serializer=jetstream__pb2.HealthCheckResponse.SerializeToString,
75+
),
76+
}
77+
generic_handler = grpc.method_handlers_generic_handler(
78+
'jetstream_proto.Orchestrator', rpc_method_handlers)
79+
server.add_generic_rpc_handlers((generic_handler,))
80+
81+
82+
# This class is part of an EXPERIMENTAL API.
7883
class Orchestrator(object):
79-
"""TODO: Merge this with main JetStream core once we settle on an API."""
80-
81-
@staticmethod
82-
def Decode(
83-
request,
84-
target,
85-
options=(),
86-
channel_credentials=None,
87-
call_credentials=None,
88-
insecure=False,
89-
compression=None,
90-
wait_for_ready=None,
91-
timeout=None,
92-
metadata=None,
93-
):
94-
return grpc.experimental.unary_stream(
95-
request,
96-
target,
97-
"/jetstream_proto.Orchestrator/Decode",
98-
jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString,
99-
jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString,
100-
options,
101-
channel_credentials,
102-
insecure,
103-
call_credentials,
104-
compression,
105-
wait_for_ready,
106-
timeout,
107-
metadata,
108-
)
109-
110-
@staticmethod
111-
def HealthCheck(
112-
request,
113-
target,
114-
options=(),
115-
channel_credentials=None,
116-
call_credentials=None,
117-
insecure=False,
118-
compression=None,
119-
wait_for_ready=None,
120-
timeout=None,
121-
metadata=None,
122-
):
123-
return grpc.experimental.unary_unary(
124-
request,
125-
target,
126-
"/jetstream_proto.Orchestrator/HealthCheck",
127-
jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString,
128-
jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString,
129-
options,
130-
channel_credentials,
131-
insecure,
132-
call_credentials,
133-
compression,
134-
wait_for_ready,
135-
timeout,
136-
metadata,
137-
)
84+
"""TODO: Merge this with main JetStream core once we settle on an API.
85+
86+
"""
13887

88+
@staticmethod
89+
def Decode(request,
90+
target,
91+
options=(),
92+
channel_credentials=None,
93+
call_credentials=None,
94+
insecure=False,
95+
compression=None,
96+
wait_for_ready=None,
97+
timeout=None,
98+
metadata=None):
99+
return grpc.experimental.unary_stream(request, target, '/jetstream_proto.Orchestrator/Decode',
100+
jetstream__pb2.DecodeRequest.SerializeToString,
101+
jetstream__pb2.DecodeResponse.FromString,
102+
options, channel_credentials,
103+
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
104+
105+
@staticmethod
106+
def HealthCheck(request,
107+
target,
108+
options=(),
109+
channel_credentials=None,
110+
call_credentials=None,
111+
insecure=False,
112+
compression=None,
113+
wait_for_ready=None,
114+
timeout=None,
115+
metadata=None):
116+
return grpc.experimental.unary_unary(request, target, '/jetstream_proto.Orchestrator/HealthCheck',
117+
jetstream__pb2.HealthCheckRequest.SerializeToString,
118+
jetstream__pb2.HealthCheckResponse.FromString,
119+
options, channel_credentials,
120+
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

jetstream/core/proto/multi_lora_decoding_pb2.py

Lines changed: 3 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)