Skip to content

Commit f36a047

Browse files
matthewchan-gcopybara-github
authored andcommitted
Internal change
LiteRT-LM-PiperOrigin-RevId: 890131978
1 parent 0543ddc commit f36a047

6 files changed

Lines changed: 67 additions & 15 deletions

File tree

python/litert_lm/BUILD

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,22 +59,28 @@ cc_binary(
5959
"nowindows",
6060
],
6161
deps = [
62+
"@com_google_absl//absl/base:core_headers",
6263
"@com_google_absl//absl/base:log_severity",
6364
"@com_google_absl//absl/log:globals",
65+
"@com_google_absl//absl/status",
6466
"@com_google_absl//absl/status:statusor",
67+
"@com_google_absl//absl/strings:string_view",
6568
"@com_google_absl//absl/synchronization",
6669
"@com_google_absl//absl/time",
70+
"@nlohmann_json//:json",
6771
"@nanobind",
6872
"@nanobind_json",
6973
"@litert//litert/c/internal:litert_logging",
7074
"//runtime/conversation",
75+
"//runtime/conversation:io_types",
7176
"//runtime/core:engine_impl",
7277
"//runtime/engine:engine_factory",
7378
"//runtime/engine:engine_interface",
7479
"//runtime/engine:engine_settings",
75-
"@rules_python//python/cc:current_py_cc_headers",
80+
"//runtime/engine:io_types",
81+
"//runtime/executor:executor_settings_base",
82+
"@rules_python//python/cc:current_py_cc_headers", # buildcleaner: keep
7683
"@litert//tflite:minimal_logging",
77-
"@litert//tflite/core/c:private_c_api_types",
7884
],
7985
)
8086

python/litert_lm/engine_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ def test_create_conversation_with_messages(self):
148148
):
149149
self.assertEqual(conversation.messages, messages)
150150

151+
def test_create_conversation_with_extra_context(self):
152+
extra_context = {"key": "value"}
153+
with (
154+
self._create_engine() as engine,
155+
engine.create_conversation(extra_context=extra_context) as conversation,
156+
):
157+
self.assertEqual(conversation.extra_context, extra_context)
158+
151159
def test_str_input_support(self):
152160
with (
153161
self._create_engine() as engine,

python/litert_lm/interfaces.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import collections.abc
2121
import dataclasses
2222
import enum
23-
import pathlib
2423
from typing import Any
2524

2625

@@ -103,6 +102,7 @@ def create_conversation(
103102
collections.abc.Sequence[collections.abc.Callable[..., Any]] | None
104103
) = None,
105104
tool_event_handler: ToolEventHandler | None = None,
105+
extra_context: collections.abc.Mapping[str, Any] | None = None,
106106
) -> AbstractConversation:
107107
"""Creates a new conversation for this engine.
108108
@@ -111,6 +111,7 @@ def create_conversation(
111111
message is a mapping that should contain 'role' and 'content' keys.
112112
tools: A list of Python functions to be used as tools.
113113
tool_event_handler: A handler for tool call and tool response events.
114+
extra_context: Extra context for the conversation.
114115
"""
115116

116117
@abc.abstractmethod
@@ -129,6 +130,7 @@ class AbstractConversation(abc.ABC):
129130
messages: A sequence of messages for the conversation preface.
130131
tools: A list of Python functions to be used as tools.
131132
tool_event_handler: A handler for tool call and tool response events.
133+
extra_context: Extra context for the chat template.
132134
"""
133135

134136
def __init__(
@@ -141,6 +143,7 @@ def __init__(
141143
collections.abc.Sequence[collections.abc.Callable[..., Any]] | None
142144
) = None,
143145
tool_event_handler: ToolEventHandler | None = None,
146+
extra_context: collections.abc.Mapping[str, Any] | None = None,
144147
):
145148
"""Initializes the instance.
146149
@@ -149,10 +152,12 @@ def __init__(
149152
message is a mapping that should contain 'role' and 'content' keys.
150153
tools: A list of Python functions to be used as tools.
151154
tool_event_handler: A handler for tool call and tool response events.
155+
extra_context: Extra context for the chat template.
152156
"""
153157
self.messages = messages or []
154158
self.tools = tools or []
155159
self.tool_event_handler = tool_event_handler
160+
self.extra_context = extra_context or {}
156161

157162
def __enter__(self) -> AbstractConversation:
158163
"""Initializes the conversation."""

python/litert_lm/litert_lm.cc

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,43 @@
1313
// limitations under the License.
1414

1515
#include <deque>
16+
#include <exception>
1617
#include <map>
18+
#include <memory>
19+
#include <optional>
1720
#include <sstream>
1821
#include <stdexcept>
22+
#include <string>
1923
#include <string_view>
2024
#include <utility>
25+
#include <variant>
26+
#include <vector>
2127

2228
#include "nanobind/nanobind.h"
2329
#include "nanobind/stl/optional.h" // IWYU pragma: keep
24-
#include "nanobind/stl/shared_ptr.h"
30+
#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep
2531
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
2632
#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep
2733
#include "nanobind/stl/variant.h" // IWYU pragma: keep
2834
#include "nanobind/stl/vector.h" // IWYU pragma: keep
2935
#include "absl/base/log_severity.h" // from @com_google_absl
36+
#include "absl/base/thread_annotations.h" // from @com_google_absl
3037
#include "absl/log/globals.h" // from @com_google_absl
38+
#include "absl/status/status.h" // from @com_google_absl
3139
#include "absl/status/statusor.h" // from @com_google_absl
40+
#include "absl/strings/string_view.h" // from @com_google_absl
3241
#include "absl/synchronization/mutex.h" // from @com_google_absl
3342
#include "absl/time/time.h" // from @com_google_absl
43+
#include "nlohmann/json_fwd.hpp" // from @nlohmann_json
3444
#include "nanobind_json/nanobind_json.hpp" // from @nanobind_json // IWYU pragma: keep
3545
#include "litert/c/internal/litert_logging.h" // from @litert
3646
#include "runtime/conversation/conversation.h"
47+
#include "runtime/conversation/io_types.h"
3748
#include "runtime/engine/engine.h"
3849
#include "runtime/engine/engine_factory.h"
39-
#include "tflite/core/c/c_api_types.h" // from @litert
50+
#include "runtime/engine/engine_settings.h"
51+
#include "runtime/engine/io_types.h"
52+
#include "runtime/executor/executor_settings_base.h"
4053
#include "tflite/logger.h" // from @litert
4154
#include "tflite/minimal_logging.h" // from @litert
4255

@@ -468,7 +481,8 @@ NB_MODULE(litert_lm_ext, module) {
468481
.def(
469482
"create_conversation",
470483
[](const nb::object& self, const nb::handle& messages,
471-
const nb::handle& tools, const nb::handle& tool_event_handler) {
484+
const nb::handle& tools, const nb::handle& tool_event_handler,
485+
const nb::handle& extra_context) {
472486
Engine& engine = nb::cast<Engine&>(self);
473487

474488
auto builder = ConversationConfig::Builder();
@@ -503,6 +517,12 @@ NB_MODULE(litert_lm_ext, module) {
503517
has_preface = true;
504518
}
505519

520+
if (!extra_context.is_none()) {
521+
json_preface.extra_context =
522+
nb::cast<nlohmann::json>(extra_context);
523+
has_preface = true;
524+
}
525+
506526
if (has_preface) {
507527
builder.SetPreface(json_preface);
508528
}
@@ -515,6 +535,7 @@ NB_MODULE(litert_lm_ext, module) {
515535
nb::object py_conversation = nb::cast(std::move(conversation));
516536
py_conversation.attr("_tool_map") = py_tool_map;
517537
py_conversation.attr("tool_event_handler") = tool_event_handler;
538+
py_conversation.attr("extra_context") = extra_context;
518539
if (messages.is_none()) {
519540
py_conversation.attr("messages") = nb::list();
520541
} else {
@@ -529,7 +550,8 @@ NB_MODULE(litert_lm_ext, module) {
529550
},
530551
nb::kw_only(), nb::arg("messages") = nb::none(),
531552
nb::arg("tools") = nb::none(),
532-
nb::arg("tool_event_handler") = nb::none())
553+
nb::arg("tool_event_handler") = nb::none(),
554+
nb::arg("extra_context") = nb::none())
533555
.def(
534556
"create_session",
535557
[](Engine& self) {
@@ -557,6 +579,7 @@ NB_MODULE(litert_lm_ext, module) {
557579
"run_prefill",
558580
[](Engine::Session& self, const std::vector<std::string>& contents) {
559581
std::vector<InputData> input_data;
582+
input_data.reserve(contents.size());
560583
for (const auto& text : contents) {
561584
input_data.emplace_back(InputText(text));
562585
}
@@ -580,6 +603,7 @@ NB_MODULE(litert_lm_ext, module) {
580603
[](Engine::Session& self, const std::vector<std::string>& target_text,
581604
bool store_token_lengths) {
582605
std::vector<absl::string_view> target_text_views;
606+
target_text_views.reserve(target_text.size());
583607
for (const auto& text : target_text) {
584608
target_text_views.push_back(text);
585609
}

python/litert_lm_cli/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,10 @@ def add_numbers(a: float, b: float) -> float:
226226
"""Adds two numbers."""
227227
return a + b
228228
229-
# Provides the "system instruction" and "tools"
229+
# Provides the "system instruction", "tools", and "extra_context"
230230
system_instruction = "You are a helpful assistant."
231231
tools = [add_numbers]
232+
extra_context = {"key": "value"}
232233
```
233234
234235
Args:

python/litert_lm_cli/model.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,16 @@
3636

3737

3838
def load_preset(preset: str):
39-
"""Loads a preset file and returns the tools and messages."""
39+
"""Loads a preset file and returns the tools, messages and extra_context."""
4040
print(f"Loading preset from {preset}:")
4141
if not os.path.exists(preset):
4242
print(f"Preset file not found: {preset}")
43-
return None, None
43+
return None, None, None
4444

4545
spec = importlib.util.spec_from_file_location("user_tools", preset)
4646
if not spec or not spec.loader:
4747
print(f"Failed to load tools from {preset}")
48-
return None, None
48+
return None, None, None
4949

5050
user_tools = importlib.util.module_from_spec(spec)
5151
spec.loader.exec_module(user_tools)
@@ -71,7 +71,11 @@ def load_preset(preset: str):
7171
for tool in tools:
7272
print(f" - {getattr(tool, "__name__", str(tool))}")
7373

74-
return tools, messages
74+
extra_context = getattr(user_tools, "extra_context", None)
75+
if extra_context:
76+
print(f"- Extra context: {extra_context}")
77+
78+
return tools, messages, extra_context
7579

7680

7781
_GREEN = "\033[92m"
@@ -147,9 +151,10 @@ def run_interactive(
147151

148152
tools = None
149153
messages = None
154+
extra_context = None
150155
if preset:
151-
tools, messages = load_preset(preset)
152-
if tools is None:
156+
tools, messages, extra_context = load_preset(preset)
157+
if tools is None and messages is None and extra_context is None:
153158
return
154159

155160
handler = LoggingToolEventHandler() if tools else None
@@ -164,7 +169,10 @@ def run_interactive(
164169
with (
165170
engine_cm as engine,
166171
engine.create_conversation(
167-
tools=tools, messages=messages, tool_event_handler=handler
172+
tools=tools,
173+
messages=messages,
174+
tool_event_handler=handler,
175+
extra_context=extra_context,
168176
) as conversation,
169177
):
170178
if prompt:

0 commit comments

Comments
 (0)