Skip to content

Commit 8a03441

Browse files
committed
#3 fix responseSchema type
1 parent 9b2682d commit 8a03441

File tree

5 files changed

+266
-12
lines changed

5 files changed

+266
-12
lines changed

gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88
import java.net.http.HttpClient;
99
import java.net.http.HttpRequest;
1010
import java.net.http.HttpResponse;
11-
import java.util.List;
12-
import java.util.Map;
13-
import java.util.Optional;
14-
import java.util.UUID;
11+
import java.util.*;
1512
import java.util.concurrent.CompletableFuture;
1613
import java.util.concurrent.ConcurrentHashMap;
1714
import java.util.stream.Collectors;
@@ -145,7 +142,10 @@ public List<SafetyRating> safetyRatings(UUID id) {
145142
return emptyList();
146143
}
147144
return response.candidates().stream()
148-
.flatMap(candidate -> candidate.safetyRatings().stream())
145+
// when streaming, we sometimes don't get a safety rating... (with 1.5 pro)
146+
.map(ResponseCandidate::safetyRatings)
147+
.filter(Objects::nonNull)
148+
.flatMap(Collection::stream)
149149
.toList();
150150
}
151151

gemini-api/src/main/java/swiss/ameri/gemini/api/GenerationConfig.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
public record GenerationConfig(
3737
List<String> stopSequences,
3838
String responseMimeType,
39-
String responseSchema,
39+
Schema responseSchema,
4040
Integer maxOutputTokens,
4141
Double temperature,
4242
Double topP,
@@ -53,7 +53,7 @@ public static GenerationConfigBuilder builder() {
5353
public static class GenerationConfigBuilder {
5454
private final List<String> stopSequences = new ArrayList<>();
5555
private String responseMimeType;
56-
private String responseSchema;
56+
private Schema responseSchema;
5757
private Integer maxOutputTokens;
5858
private Double temperature;
5959
private Double topP;
@@ -69,7 +69,7 @@ public GenerationConfigBuilder responseMimeType(String responseMimeType) {
6969
return this;
7070
}
7171

72-
public GenerationConfigBuilder responseSchema(String responseSchema) {
72+
public GenerationConfigBuilder responseSchema(Schema responseSchema) {
7373
this.responseSchema = responseSchema;
7474
return this;
7575
}
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
package swiss.ameri.gemini.api;
2+
3+
4+
import java.util.List;
5+
import java.util.Map;
6+
7+
/**
8+
* The Schema object allows the definition of input and output data types.
9+
* These types can be objects, but also primitives and arrays.
10+
* Represents a select subset of an OpenAPI 3.0 schema object.
11+
*
12+
* @param type Required. Data type.
13+
* @param format Optional. The format of the data. This is used only for primitive datatypes.
14+
* Supported formats:
15+
* for NUMBER type: float, double
16+
* for INTEGER type: int32, int64
17+
* for STRING type: enum
18+
* @param description Optional. A brief description of the parameter. This could contain examples of use.
19+
* Parameter description may be formatted as Markdown.
20+
* @param nullable Optional. Indicates if the value may be null.
21+
* @param ameri_swiss_enum Optional. <b>Note: the ameri_swiss prefix must be removed by the {@link swiss.ameri.gemini.spi.JsonParser}</b>.
22+
* Possible values of the element of Type.STRING with enum format.
23+
* For example we can define an Enum Direction as :
24+
* <code>{type:STRING, format:enum, enum:["EAST", NORTH", "SOUTH", "WEST"]}</code>
25+
* @param maxItems Optional. Maximum number of the elements for Type.ARRAY.
26+
* @param minItems Optional. Minimum number of the elements for Type.ARRAY.
27+
* @param properties Optional. Properties of Type.OBJECT.
28+
* An object containing a list of "key": value pairs. Example:
29+
* <code>{ "name": "wrench", "mass": "1.3kg", "count": "3" }</code>.
30+
* @param required Optional. Required properties of Type.OBJECT.
31+
* @param items Optional. Schema of the elements of Type.ARRAY.
32+
* @see <a href="https://ai.google.dev/api/caching#Schema">Schema</a> for further information.
33+
*/
34+
public record Schema(
35+
Type type,
36+
String format,
37+
String description,
38+
Boolean nullable,
39+
List<String> ameri_swiss_enum,
40+
String maxItems,
41+
String minItems,
42+
Map<String, Schema> properties,
43+
List<String> required,
44+
Schema items
45+
) {
46+
47+
48+
/**
49+
* Create a {@link SchemaBuilder}.
50+
*
51+
* @return an empty {@link SchemaBuilder}
52+
*/
53+
public static SchemaBuilder builder() {
54+
return new SchemaBuilder();
55+
}
56+
57+
/**
58+
* A builder for {@link Schema}. Currently, does not validate the fields when building the model. Not thread-safe.
59+
*/
60+
public static class SchemaBuilder {
61+
private Type type;
62+
private String format;
63+
private String description;
64+
private Boolean nullable;
65+
private List<String> ameri_swiss_enum;
66+
private String maxItems;
67+
private String minItems;
68+
private Map<String, Schema> properties;
69+
private List<String> required;
70+
private Schema items;
71+
72+
73+
private SchemaBuilder() {
74+
}
75+
76+
public Schema build() {
77+
return new Schema(
78+
this.type,
79+
this.format,
80+
this.description,
81+
this.nullable,
82+
this.ameri_swiss_enum,
83+
this.maxItems,
84+
this.minItems,
85+
this.properties,
86+
this.required,
87+
this.items
88+
);
89+
}
90+
91+
public SchemaBuilder type(Type type) {
92+
this.type = type;
93+
return this;
94+
}
95+
96+
public SchemaBuilder format(String format) {
97+
this.format = format;
98+
return this;
99+
}
100+
101+
public SchemaBuilder description(String description) {
102+
this.description = description;
103+
return this;
104+
}
105+
106+
public SchemaBuilder nullable(Boolean nullable) {
107+
this.nullable = nullable;
108+
return this;
109+
}
110+
111+
public SchemaBuilder ameri_swiss_enum(List<String> ameri_swiss_enum) {
112+
this.ameri_swiss_enum = ameri_swiss_enum;
113+
return this;
114+
}
115+
116+
public SchemaBuilder maxItems(String maxItems) {
117+
this.maxItems = maxItems;
118+
return this;
119+
}
120+
121+
public SchemaBuilder minItems(String minItems) {
122+
this.minItems = minItems;
123+
return this;
124+
}
125+
126+
public SchemaBuilder properties(Map<String, Schema> properties) {
127+
this.properties = properties;
128+
return this;
129+
}
130+
131+
public SchemaBuilder required(List<String> required) {
132+
this.required = required;
133+
return this;
134+
}
135+
136+
public SchemaBuilder items(Schema items) {
137+
this.items = items;
138+
return this;
139+
}
140+
}
141+
142+
/**
143+
* Type contains the list of OpenAPI data types.
144+
*
145+
* @see <a href="https://ai.google.dev/api/caching#Type">Data types</a>
146+
*/
147+
public enum Type {
148+
/**
149+
* Not specified, should not be used.
150+
*/
151+
TYPE_UNSPECIFIED,
152+
/**
153+
* String type.
154+
*/
155+
STRING,
156+
/**
157+
* Number type.
158+
*/
159+
NUMBER,
160+
/**
161+
* Integer type.
162+
*/
163+
INTEGER,
164+
/**
165+
* Boolean type.
166+
*/
167+
BOOLEAN,
168+
/**
169+
* Array type.
170+
*/
171+
ARRAY,
172+
/**
173+
* Object type.
174+
*/
175+
OBJECT
176+
}
177+
}
178+

gemini-gson/src/main/java/swiss/ameri/gemini/gson/GsonJsonParser.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
11
package swiss.ameri.gemini.gson;
22

3+
import com.google.gson.FieldNamingStrategy;
34
import com.google.gson.Gson;
5+
import com.google.gson.GsonBuilder;
6+
import swiss.ameri.gemini.api.Schema;
47
import swiss.ameri.gemini.spi.JsonParser;
58

69
/**
710
* Reference implementation of {@link JsonParser} using {@link Gson} dependency.
811
*/
912
public class GsonJsonParser implements JsonParser {
1013

14+
/**
15+
* Field naming strategy to avoid usage of illegal field names in java.
16+
* See e.g. {@link Schema#ameri_swiss_enum()}, which cannot be named {@code enum}.
17+
*/
18+
public static final FieldNamingStrategy FIELD_NAMING_STRATEGY = field -> {
19+
if (field.getName().startsWith("ameri_swiss_")) {
20+
return field.getName().substring("ameri_swiss_".length());
21+
}
22+
return field.getName();
23+
};
24+
1125
private final Gson gson;
1226

1327
/**
@@ -23,7 +37,7 @@ public GsonJsonParser(Gson gson) {
2337
* Create a default {@link JsonParser} instance.
2438
*/
2539
public GsonJsonParser() {
26-
this(new Gson());
40+
this(new GsonBuilder().setFieldNamingStrategy(FIELD_NAMING_STRATEGY).create());
2741
}
2842

2943
@Override

gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.io.InputStream;
99
import java.util.Base64;
1010
import java.util.List;
11+
import java.util.Map;
1112
import java.util.concurrent.ExecutionException;
1213
import java.util.concurrent.TimeUnit;
1314
import java.util.concurrent.TimeoutException;
@@ -38,6 +39,8 @@ public static void main(String[] args) throws Exception {
3839
countTokens(genAi);
3940
generateContent(genAi);
4041
generateContentStream(genAi);
42+
generateWithResponseSchema(genAi);
43+
generateContentStreamWithResponseSchema(genAi);
4144
multiChatTurn(genAi);
4245
textAndImage(genAi);
4346
embedContents(genAi);
@@ -90,7 +93,7 @@ private static void countTokens(GenAi genAi) {
9093
private static void multiChatTurn(GenAi genAi) {
9194
System.out.println("----- multi turn chat");
9295
GenerativeModel chatModel = GenerativeModel.builder()
93-
.modelName(ModelVariant.GEMINI_1_0_PRO)
96+
.modelName(ModelVariant.GEMINI_1_5_PRO)
9497
.addContent(new Content.TextContent(
9598
Content.Role.USER.roleName(),
9699
"Write the first line of a story about a magic backpack."
@@ -133,9 +136,68 @@ private static void generateContent(GenAi genAi) throws InterruptedException, Ex
133136
.get(20, TimeUnit.SECONDS);
134137
}
135138

139+
140+
private static void generateContentStreamWithResponseSchema(GenAi genAi) {
141+
System.out.println("----- Generate content (streaming) with response schema -- with usage meta data");
142+
var model = createResponseSchemaModel();
143+
genAi.generateContentStream(model)
144+
.forEach(x -> {
145+
System.out.println(x);
146+
// note that the usage metadata is updated as it arrives
147+
System.out.println(genAi.usageMetadata(x.id()));
148+
System.out.println(genAi.safetyRatings(x.id()));
149+
});
150+
}
151+
152+
private static void generateWithResponseSchema(GenAi genAi) throws InterruptedException, ExecutionException, TimeoutException {
153+
var model = createResponseSchemaModel();
154+
System.out.println("----- Generate with response schema (blocking)");
155+
genAi.generateContent(model)
156+
.thenAccept(gcr -> {
157+
System.out.println(gcr);
158+
System.out.println("----- Generate with response schema (blocking) usage meta data & safety ratings");
159+
System.out.println(genAi.usageMetadata(gcr.id()));
160+
System.out.println(genAi.safetyRatings(gcr.id()).stream().map(GenAi.SafetyRating::toTypedSafetyRating).toList());
161+
})
162+
.get(20, TimeUnit.SECONDS);
163+
}
164+
165+
private static GenerativeModel createResponseSchemaModel() {
166+
return GenerativeModel.builder()
167+
.modelName(ModelVariant.GEMINI_1_5_FLASH)
168+
.addContent(Content.textContent(
169+
Content.Role.USER,
170+
"List 3 popular cookie recipes."
171+
))
172+
.addSafetySetting(SafetySetting.of(
173+
SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
174+
SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH
175+
))
176+
.generationConfig(new GenerationConfig(
177+
null,
178+
"application/json",
179+
Schema.builder()
180+
.type(Schema.Type.ARRAY)
181+
.items(Schema.builder()
182+
.type(Schema.Type.OBJECT)
183+
.properties(Map.of(
184+
"recipe_name", Schema.builder()
185+
.type(Schema.Type.STRING)
186+
.build()
187+
))
188+
.build())
189+
.build(),
190+
null,
191+
null,
192+
null,
193+
null
194+
))
195+
.build();
196+
}
197+
136198
private static GenerativeModel createStoryModel() {
137199
return GenerativeModel.builder()
138-
.modelName(ModelVariant.GEMINI_1_0_PRO)
200+
.modelName(ModelVariant.GEMINI_1_5_PRO)
139201
.addContent(Content.textContent(
140202
Content.Role.USER,
141203
"Write a 50 word story about a magic backpack."
@@ -159,7 +221,7 @@ private static GenerativeModel createStoryModel() {
159221
private static void getModel(GenAi genAi) {
160222
System.out.println("----- Get Model");
161223
System.out.println(
162-
genAi.getModel(ModelVariant.GEMINI_1_0_PRO)
224+
genAi.getModel(ModelVariant.GEMINI_1_5_PRO)
163225
);
164226
}
165227

0 commit comments

Comments
 (0)