Skip to content

Commit 70c49bd

Browse files
committed
GenerativeModelBuilder
1 parent de19a97 commit 70c49bd

File tree

4 files changed

+98
-26
lines changed

4 files changed

+98
-26
lines changed

gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ public List<Model> listModels() {
6767
});
6868
}
6969

70-
/***
70+
/**
7171
* Get information of a model. Can be used to create a {@link GenerativeModel}.
72+
*
7273
* @param model of which the information is wanted.
7374
* @see #listModels()
7475
*/

gemini-api/src/main/java/swiss/ameri/gemini/api/GenerativeModel.java

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package swiss.ameri.gemini.api;
22

3+
import java.util.ArrayList;
34
import java.util.List;
45

56
/**
@@ -17,14 +18,86 @@ public record GenerativeModel(
1718
GenerationConfig generationConfig
1819
) {
1920

20-
public static GenerativeModel of(
21-
ModelVariant modelVariant,
22-
List<Content> contents,
23-
List<SafetySetting> safetySettings,
24-
GenerationConfig generationConfig
25-
) {
26-
// todo add builder, which accepts modelVariant
27-
return new GenerativeModel(modelVariant.variant(), contents, safetySettings, generationConfig);
21+
/**
22+
* Create a {@link GenerativeModelBuilder}.
23+
*
24+
* @return an empty {@link GenerativeModelBuilder}
25+
*/
26+
public static GenerativeModelBuilder builder() {
27+
return new GenerativeModelBuilder();
28+
}
29+
30+
/**
31+
* A builder for {@link GenerativeModel}. Currently, does not validate the fields when building the model. Not thread-safe.
32+
*/
33+
public static class GenerativeModelBuilder {
34+
private String modelName;
35+
private GenerationConfig generationConfig;
36+
private final List<Content> contents = new ArrayList<>();
37+
private final List<SafetySetting> safetySettings = new ArrayList<>();
38+
39+
/**
40+
* Set the model name.
41+
*
42+
* @param modelName to be set
43+
* @return this
44+
*/
45+
public GenerativeModelBuilder modelName(String modelName) {
46+
this.modelName = modelName;
47+
return this;
48+
}
49+
50+
/**
51+
* Set the model name.
52+
*
53+
* @param modelVariant to be set
54+
* @return this
55+
*/
56+
public GenerativeModelBuilder modelName(ModelVariant modelVariant) {
57+
return modelName(modelVariant == null ? null : modelVariant.variant());
58+
}
59+
60+
/**
61+
* Add content
62+
*
63+
* @param content to be added
64+
* @return this
65+
*/
66+
public GenerativeModelBuilder addContent(Content content) {
67+
this.contents.add(content);
68+
return this;
69+
}
70+
71+
/**
72+
* Add safety setting
73+
*
74+
* @param safetySetting to be added
75+
* @return this
76+
*/
77+
public GenerativeModelBuilder addSafetySetting(SafetySetting safetySetting) {
78+
this.safetySettings.add(safetySetting);
79+
return this;
80+
}
81+
82+
/**
83+
* Set the generation config
84+
*
85+
* @param generationConfig to be set
86+
* @return this
87+
*/
88+
public GenerativeModelBuilder generationConfig(GenerationConfig generationConfig) {
89+
this.generationConfig = generationConfig;
90+
return this;
91+
}
92+
93+
/**
94+
* Build the model based on this builder.
95+
*
96+
* @return a completed (not necessarily validated) {@link GenerativeModel}
97+
*/
98+
public GenerativeModel build() {
99+
return new GenerativeModel(modelName, contents, safetySettings, generationConfig);
100+
}
28101
}
29102

30103
}

gemini-api/src/main/java/swiss/ameri/gemini/api/SafetySetting.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ public static SafetySetting of(
1717
HarmBlockThreshold threshold
1818
) {
1919
return new SafetySetting(
20-
category.name(),
21-
threshold.name()
20+
category == null ? null : category.name(),
21+
threshold == null ? null : threshold.name()
2222
);
2323
}
2424

gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import swiss.ameri.gemini.gson.GsonJsonParser;
55
import swiss.ameri.gemini.spi.JsonParser;
66

7-
import java.util.List;
87
import java.util.concurrent.TimeUnit;
98

109
/**
@@ -38,27 +37,26 @@ public static void main(String[] args) throws Exception {
3837
);
3938

4039
System.out.println("-----");
41-
var model = GenerativeModel.of(
42-
ModelVariant.GEMINI_1_0_PRO,
43-
List.of(
44-
new Content.TextContent(
45-
Content.Role.USER.roleName(),
46-
"Write a 300 word story about a magic backpack."
47-
)
48-
),
49-
List.of(
50-
SafetySetting.of(SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH)
51-
),
52-
new GenerationConfig(
40+
var model = GenerativeModel.builder()
41+
.modelName(ModelVariant.GEMINI_1_0_PRO)
42+
.addContent(new Content.TextContent(
43+
Content.Role.USER.roleName(),
44+
"Write a 300 word story about a magic backpack."
45+
))
46+
.addSafetySetting(SafetySetting.of(
47+
SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
48+
SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH
49+
))
50+
.generationConfig(new GenerationConfig(
5351
null,
5452
null,
5553
null,
5654
null,
5755
null,
5856
null,
5957
null
60-
)
61-
);
58+
))
59+
.build();
6260
genAi.generateContent(model)
6361
.thenAccept(System.out::println)
6462
.get(20, TimeUnit.SECONDS);

0 commit comments

Comments
 (0)