Skip to content

Commit b55af84

Browse files
committed
Add Gemini as image comparison algorithm #473
related to Visual-Regression-Tracker/Visual-Regression-Tracker#473
1 parent 0c63fb2 commit b55af84

13 files changed

Lines changed: 1085 additions & 286 deletions

package-lock.json

Lines changed: 302 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"dependencies": {
3131
"@aws-sdk/client-s3": "^3.922.0",
3232
"@aws-sdk/s3-request-presigner": "^3.922.0",
33+
"@google/genai": "^1.34.0",
3334
"@nestjs/cache-manager": "^3.0.1",
3435
"@nestjs/common": "^11.1.8",
3536
"@nestjs/config": "^4.0.2",
@@ -65,7 +66,8 @@
6566
"rxjs": "^7.8.2",
6667
"swagger-ui-express": "^4.6.3",
6768
"uuid-apikey": "^1.5.3",
68-
"zod": "^4.2.1"
69+
"zod": "^4.2.1",
70+
"zod-to-json-schema": "^3.25.0"
6971
},
7072
"devDependencies": {
7173
"@darraghor/eslint-plugin-nestjs-typed": "^6.9.3",

src/compare/compare.module.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ import { LookSameService } from './libs/looks-same/looks-same.service';
44
import { OdiffService } from './libs/odiff/odiff.service';
55
import { PixelmatchService } from './libs/pixelmatch/pixelmatch.service';
66
import { VlmService } from './libs/vlm/vlm.service';
7-
import { OllamaController } from './libs/vlm/ollama.controller';
8-
import { OllamaService } from './libs/vlm/ollama.service';
7+
import { OllamaController } from './libs/vlm/providers/ollama/ollama.controller';
8+
import { OllamaService } from './libs/vlm/providers/ollama/ollama.service';
9+
import { GeminiService } from './libs/vlm/providers/gemini/gemini.service';
910
import { StaticModule } from '../static/static.module';
1011

1112
@Module({
1213
controllers: [OllamaController],
13-
providers: [CompareService, PixelmatchService, LookSameService, OdiffService, VlmService, OllamaService],
14+
providers: [CompareService, PixelmatchService, LookSameService, OdiffService, VlmService, OllamaService, GeminiService],
1415
imports: [StaticModule],
1516
exports: [CompareService],
1617
})

src/compare/compare.service.spec.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import { LookSameService } from './libs/looks-same/looks-same.service';
66
import { OdiffService } from './libs/odiff/odiff.service';
77
import { PixelmatchService } from './libs/pixelmatch/pixelmatch.service';
88
import { VlmService } from './libs/vlm/vlm.service';
9-
import { OllamaService } from './libs/vlm/ollama.service';
9+
import { OllamaService } from './libs/vlm/providers/ollama/ollama.service';
10+
import { GeminiService } from './libs/vlm/providers/gemini/gemini.service';
1011
import { StaticModule } from '../static/static.module';
1112
import { ImageComparison } from '@prisma/client';
1213
import * as utils from '../static/utils';
@@ -26,6 +27,7 @@ describe('CompareService', () => {
2627
LookSameService,
2728
VlmService,
2829
OllamaService,
30+
GeminiService,
2931
PrismaService,
3032
{
3133
provide: ConfigService,
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import { Test, TestingModule } from '@nestjs/testing';
2+
import { GeminiService } from './gemini.service';
3+
import { GeminiVlmConfig } from '../../vlm.types';
4+
5+
// Mock the @google/genai module
6+
const mockGenerateContent = jest.fn();
7+
8+
jest.mock('@google/genai', () => {
9+
return {
10+
GoogleGenAI: jest.fn().mockImplementation(() => ({
11+
models: {
12+
generateContent: mockGenerateContent,
13+
},
14+
})),
15+
};
16+
});
17+
18+
describe('GeminiService', () => {
19+
let service: GeminiService;
20+
21+
beforeEach(async () => {
22+
jest.clearAllMocks();
23+
24+
const module: TestingModule = await Test.createTestingModule({
25+
providers: [GeminiService],
26+
}).compile();
27+
28+
service = module.get<GeminiService>(GeminiService);
29+
});
30+
31+
const createConfig = (overrides?: Partial<GeminiVlmConfig>): GeminiVlmConfig => ({
32+
provider: 'gemini',
33+
model: 'gemini-1.5-pro',
34+
prompt: 'Test prompt',
35+
temperature: 0.1,
36+
apiKey: 'test-api-key',
37+
...overrides,
38+
});
39+
40+
const createMockResponse = (text: string) => ({
41+
text,
42+
});
43+
44+
describe('generate', () => {
45+
it('should call Gemini SDK with correct parameters and return VlmProviderResponse', async () => {
46+
const config = createConfig();
47+
const testBytes = new Uint8Array([1, 2, 3, 4]);
48+
const mockResponse = createMockResponse('{"identical": true, "description": "No differences"}');
49+
mockGenerateContent.mockResolvedValue(mockResponse);
50+
51+
const result = await service.generate(config, [testBytes]);
52+
53+
expect(mockGenerateContent).toHaveBeenCalledWith({
54+
model: config.model,
55+
contents: [
56+
{ text: config.prompt },
57+
{
58+
inlineData: {
59+
data: expect.any(String),
60+
mimeType: 'image/png',
61+
},
62+
},
63+
],
64+
config: {
65+
temperature: config.temperature,
66+
responseMimeType: 'application/json',
67+
responseJsonSchema: expect.any(Object),
68+
},
69+
});
70+
expect(result.content).toBe('{"identical": true, "description": "No differences"}');
71+
});
72+
73+
it.each([
74+
['single image', [new Uint8Array([137, 80, 78, 71])], 2],
75+
['multiple images', [new Uint8Array([1, 2, 3]), new Uint8Array([4, 5, 6]), new Uint8Array([7, 8, 9])], 4],
76+
])('should handle %s and convert to base64', async (_, images, expectedPartsCount) => {
77+
const config = createConfig();
78+
const mockResponse = createMockResponse('{"identical": true}');
79+
mockGenerateContent.mockResolvedValue(mockResponse);
80+
81+
await service.generate(config, images);
82+
83+
const callArgs = mockGenerateContent.mock.calls[0][0];
84+
expect(callArgs.contents.length).toBe(expectedPartsCount);
85+
86+
// Verify first image is base64 encoded
87+
if (images.length > 0) {
88+
const imagePart = callArgs.contents[1];
89+
expect(imagePart.inlineData.mimeType).toBe('image/png');
90+
expect(imagePart.inlineData.data).toBe(Buffer.from(images[0]).toString('base64'));
91+
}
92+
});
93+
94+
it('should always include hardcoded JSON schema', async () => {
95+
const config = createConfig();
96+
const mockResponse = createMockResponse('{"identical": true}');
97+
mockGenerateContent.mockResolvedValue(mockResponse);
98+
99+
await service.generate(config, []);
100+
101+
const callArgs = mockGenerateContent.mock.calls[0][0];
102+
expect(callArgs.config.responseMimeType).toBe('application/json');
103+
expect(callArgs.config.responseJsonSchema).toEqual(
104+
expect.objectContaining({
105+
type: 'object',
106+
properties: expect.objectContaining({
107+
identical: expect.any(Object),
108+
description: expect.any(Object),
109+
}),
110+
})
111+
);
112+
});
113+
114+
it.each([
115+
['API key is missing', { apiKey: '' }, 'Gemini API key is required'],
116+
['SDK call fails', { apiKey: 'test-api-key' }, 'API Error'],
117+
])('should throw error when %s', async (_, overrides, expectedError) => {
118+
const config = createConfig(overrides);
119+
120+
if (expectedError === 'API Error') {
121+
mockGenerateContent.mockRejectedValue(new Error(expectedError));
122+
}
123+
124+
await expect(service.generate(config, [])).rejects.toThrow(expectedError);
125+
});
126+
});
127+
});
128+
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import { Injectable, Logger } from '@nestjs/common';
2+
import { GoogleGenAI } from '@google/genai';
3+
import { zodToJsonSchema } from 'zod-to-json-schema';
4+
import { VlmProvider, VlmProviderResponse } from '../../vlm-provider.interface';
5+
import { VlmConfig, GeminiVlmConfig } from '../../vlm.types';
6+
import { VlmComparisonResultSchema } from '../../vlm.service';
7+
8+
@Injectable()
9+
export class GeminiService implements VlmProvider {
10+
private readonly logger: Logger = new Logger(GeminiService.name);
11+
12+
private getGenAI(apiKey: string): GoogleGenAI {
13+
if (!apiKey) {
14+
throw new Error('Gemini API key is required');
15+
}
16+
return new GoogleGenAI({ apiKey });
17+
}
18+
19+
private imageToBase64(imageBytes: Uint8Array): string {
20+
const base64 = Buffer.from(imageBytes).toString('base64');
21+
return base64;
22+
}
23+
24+
async generate(config: VlmConfig, images: Uint8Array[]): Promise<VlmProviderResponse> {
25+
// Type guard: ensure this is Gemini config
26+
if (config.provider !== 'gemini') {
27+
throw new Error(`GeminiService requires GeminiVlmConfig, got ${config.provider}`);
28+
}
29+
30+
const geminiConfig: GeminiVlmConfig = config;
31+
32+
const genAI = this.getGenAI(geminiConfig.apiKey);
33+
34+
try {
35+
const imageParts = images.map((img) => ({
36+
inlineData: {
37+
data: this.imageToBase64(img),
38+
mimeType: 'image/png',
39+
},
40+
}));
41+
42+
const parts = [
43+
{ text: geminiConfig.prompt },
44+
...imageParts,
45+
];
46+
47+
const result = await genAI.models.generateContent({
48+
model: geminiConfig.model,
49+
contents: parts,
50+
config: {
51+
temperature: geminiConfig.temperature,
52+
responseMimeType: 'application/json' as const,
53+
responseJsonSchema: zodToJsonSchema(VlmComparisonResultSchema),
54+
},
55+
});
56+
57+
return {
58+
content: result.text,
59+
};
60+
} catch (error) {
61+
this.logger.error(`Gemini generate request failed: ${error.message}`, error.stack);
62+
throw error;
63+
}
64+
}
65+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import {
2+
Controller,
3+
Get,
4+
Post,
5+
Query,
6+
HttpException,
7+
HttpStatus,
8+
UseInterceptors,
9+
UploadedFiles,
10+
} from '@nestjs/common';
11+
import { FilesInterceptor } from '@nestjs/platform-express';
12+
import { ApiTags, ApiConsumes, ApiBody } from '@nestjs/swagger';
13+
import { OllamaService } from './ollama.service';
14+
import { VlmProviderResponse } from '../../vlm-provider.interface';
15+
import { OllamaVlmConfig } from '../../vlm.types';
16+
17+
@ApiTags('Ollama')
18+
@Controller('ollama')
19+
export class OllamaController {
20+
constructor(private readonly ollamaService: OllamaService) {}
21+
22+
@Get('models')
23+
async listModels() {
24+
return { models: await this.ollamaService.listModels() };
25+
}
26+
27+
@Post('compare')
28+
@ApiConsumes('multipart/form-data')
29+
@ApiBody({
30+
schema: {
31+
type: 'object',
32+
required: ['images'],
33+
properties: {
34+
images: {
35+
type: 'array',
36+
items: { type: 'string', format: 'binary' },
37+
description: 'Two images to compare (baseline and comparison)',
38+
},
39+
},
40+
},
41+
})
42+
@UseInterceptors(FilesInterceptor('images', 2))
43+
async compareImages(
44+
@UploadedFiles() files: Express.Multer.File[],
45+
@Query('model') model: string,
46+
@Query('prompt') prompt: string,
47+
@Query('temperature') temperature: string
48+
): Promise<VlmProviderResponse> {
49+
if (files?.length !== 2) {
50+
throw new HttpException('Two images required', HttpStatus.BAD_REQUEST);
51+
}
52+
53+
const config: OllamaVlmConfig = {
54+
provider: 'ollama',
55+
model,
56+
prompt,
57+
temperature: Number(temperature),
58+
};
59+
const images = files.map((f) => new Uint8Array(f.buffer));
60+
61+
return this.ollamaService.generate(config, images);
62+
}
63+
}
64+

0 commit comments

Comments
 (0)