Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class ActiveRequest:
)
################## Id of the adapter ###################
adapter_id: str = ""
################ Whether the prefill content has bos or not #################
has_bos: bool = False

def enqueue_samples(self, generated_samples: list[ReturnSample]):
"""Adds the generated sample(s) to return channel for current step.
Expand Down Expand Up @@ -600,10 +602,11 @@ def _process_prefill_content(
self,
request: ActiveRequest,
tokenizer: tokenizer_api.Tokenizer,
is_bos: bool,
max_prefill_length: int,
) -> Tuple[jax.Array | np.ndarray, int]:
content = request.prefill_content
# Add bos token if the prefill content doesn't have bos.
is_bos = not request.has_bos
if isinstance(content, str):
# If it's text input, tokenize and pad the input.
return tokenizer.encode(
Expand All @@ -614,6 +617,7 @@ def _process_prefill_content(
)
else:
# If it's token input, pad the input.
content = np.array(content)
return token_utils.pad_tokens(
content,
tokenizer.bos_id,
Expand Down Expand Up @@ -804,18 +808,16 @@ def _prefill_thread(self, idx: int):
if request is None:
break
request.metadata.prefill_dequeue_time = time.perf_counter()
is_bos = True
ThreadDebugLog(
thread_name,
f"Executing prefilling for one ActiveRequest. Current prefill "
f"backlog size: {self._prefill_backlog.qsize()},"
f" is_bos: {is_bos}",
f" has_bos: {request.has_bos}",
)
# Tokenize and padding the text or token input.
padded_tokens, true_length = self._process_prefill_content(
request,
tokenizer,
is_bos,
prefill_engine.max_prefill_length,
)

Expand Down Expand Up @@ -1704,6 +1706,7 @@ async def Decode( # pylint: disable=invalid-overridden-method
prefill_enqueue_time=time.perf_counter(),
),
num_samples=request.num_samples if request.num_samples else 1,
has_bos=request.has_bos,
)
# The first stage is being prefilled, all other stages are handled
# inside the driver (transfer, generate*N, detokenize).
Expand Down
5 changes: 4 additions & 1 deletion jetstream/core/proto/jetstream.proto
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ message DecodeRequest {

string lora_adapter_id = 9;

// Indicates whether the content has a beginning of sequence (BOS) token.
bool has_bos = 10;

reserved 1, 2, 3;
// Next ID: 10
// Next ID: 11
}

message DecodeResponse {
Expand Down
52 changes: 27 additions & 25 deletions jetstream/core/proto/jetstream_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: jetstream.proto
# source: jetstream/core/proto/jetstream.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
Expand All @@ -26,34 +26,36 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x0fjetstream.proto\x12\x0fjetstream_proto"\xaa\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x13\n\x0bnum_samples\x18\x08 \x01(\x05\x12\x17\n\x0flora_adapter_id\x18\t \x01(\t\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3'
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xbb\x03\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x12\x13\n\x0bnum_samples\x18\x08 \x01(\x05\x12\x17\n\x0flora_adapter_id\x18\t \x01(\t\x12\x0f\n\x07has_bos\x18\n \x01(\x08\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3'
)

_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "jetstream_pb2", _globals)
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, "jetstream.core.proto.jetstream_pb2", _globals
)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_DECODEREQUEST"]._serialized_start = 37
_globals["_DECODEREQUEST"]._serialized_end = 463
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 319
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 346
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 348
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 381
_globals["_DECODEREQUEST_METADATA"]._serialized_start = 383
_globals["_DECODEREQUEST_METADATA"]._serialized_end = 413
_globals["_DECODERESPONSE"]._serialized_start = 466
_globals["_DECODERESPONSE"]._serialized_end = 797
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 632
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 648
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 651
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 780
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 739
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 780
_globals["_HEALTHCHECKREQUEST"]._serialized_start = 799
_globals["_HEALTHCHECKREQUEST"]._serialized_end = 819
_globals["_HEALTHCHECKRESPONSE"]._serialized_start = 821
_globals["_HEALTHCHECKRESPONSE"]._serialized_end = 859
_globals["_ORCHESTRATOR"]._serialized_start = 862
_globals["_ORCHESTRATOR"]._serialized_end = 1047
_globals["_DECODEREQUEST"]._serialized_start = 58
_globals["_DECODEREQUEST"]._serialized_end = 501
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 357
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 384
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 386
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 419
_globals["_DECODEREQUEST_METADATA"]._serialized_start = 421
_globals["_DECODEREQUEST_METADATA"]._serialized_end = 451
_globals["_DECODERESPONSE"]._serialized_start = 504
_globals["_DECODERESPONSE"]._serialized_end = 835
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 670
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 686
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 689
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 818
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 777
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 818
_globals["_HEALTHCHECKREQUEST"]._serialized_start = 837
_globals["_HEALTHCHECKREQUEST"]._serialized_end = 857
_globals["_HEALTHCHECKRESPONSE"]._serialized_start = 859
_globals["_HEALTHCHECKRESPONSE"]._serialized_end = 897
_globals["_ORCHESTRATOR"]._serialized_start = 900
_globals["_ORCHESTRATOR"]._serialized_end = 1085
# @@protoc_insertion_point(module_scope)
26 changes: 13 additions & 13 deletions jetstream/core/proto/jetstream_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

from jetstream.core.proto import jetstream_pb2 as jetstream__pb2
from jetstream.core.proto import jetstream_pb2 as jetstream_dot_core_dot_proto_dot_jetstream__pb2


class OrchestratorStub(object):
Expand All @@ -29,13 +29,13 @@ def __init__(self, channel):
"""
self.Decode = channel.unary_stream(
"/jetstream_proto.Orchestrator/Decode",
request_serializer=jetstream__pb2.DecodeRequest.SerializeToString,
response_deserializer=jetstream__pb2.DecodeResponse.FromString,
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString,
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString,
)
self.HealthCheck = channel.unary_unary(
"/jetstream_proto.Orchestrator/HealthCheck",
request_serializer=jetstream__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=jetstream__pb2.HealthCheckResponse.FromString,
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString,
)


Expand All @@ -59,13 +59,13 @@ def add_OrchestratorServicer_to_server(servicer, server):
rpc_method_handlers = {
"Decode": grpc.unary_stream_rpc_method_handler(
servicer.Decode,
request_deserializer=jetstream__pb2.DecodeRequest.FromString,
response_serializer=jetstream__pb2.DecodeResponse.SerializeToString,
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.FromString,
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.SerializeToString,
),
"HealthCheck": grpc.unary_unary_rpc_method_handler(
servicer.HealthCheck,
request_deserializer=jetstream__pb2.HealthCheckRequest.FromString,
response_serializer=jetstream__pb2.HealthCheckResponse.SerializeToString,
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.FromString,
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
Expand Down Expand Up @@ -95,8 +95,8 @@ def Decode(
request,
target,
"/jetstream_proto.Orchestrator/Decode",
jetstream__pb2.DecodeRequest.SerializeToString,
jetstream__pb2.DecodeResponse.FromString,
jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString,
jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString,
options,
channel_credentials,
insecure,
Expand Down Expand Up @@ -124,8 +124,8 @@ def HealthCheck(
request,
target,
"/jetstream_proto.Orchestrator/HealthCheck",
jetstream__pb2.HealthCheckRequest.SerializeToString,
jetstream__pb2.HealthCheckResponse.FromString,
jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString,
jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,
Expand Down
51 changes: 32 additions & 19 deletions jetstream/core/proto/multi_lora_decoding_pb2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: multi_lora_decoding.proto
# source: jetstream/core/proto/multi_lora_decoding.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
Expand All @@ -13,30 +26,30 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x19multi_lora_decoding.proto"\x15\n\x13ListAdaptersRequest"c\n\x14ListAdaptersResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\x12#\n\radapter_infos\x18\x03 \x03(\x0b\x32\x0c.AdapterInfo"\x82\x01\n\x0b\x41\x64\x61pterInfo\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0cloading_cost\x18\x02 \x01(\x03\x12\x10\n\x08size_hbm\x18\x03 \x01(\x03\x12\x10\n\x08size_cpu\x18\x04 \x01(\x03\x12\x15\n\rlast_accessed\x18\x05 \x01(\x02\x12\x0e\n\x06status\x18\x06 \x01(\t">\n\x12LoadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t"=\n\x13LoadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t"*\n\x14UnloadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t"?\n\x15UnloadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t2\xc7\x01\n\x02v1\x12\x37\n\x06models\x12\x14.ListAdaptersRequest\x1a\x15.ListAdaptersResponse"\x00\x12@\n\x11load_lora_adapter\x12\x13.LoadAdapterRequest\x1a\x14.LoadAdapterResponse"\x00\x12\x46\n\x13unload_lora_adapter\x12\x15.UnloadAdapterRequest\x1a\x16.UnloadAdapterResponse"\x00\x62\x06proto3'
b'\n.jetstream/core/proto/multi_lora_decoding.proto"\x15\n\x13ListAdaptersRequest"c\n\x14ListAdaptersResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t\x12#\n\radapter_infos\x18\x03 \x03(\x0b\x32\x0c.AdapterInfo"\x82\x01\n\x0b\x41\x64\x61pterInfo\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0cloading_cost\x18\x02 \x01(\x03\x12\x10\n\x08size_hbm\x18\x03 \x01(\x03\x12\x10\n\x08size_cpu\x18\x04 \x01(\x03\x12\x15\n\rlast_accessed\x18\x05 \x01(\x02\x12\x0e\n\x06status\x18\x06 \x01(\t">\n\x12LoadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t"=\n\x13LoadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t"*\n\x14UnloadAdapterRequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t"?\n\x15UnloadAdapterResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x15\n\rerror_message\x18\x02 \x01(\t2\xc7\x01\n\x02v1\x12\x37\n\x06models\x12\x14.ListAdaptersRequest\x1a\x15.ListAdaptersResponse"\x00\x12@\n\x11load_lora_adapter\x12\x13.LoadAdapterRequest\x1a\x14.LoadAdapterResponse"\x00\x12\x46\n\x13unload_lora_adapter\x12\x15.UnloadAdapterRequest\x1a\x16.UnloadAdapterResponse"\x00\x62\x06proto3'
)

_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, "multi_lora_decoding_pb2", _globals
DESCRIPTOR, "jetstream.core.proto.multi_lora_decoding_pb2", _globals
)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_LISTADAPTERSREQUEST"]._serialized_start = 29
_globals["_LISTADAPTERSREQUEST"]._serialized_end = 50
_globals["_LISTADAPTERSRESPONSE"]._serialized_start = 52
_globals["_LISTADAPTERSRESPONSE"]._serialized_end = 151
_globals["_ADAPTERINFO"]._serialized_start = 154
_globals["_ADAPTERINFO"]._serialized_end = 284
_globals["_LOADADAPTERREQUEST"]._serialized_start = 286
_globals["_LOADADAPTERREQUEST"]._serialized_end = 348
_globals["_LOADADAPTERRESPONSE"]._serialized_start = 350
_globals["_LOADADAPTERRESPONSE"]._serialized_end = 411
_globals["_UNLOADADAPTERREQUEST"]._serialized_start = 413
_globals["_UNLOADADAPTERREQUEST"]._serialized_end = 455
_globals["_UNLOADADAPTERRESPONSE"]._serialized_start = 457
_globals["_UNLOADADAPTERRESPONSE"]._serialized_end = 520
_globals["_V1"]._serialized_start = 523
_globals["_V1"]._serialized_end = 722
_globals["_LISTADAPTERSREQUEST"]._serialized_start = 50
_globals["_LISTADAPTERSREQUEST"]._serialized_end = 71
_globals["_LISTADAPTERSRESPONSE"]._serialized_start = 73
_globals["_LISTADAPTERSRESPONSE"]._serialized_end = 172
_globals["_ADAPTERINFO"]._serialized_start = 175
_globals["_ADAPTERINFO"]._serialized_end = 305
_globals["_LOADADAPTERREQUEST"]._serialized_start = 307
_globals["_LOADADAPTERREQUEST"]._serialized_end = 369
_globals["_LOADADAPTERRESPONSE"]._serialized_start = 371
_globals["_LOADADAPTERRESPONSE"]._serialized_end = 432
_globals["_UNLOADADAPTERREQUEST"]._serialized_start = 434
_globals["_UNLOADADAPTERREQUEST"]._serialized_end = 476
_globals["_UNLOADADAPTERRESPONSE"]._serialized_start = 478
_globals["_UNLOADADAPTERRESPONSE"]._serialized_end = 541
_globals["_V1"]._serialized_start = 544
_globals["_V1"]._serialized_end = 743
# @@protoc_insertion_point(module_scope)
Loading