Skip to content

Commit 5ad4bb2

Browse files
authored
Add Gemini as image comparison algorithm #473 (#340)
* Add Gemini as image comparison algorithm #473 related to Visual-Regression-Tracker/Visual-Regression-Tracker#473
1 parent 0c63fb2 commit 5ad4bb2

15 files changed

Lines changed: 1099 additions & 296 deletions

package-lock.json

Lines changed: 302 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"dependencies": {
3131
"@aws-sdk/client-s3": "^3.922.0",
3232
"@aws-sdk/s3-request-presigner": "^3.922.0",
33+
"@google/genai": "^1.34.0",
3334
"@nestjs/cache-manager": "^3.0.1",
3435
"@nestjs/common": "^11.1.8",
3536
"@nestjs/config": "^4.0.2",
@@ -65,7 +66,8 @@
6566
"rxjs": "^7.8.2",
6667
"swagger-ui-express": "^4.6.3",
6768
"uuid-apikey": "^1.5.3",
68-
"zod": "^4.2.1"
69+
"zod": "^4.2.1",
70+
"zod-to-json-schema": "^3.25.0"
6971
},
7072
"devDependencies": {
7173
"@darraghor/eslint-plugin-nestjs-typed": "^6.9.3",

src/app.module.ts

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { Module } from '@nestjs/common';
2-
import { CacheInterceptor, CacheModule } from '@nestjs/cache-manager';
32
import { AppService } from './app.service';
43
import { AuthModule } from './auth/auth.module';
54
import { UsersModule } from './users/users.module';
@@ -9,7 +8,7 @@ import { TestRunsModule } from './test-runs/test-runs.module';
98
import { TestVariationsModule } from './test-variations/test-variations.module';
109
import { PrismaService } from './prisma/prisma.service';
1110
import { ConfigModule } from '@nestjs/config';
12-
import { APP_FILTER, APP_INTERCEPTOR } from '@nestjs/core';
11+
import { APP_FILTER } from '@nestjs/core';
1312
import { HttpExceptionFilter } from './http-exception.filter';
1413
import { CompareModule } from './compare/compare.module';
1514
import { ScheduleModule } from '@nestjs/schedule';
@@ -19,7 +18,6 @@ import { HealthController } from './health/health.controller';
1918
@Module({
2019
imports: [
2120
ConfigModule.forRoot({ isGlobal: true }),
22-
CacheModule.register(),
2321
ScheduleModule.forRoot(),
2422
AuthModule,
2523
UsersModule,
@@ -37,10 +35,6 @@ import { HealthController } from './health/health.controller';
3735
provide: APP_FILTER,
3836
useClass: HttpExceptionFilter,
3937
},
40-
{
41-
provide: APP_INTERCEPTOR,
42-
useClass: CacheInterceptor,
43-
},
4438
],
4539
controllers: [HealthController],
4640
})

src/compare/compare.module.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,22 @@ import { LookSameService } from './libs/looks-same/looks-same.service';
44
import { OdiffService } from './libs/odiff/odiff.service';
55
import { PixelmatchService } from './libs/pixelmatch/pixelmatch.service';
66
import { VlmService } from './libs/vlm/vlm.service';
7-
import { OllamaController } from './libs/vlm/ollama.controller';
8-
import { OllamaService } from './libs/vlm/ollama.service';
7+
import { OllamaController } from './libs/vlm/providers/ollama/ollama.controller';
8+
import { OllamaService } from './libs/vlm/providers/ollama/ollama.service';
9+
import { GeminiService } from './libs/vlm/providers/gemini/gemini.service';
910
import { StaticModule } from '../static/static.module';
1011

1112
@Module({
1213
controllers: [OllamaController],
13-
providers: [CompareService, PixelmatchService, LookSameService, OdiffService, VlmService, OllamaService],
14+
providers: [
15+
CompareService,
16+
PixelmatchService,
17+
LookSameService,
18+
OdiffService,
19+
VlmService,
20+
OllamaService,
21+
GeminiService,
22+
],
1423
imports: [StaticModule],
1524
exports: [CompareService],
1625
})

src/compare/compare.service.spec.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@ import { LookSameService } from './libs/looks-same/looks-same.service';
66
import { OdiffService } from './libs/odiff/odiff.service';
77
import { PixelmatchService } from './libs/pixelmatch/pixelmatch.service';
88
import { VlmService } from './libs/vlm/vlm.service';
9-
import { OllamaService } from './libs/vlm/ollama.service';
9+
import { OllamaService } from './libs/vlm/providers/ollama/ollama.service';
10+
import { GeminiService } from './libs/vlm/providers/gemini/gemini.service';
1011
import { StaticModule } from '../static/static.module';
1112
import { ImageComparison } from '@prisma/client';
1213
import * as utils from '../static/utils';
1314

15+
jest.mock('zod/v3', () => {
16+
const actualZod = jest.requireActual('zod');
17+
return actualZod;
18+
});
19+
1420
describe('CompareService', () => {
1521
let service: CompareService;
1622
let pixelmatchService: PixelmatchService;
@@ -26,6 +32,7 @@ describe('CompareService', () => {
2632
LookSameService,
2733
VlmService,
2834
OllamaService,
35+
GeminiService,
2936
PrismaService,
3037
{
3138
provide: ConfigService,

src/compare/libs/vlm/ollama.service.spec.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import { Test, TestingModule } from '@nestjs/testing';
22
import { ConfigService } from '@nestjs/config';
33
import { OllamaService } from './ollama.service';
44

5-
// Mock the ollama module
65
const mockChat = jest.fn();
76
const mockList = jest.fn();
87

@@ -20,7 +19,6 @@ describe('OllamaService', () => {
2019
let service: OllamaService;
2120

2221
beforeEach(async () => {
23-
// Reset mocks
2422
jest.clearAllMocks();
2523

2624
const module: TestingModule = await Test.createTestingModule({
@@ -100,7 +98,6 @@ describe('OllamaService', () => {
10098
};
10199
mockChat.mockResolvedValue(mockResponse);
102100

103-
// Use a longer base64 string
104101
const longBase64 =
105102
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==';
106103
const result = await service.generate({
@@ -109,7 +106,7 @@ describe('OllamaService', () => {
109106
{
110107
role: 'user',
111108
content: 'Test prompt',
112-
images: [longBase64], // base64 string - passed through as-is
109+
images: [longBase64],
113110
},
114111
],
115112
});
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import { Test, TestingModule } from '@nestjs/testing';
2+
import { GeminiService } from './gemini.service';
3+
import { GeminiVlmConfig } from '../../vlm.types';
4+
5+
jest.mock('zod/v3', () => {
6+
const actualZod = jest.requireActual('zod');
7+
return actualZod;
8+
});
9+
10+
const mockGenerateContent = jest.fn();
11+
12+
jest.mock('@google/genai', () => {
13+
return {
14+
GoogleGenAI: jest.fn().mockImplementation(() => ({
15+
models: {
16+
generateContent: mockGenerateContent,
17+
},
18+
})),
19+
};
20+
});
21+
22+
describe('GeminiService', () => {
23+
let service: GeminiService;
24+
25+
beforeEach(async () => {
26+
jest.clearAllMocks();
27+
28+
const module: TestingModule = await Test.createTestingModule({
29+
providers: [GeminiService],
30+
}).compile();
31+
32+
service = module.get<GeminiService>(GeminiService);
33+
});
34+
35+
const createConfig = (overrides?: Partial<GeminiVlmConfig>): GeminiVlmConfig => ({
36+
provider: 'gemini',
37+
model: 'gemini-1.5-pro',
38+
prompt: 'Test prompt',
39+
temperature: 0.1,
40+
apiKey: 'test-api-key',
41+
...overrides,
42+
});
43+
44+
const createMockResponse = (text: string) => ({
45+
text,
46+
});
47+
48+
describe('generate', () => {
49+
it('should call Gemini SDK with correct parameters and return VlmProviderResponse', async () => {
50+
const config = createConfig();
51+
const testBytes = new Uint8Array([1, 2, 3, 4]);
52+
const mockResponse = createMockResponse('{"identical": true, "description": "No differences"}');
53+
mockGenerateContent.mockResolvedValue(mockResponse);
54+
55+
const result = await service.generate(config, [testBytes]);
56+
57+
expect(mockGenerateContent).toHaveBeenCalledWith({
58+
model: config.model,
59+
contents: [
60+
{ text: config.prompt },
61+
{
62+
inlineData: {
63+
data: expect.any(String),
64+
mimeType: 'image/png',
65+
},
66+
},
67+
],
68+
config: {
69+
temperature: config.temperature,
70+
responseMimeType: 'application/json',
71+
responseJsonSchema: expect.any(Object),
72+
},
73+
});
74+
expect(result.content).toBe('{"identical": true, "description": "No differences"}');
75+
});
76+
77+
it.each([
78+
['single image', [new Uint8Array([137, 80, 78, 71])], 2],
79+
['multiple images', [new Uint8Array([1, 2, 3]), new Uint8Array([4, 5, 6]), new Uint8Array([7, 8, 9])], 4],
80+
])('should handle %s and convert to base64', async (_, images, expectedPartsCount) => {
81+
const config = createConfig();
82+
const mockResponse = createMockResponse('{"identical": true}');
83+
mockGenerateContent.mockResolvedValue(mockResponse);
84+
85+
await service.generate(config, images);
86+
87+
const callArgs = mockGenerateContent.mock.calls[0][0];
88+
expect(callArgs.contents.length).toBe(expectedPartsCount);
89+
90+
if (images.length > 0) {
91+
const imagePart = callArgs.contents[1];
92+
expect(imagePart.inlineData.mimeType).toBe('image/png');
93+
expect(imagePart.inlineData.data).toBe(Buffer.from(images[0]).toString('base64'));
94+
}
95+
});
96+
97+
it('should always include hardcoded JSON schema', async () => {
98+
const config = createConfig();
99+
const mockResponse = createMockResponse('{"identical": true}');
100+
mockGenerateContent.mockResolvedValue(mockResponse);
101+
102+
await service.generate(config, []);
103+
104+
const callArgs = mockGenerateContent.mock.calls[0][0];
105+
expect(callArgs.config.responseMimeType).toBe('application/json');
106+
const schema = callArgs.config.responseJsonSchema;
107+
expect(schema).toBeDefined();
108+
expect(schema).toBeTruthy();
109+
});
110+
111+
it.each([
112+
['API key is missing', { apiKey: '' }, 'Gemini API key is required'],
113+
['SDK call fails', { apiKey: 'test-api-key' }, 'API Error'],
114+
])('should throw error when %s', async (_, overrides, expectedError) => {
115+
const config = createConfig(overrides);
116+
117+
if (expectedError === 'API Error') {
118+
mockGenerateContent.mockRejectedValue(new Error(expectedError));
119+
}
120+
121+
await expect(service.generate(config, [])).rejects.toThrow(expectedError);
122+
});
123+
});
124+
});
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import { Injectable, Logger } from '@nestjs/common';
2+
import { GoogleGenAI } from '@google/genai';
3+
import { zodToJsonSchema } from 'zod-to-json-schema';
4+
import { VlmProvider, VlmProviderResponse } from '../../vlm-provider.interface';
5+
import { VlmConfig, GeminiVlmConfig } from '../../vlm.types';
6+
import { VlmComparisonResultSchema } from '../../vlm.service';
7+
8+
@Injectable()
9+
export class GeminiService implements VlmProvider {
10+
private readonly logger: Logger = new Logger(GeminiService.name);
11+
12+
private getGenAI(apiKey: string): GoogleGenAI {
13+
if (!apiKey) {
14+
throw new Error('Gemini API key is required');
15+
}
16+
return new GoogleGenAI({ apiKey });
17+
}
18+
19+
private imageToBase64(imageBytes: Uint8Array): string {
20+
const base64 = Buffer.from(imageBytes).toString('base64');
21+
return base64;
22+
}
23+
24+
async generate(config: VlmConfig, images: Uint8Array[]): Promise<VlmProviderResponse> {
25+
// Type guard: ensure this is Gemini config
26+
if (config.provider !== 'gemini') {
27+
throw new Error(`GeminiService requires GeminiVlmConfig, got ${config.provider}`);
28+
}
29+
30+
const geminiConfig: GeminiVlmConfig = config;
31+
32+
const genAI = this.getGenAI(geminiConfig.apiKey);
33+
34+
try {
35+
const imageParts = images.map((img) => ({
36+
inlineData: {
37+
data: this.imageToBase64(img),
38+
mimeType: 'image/png',
39+
},
40+
}));
41+
42+
const parts = [{ text: geminiConfig.prompt }, ...imageParts];
43+
44+
const result = await genAI.models.generateContent({
45+
model: geminiConfig.model,
46+
contents: parts,
47+
config: {
48+
temperature: geminiConfig.temperature,
49+
responseMimeType: 'application/json' as const,
50+
responseJsonSchema: zodToJsonSchema(VlmComparisonResultSchema),
51+
},
52+
});
53+
54+
return {
55+
content: result.text,
56+
};
57+
} catch (error) {
58+
this.logger.error(`Gemini generate request failed: ${error.message}`, error.stack);
59+
throw error;
60+
}
61+
}
62+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import {
2+
Controller,
3+
Get,
4+
Post,
5+
Query,
6+
HttpException,
7+
HttpStatus,
8+
UseInterceptors,
9+
UploadedFiles,
10+
} from '@nestjs/common';
11+
import { FilesInterceptor } from '@nestjs/platform-express';
12+
import { ApiTags, ApiConsumes, ApiBody } from '@nestjs/swagger';
13+
import { OllamaService } from './ollama.service';
14+
import { VlmProviderResponse } from '../../vlm-provider.interface';
15+
import { OllamaVlmConfig } from '../../vlm.types';
16+
17+
@ApiTags('Ollama')
18+
@Controller('ollama')
19+
export class OllamaController {
20+
constructor(private readonly ollamaService: OllamaService) {}
21+
22+
@Get('models')
23+
async listModels() {
24+
return { models: await this.ollamaService.listModels() };
25+
}
26+
27+
@Post('compare')
28+
@ApiConsumes('multipart/form-data')
29+
@ApiBody({
30+
schema: {
31+
type: 'object',
32+
required: ['images'],
33+
properties: {
34+
images: {
35+
type: 'array',
36+
items: { type: 'string', format: 'binary' },
37+
description: 'Two images to compare (baseline and comparison)',
38+
},
39+
},
40+
},
41+
})
42+
@UseInterceptors(FilesInterceptor('images', 2))
43+
async compareImages(
44+
@UploadedFiles() files: Express.Multer.File[],
45+
@Query('model') model: string,
46+
@Query('prompt') prompt: string,
47+
@Query('temperature') temperature: string
48+
): Promise<VlmProviderResponse> {
49+
if (files?.length !== 2) {
50+
throw new HttpException('Two images required', HttpStatus.BAD_REQUEST);
51+
}
52+
53+
const config: OllamaVlmConfig = {
54+
provider: 'ollama',
55+
model,
56+
prompt,
57+
temperature: Number(temperature),
58+
};
59+
const images = files.map((f) => new Uint8Array(f.buffer));
60+
61+
return this.ollamaService.generate(config, images);
62+
}
63+
}

0 commit comments

Comments
 (0)