Skip to content

Commit 23314b9

Browse files
committed
Be more resilient to gemini outputting thoughts in api response
1 parent 9af4349 commit 23314b9

8 files changed

Lines changed: 80 additions & 127 deletions

File tree

.github/workflows/build-ruby-release.reusable.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: BAML Release - Build Ruby
33
on:
44
workflow_call: {}
55
push:
6-
branches: [bump-ruby]
6+
branches: [gemini-fix]
77

88
permissions:
99
contents: read

engine/Cargo.lock

Lines changed: 1 addition & 74 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/baml-runtime/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ notify-debouncer-full = "0.3.1"
172172
ring = { version = "0.17.4", features = ["std"] }
173173
tokio = { version = "1", features = ["full"] }
174174
reqwest.workspace = true
175-
rustls = "0.23.26"
176175
walkdir = "2.5.0"
177176
which = "6.0.3"
178177
indicatif = "0.17"

engine/baml-runtime/src/internal/llm_client/primitive/google/googleai_client.rs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -359,26 +359,6 @@ impl ToProviderMessage for GoogleAIClient {
359359
}
360360
}
361361

362-
/// The Google Gemini 2 model has an experimental feature
363-
/// called Flash Thinking Mode, which is turned on in a particular
364-
/// named model: gemini-2.0-flash-thinking-exp-1219
365-
///
366-
/// When run in this mode, Gemini returns `candidates` with 2 parts each.
367-
/// Part 0 is the chain of thought, part 1 is the actual output.
368-
/// Other Gemini models put the output data in part 0.
369-
///
370-
/// TODO: Explicitly represent Flash Thinking Mode response and
371-
/// do more thorough checking for the content part.
372-
/// For examples of how to introspect the response more safely, see:
373-
/// https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_2_0_flash_thinking_mode.ipynb
374-
fn content_part(model_name: &str) -> usize {
375-
if model_name.contains("gemini-2.0-flash-thinking-exp-1219") {
376-
1
377-
} else {
378-
0
379-
}
380-
}
381-
382362
impl CompletionToProviderBody for GoogleAIClient {
383363
fn completion_to_provider_body(
384364
&self,

engine/baml-runtime/src/internal/llm_client/primitive/google/response_handler.rs

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use anyhow::Result;
22

3-
use super::types::GoogleResponse;
3+
use super::types::{GoogleResponse, Part};
44
use crate::internal::llm_client::{
55
primitive::request::RequestBuilder, traits::WithClient, ErrorCode, LLMCompleteResponse,
66
LLMCompleteResponseMetadata, LLMErrorResponse, LLMResponse,
@@ -26,6 +26,7 @@ pub fn parse_google_response<C: WithClient + RequestBuilder>(
2626
instant_now: web_time::Instant,
2727
model_name: Option<String>,
2828
) -> LLMResponse {
29+
// baml_log::info!("Parsing Google response: {:#?}", response_body);
2930
let response = match GoogleResponse::deserialize(&response_body)
3031
.context(format!(
3132
"Failed to parse into a response accepted by {}: {}",
@@ -78,11 +79,11 @@ pub fn parse_google_response<C: WithClient + RequestBuilder>(
7879
};
7980

8081
let model_name = model_name.unwrap_or("<unknown>".to_string());
81-
let part_index = content_part(&model_name);
82+
let text_content = text_content_part(&content.parts);
8283
LLMResponse::Success(LLMCompleteResponse {
8384
client: client.context().name.to_string(),
8485
prompt: to_prompt(prompt),
85-
content: content.parts[part_index].text.clone(),
86+
content: text_content.unwrap_or_default(),
8687
start_time: system_now,
8788
latency: instant_now.elapsed(),
8889
request_options: client.request_options().clone(),
@@ -101,12 +102,11 @@ pub fn parse_google_response<C: WithClient + RequestBuilder>(
101102
})
102103
}
103104

104-
fn content_part(model_name: &str) -> usize {
105-
if model_name.contains("gemini-2.0-flash-thinking-exp-1219") {
106-
1
107-
} else {
108-
0
109-
}
105+
fn text_content_part(parts: &Vec<Part>) -> Option<String> {
106+
parts
107+
.iter()
108+
.position(|part| !part.text.is_empty() && part.thought.unwrap_or(false) == false)
109+
.map(|index| parts[index].text.clone())
110110
}
111111

112112
pub fn scan_google_response_stream(
@@ -145,18 +145,13 @@ pub fn scan_google_response_stream(
145145
Err(e) => return Err(e),
146146
};
147147
if let Some(choice) = event.candidates.get(0) {
148-
let part_index = content_part(
149-
model_name
150-
.as_ref()
151-
.map(|s| s.as_str())
152-
.unwrap_or("<unknown>"),
153-
);
154-
if let Some(content) = choice
148+
let text_content = &choice
155149
.content
156150
.as_ref()
157-
.and_then(|c| c.parts.get(part_index))
158-
{
159-
inner.content += &content.text;
151+
.and_then(|c| text_content_part(&c.parts));
152+
153+
if let Some(text_content) = text_content {
154+
inner.content += &text_content;
160155
}
161156
inner.metadata.finish_reason = choice.finish_reason.as_ref().map(|r| r.to_string());
162157
if choice
@@ -235,6 +230,7 @@ mod tests {
235230
function_call: None,
236231
function_response: None,
237232
video_metadata: None,
233+
thought: None,
238234
}],
239235
role: Some("model".to_string()),
240236
}),

engine/baml-runtime/src/internal/llm_client/primitive/google/types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ pub struct Part {
245245
pub function_call: Option<FunctionCall>,
246246
pub function_response: Option<FunctionResponse>,
247247
pub video_metadata: Option<VideoMetadata>,
248+
pub thought: Option<bool>,
248249
}
249250

250251
#[derive(Serialize, Deserialize, Debug)]

engine/baml-runtime/src/lib.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,16 @@ static TOKIO_SINGLETON: OnceLock<std::io::Result<Arc<tokio::runtime::Runtime>>>
115115

116116
static INIT: std::sync::Once = std::sync::Once::new();
117117

118-
fn setup_crypto_provider() {
119-
#[cfg(not(target_arch = "wasm32"))]
120-
{
121-
use rustls::crypto::CryptoProvider;
122-
INIT.call_once(|| {
123-
let provider = rustls::crypto::ring::default_provider();
124-
CryptoProvider::install_default(provider).expect("failed to install CryptoProvider");
125-
});
126-
}
127-
}
118+
// fn setup_crypto_provider() {
119+
// #[cfg(not(target_arch = "wasm32"))]
120+
// {
121+
// use rustls::crypto::CryptoProvider;
122+
// INIT.call_once(|| {
123+
// let provider = rustls::crypto::ring::default_provider();
124+
// CryptoProvider::install_default(provider).expect("failed to install CryptoProvider");
125+
// });
126+
// }
127+
// }
128128

129129
#[derive(Clone)]
130130
pub struct BamlRuntime {
@@ -187,7 +187,7 @@ impl BamlRuntime {
187187
path: &std::path::Path,
188188
env_vars: HashMap<T, T>,
189189
) -> Result<Self> {
190-
setup_crypto_provider();
190+
// setup_crypto_provider();
191191
let path = Self::parse_baml_src_path(path)?;
192192

193193
let copy = env_vars
@@ -210,7 +210,7 @@ impl BamlRuntime {
210210
files: &HashMap<T, T>,
211211
env_vars: HashMap<U, U>,
212212
) -> Result<Self> {
213-
setup_crypto_provider();
213+
// setup_crypto_provider();
214214
let copy = env_vars
215215
.iter()
216216
.map(|(k, v)| (k.as_ref().to_string(), v.as_ref().to_string()))

integ-tests/python/tests/test_functions.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,56 @@ async def test_streaming_gemini():
792792
assert msgs[-1] == final, "Expected last stream message to match final response."
793793

794794

795+
@pytest.mark.asyncio
796+
async def test_gemini_models():
797+
client_registry = baml_py.ClientRegistry()
798+
# # Test with gemini-1.5-flash-thinking-exp-1219
799+
# client_registry.add_llm_client(
800+
# "MyCustomGeminiClient",
801+
# "google-ai",
802+
# {"model": "gemini-1.5-flash-thinking-exp-1219"},
803+
# )
804+
# client_registry.set_primary("MyCustomGeminiClient")
805+
# res = await b.TestGemini(
806+
# input="Dr.Pepper", baml_options={"client_registry": client_registry}
807+
# )
808+
# assert len(res) > 0, "Expected non-empty result but got empty."
809+
810+
# Test with gemini-2.5-pro-preview-05-06
811+
# client_registry.add_llm_client(
812+
# "Gemini25ProMay", "google-ai", {"model": "gemini-2.5-pro-preview-05-06"}
813+
# )
814+
# client_registry.set_primary("Gemini25ProMay")
815+
# res = await b.TestGemini(
816+
# input="sea. Actually output the multiplication of 23*12/12+3 and take square root of 10.",
817+
# baml_options={"client_registry": client_registry},
818+
# )
819+
# assert len(res) > 0, "Expected non-empty result but got empty."
820+
821+
# Test with gemini-2.5-pro-preview-03-25
822+
# client_registry.add_llm_client(
823+
# "Gemini25ProMarch", "google-ai", {"model": "gemini-2.5-pro-preview-03-25"}
824+
# )
825+
# client_registry.set_primary("Gemini25ProMarch")
826+
# res = await b.TestGemini(
827+
# input="sea. Actually just output a json object with the keys 'name' and 'age'.",
828+
# baml_options={"client_registry": client_registry},
829+
# )
830+
# assert len(res) > 0, "Expected non-empty result but got empty."
831+
832+
# Test with gemini-2.0-flash-thinking-exp-1219
833+
client_registry.add_llm_client(
834+
"GeminiFlashThinking",
835+
"google-ai",
836+
{"model": "gemini-2.0-flash-thinking-exp-1219"},
837+
)
838+
client_registry.set_primary("GeminiFlashThinking")
839+
res = await b.TestGemini(
840+
input="sea", baml_options={"client_registry": client_registry}
841+
)
842+
assert len(res) > 0, "Expected non-empty result but got empty."
843+
844+
795845
@pytest.mark.asyncio
796846
async def test_tracing_async_only():
797847
@trace

0 commit comments

Comments
 (0)