1313// limitations under the License.
1414
1515#include < deque>
16+ #include < exception>
1617#include < map>
18+ #include < memory>
19+ #include < optional>
1720#include < sstream>
1821#include < stdexcept>
22+ #include < string>
1923#include < string_view>
2024#include < utility>
25+ #include < variant>
26+ #include < vector>
2127
2228#include " nanobind/nanobind.h"
2329#include " nanobind/stl/optional.h" // IWYU pragma: keep
24- #include " nanobind/stl/shared_ptr.h"
30+ #include " nanobind/stl/shared_ptr.h" // IWYU pragma: keep
2531#include " nanobind/stl/string_view.h" // IWYU pragma: keep
2632#include " nanobind/stl/unique_ptr.h" // IWYU pragma: keep
2733#include " nanobind/stl/variant.h" // IWYU pragma: keep
2834#include " nanobind/stl/vector.h" // IWYU pragma: keep
2935#include " absl/base/log_severity.h" // from @com_google_absl
36+ #include " absl/base/thread_annotations.h" // from @com_google_absl
3037#include " absl/log/globals.h" // from @com_google_absl
38+ #include " absl/status/status.h" // from @com_google_absl
3139#include " absl/status/statusor.h" // from @com_google_absl
40+ #include " absl/strings/string_view.h" // from @com_google_absl
3241#include " absl/synchronization/mutex.h" // from @com_google_absl
3342#include " absl/time/time.h" // from @com_google_absl
43+ #include " nlohmann/json_fwd.hpp" // from @nlohmann_json
3444#include " nanobind_json/nanobind_json.hpp" // from @nanobind_json // IWYU pragma: keep
3545#include " litert/c/internal/litert_logging.h" // from @litert
3646#include " runtime/conversation/conversation.h"
47+ #include " runtime/conversation/io_types.h"
3748#include " runtime/engine/engine.h"
3849#include " runtime/engine/engine_factory.h"
39- #include " tflite/core/c/c_api_types.h" // from @litert
50+ #include " runtime/engine/engine_settings.h"
51+ #include " runtime/engine/io_types.h"
52+ #include " runtime/executor/executor_settings_base.h"
4053#include " tflite/logger.h" // from @litert
4154#include " tflite/minimal_logging.h" // from @litert
4255
@@ -468,7 +481,8 @@ NB_MODULE(litert_lm_ext, module) {
468481 .def (
469482 " create_conversation" ,
470483 [](const nb::object& self, const nb::handle& messages,
471- const nb::handle& tools, const nb::handle& tool_event_handler) {
484+ const nb::handle& tools, const nb::handle& tool_event_handler,
485+ const nb::handle& extra_context) {
472486 Engine& engine = nb::cast<Engine&>(self);
473487
474488 auto builder = ConversationConfig::Builder ();
@@ -503,6 +517,12 @@ NB_MODULE(litert_lm_ext, module) {
503517 has_preface = true ;
504518 }
505519
520+ if (!extra_context.is_none ()) {
521+ json_preface.extra_context =
522+ nb::cast<nlohmann::json>(extra_context);
523+ has_preface = true ;
524+ }
525+
506526 if (has_preface) {
507527 builder.SetPreface (json_preface);
508528 }
@@ -515,6 +535,7 @@ NB_MODULE(litert_lm_ext, module) {
515535 nb::object py_conversation = nb::cast (std::move (conversation));
516536 py_conversation.attr (" _tool_map" ) = py_tool_map;
517537 py_conversation.attr (" tool_event_handler" ) = tool_event_handler;
538+ py_conversation.attr (" extra_context" ) = extra_context;
518539 if (messages.is_none ()) {
519540 py_conversation.attr (" messages" ) = nb::list ();
520541 } else {
@@ -529,7 +550,8 @@ NB_MODULE(litert_lm_ext, module) {
529550 },
530551 nb::kw_only (), nb::arg (" messages" ) = nb::none (),
531552 nb::arg (" tools" ) = nb::none (),
532- nb::arg (" tool_event_handler" ) = nb::none ())
553+ nb::arg (" tool_event_handler" ) = nb::none (),
554+ nb::arg (" extra_context" ) = nb::none ())
533555 .def (
534556 " create_session" ,
535557 [](Engine& self) {
@@ -557,6 +579,7 @@ NB_MODULE(litert_lm_ext, module) {
557579 " run_prefill" ,
558580 [](Engine::Session& self, const std::vector<std::string>& contents) {
559581 std::vector<InputData> input_data;
582+ input_data.reserve (contents.size ());
560583 for (const auto & text : contents) {
561584 input_data.emplace_back (InputText (text));
562585 }
@@ -580,6 +603,7 @@ NB_MODULE(litert_lm_ext, module) {
580603 [](Engine::Session& self, const std::vector<std::string>& target_text,
581604 bool store_token_lengths) {
582605 std::vector<absl::string_view> target_text_views;
606+ target_text_views.reserve (target_text.size ());
583607 for (const auto & text : target_text) {
584608 target_text_views.push_back (text);
585609 }
0 commit comments