Skip to content

Commit c3a19a1

Browse files
committed
Make modelName optional if providing ChatModel or StreamingChatModel
1 parent 0a40557 commit c3a19a1

2 files changed

Lines changed: 76 additions & 1 deletion

File tree

contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
import java.util.HashMap;
7777
import java.util.List;
7878
import java.util.Map;
79+
import java.util.Objects;
7980
import java.util.UUID;
8081
import org.jspecify.annotations.Nullable;
8182
import org.slf4j.Logger;
@@ -100,6 +101,7 @@ public abstract class LangChain4j extends BaseLlm {
100101

101102
public abstract ObjectMapper objectMapper();
102103

104+
@Nullable
103105
public abstract String modelName();
104106

105107
@Nullable
@@ -126,7 +128,32 @@ public abstract static class Builder {
126128

127129
public abstract Builder modelName(String modelName);
128130

129-
public abstract LangChain4j build();
131+
abstract @Nullable ChatModel chatModel();
132+
133+
abstract @Nullable StreamingChatModel streamingChatModel();
134+
135+
abstract @Nullable String modelName();
136+
137+
abstract LangChain4j autoBuild();
138+
139+
public LangChain4j build() {
140+
if (Objects.isNull(modelName())) {
141+
// Try to extract modelName from chatModel or streamingChatModel
142+
if (!Objects.isNull(chatModel())
143+
&& !Objects.isNull(chatModel().defaultRequestParameters())) {
144+
modelName(chatModel().defaultRequestParameters().modelName());
145+
} else if (!Objects.isNull(streamingChatModel())
146+
&& !Objects.isNull(streamingChatModel().defaultRequestParameters())) {
147+
modelName(streamingChatModel().defaultRequestParameters().modelName());
148+
}
149+
}
150+
// Up to this step, if modelName still null - Fail fast
151+
if (modelName() == null) {
152+
throw new IllegalStateException(
153+
"modelName is required. Either set it explicitly via modelName() or provide a ChatModel/StreamingChatModel");
154+
}
155+
return autoBuild();
156+
}
130157
}
131158

132159
public LangChain4j(ChatModel chatModel) {

contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import dev.langchain4j.model.chat.ChatModel;
3232
import dev.langchain4j.model.chat.StreamingChatModel;
3333
import dev.langchain4j.model.chat.request.ChatRequest;
34+
import dev.langchain4j.model.chat.request.ChatRequestParameters;
3435
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
3536
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
3637
import dev.langchain4j.model.chat.response.ChatResponse;
@@ -1145,4 +1146,51 @@ void testGenerateContentWithMalformedCharsetFallback() {
11451146

11461147
assertThat(textContent.text()).isEqualTo(textPayload);
11471148
}
1149+
1150+
@DisplayName("Should auto-detect model name from ChatModel when not explicitly set")
1151+
void testBuilderAutoDetectsModelNameFromChatModel() {
1152+
// Given
1153+
final ChatRequestParameters params = mock(ChatRequestParameters.class);
1154+
when(params.modelName()).thenReturn("auto-detected-model");
1155+
when(chatModel.defaultRequestParameters()).thenReturn(params);
1156+
1157+
// When
1158+
final LangChain4j lc4j = LangChain4j.builder().chatModel(chatModel).build();
1159+
1160+
// Then
1161+
assertThat(lc4j.modelName()).isEqualTo("auto-detected-model");
1162+
assertThat(lc4j.model()).isEqualTo("auto-detected-model");
1163+
}
1164+
1165+
@Test
1166+
@DisplayName("Should auto-detect model name from StreamingChatModel when not explicitly set")
1167+
void testBuilderAutoDetectsModelNameFromStreamingChatModel() {
1168+
// Given
1169+
final ChatRequestParameters params = mock(ChatRequestParameters.class);
1170+
when(params.modelName()).thenReturn("auto-detected-streaming-model");
1171+
when(streamingChatModel.defaultRequestParameters()).thenReturn(params);
1172+
1173+
// When
1174+
final LangChain4j lc4j = LangChain4j.builder().streamingChatModel(streamingChatModel).build();
1175+
1176+
// Then
1177+
assertThat(lc4j.modelName()).isEqualTo("auto-detected-streaming-model");
1178+
assertThat(lc4j.model()).isEqualTo("auto-detected-streaming-model");
1179+
}
1180+
1181+
@Test
1182+
@DisplayName("Should prefer explicit model name over auto-detected one")
1183+
void testBuilderPrefersExplicitModelName() {
1184+
// Given
1185+
final ChatRequestParameters params = mock(ChatRequestParameters.class);
1186+
when(params.modelName()).thenReturn("auto-detected-model");
1187+
when(chatModel.defaultRequestParameters()).thenReturn(params);
1188+
1189+
// When
1190+
final LangChain4j lc4j =
1191+
LangChain4j.builder().chatModel(chatModel).modelName("explicit-model").build();
1192+
1193+
// Then
1194+
assertThat(lc4j.modelName()).isEqualTo("explicit-model");
1195+
}
11481196
}

0 commit comments

Comments
 (0)