Skip to content

Commit 98c7fd3

Browse files
Copilotsawka
andcommitted
Add Google AI file summarization package
Co-authored-by: sawka <2722291+sawka@users.noreply.github.com>
1 parent b4319a2 commit 98c7fd3

3 files changed

Lines changed: 356 additions & 0 deletions

File tree

pkg/aiusechat/google/doc.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright 2025, Command Line Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
// Package google provides Google Generative AI integration for WaveTerm.
5+
//
6+
// This package implements file summarization using Google's Gemini models.
7+
// Unlike other AI provider implementations in the aiusechat package, this
8+
// package does NOT implement full SSE streaming. It uses a simple
9+
// request-response API for file summarization.
10+
//
11+
// # Supported File Types
12+
//
13+
// The package supports the same file types as defined in wshcmd-ai.go:
14+
// - Images (PNG, JPEG, etc.): up to 7MB
15+
// - PDFs: up to 5MB
16+
// - Text files: up to 200KB
17+
//
18+
// Binary files are rejected unless they are recognized as images or PDFs.
19+
//
20+
// # Usage
21+
//
22+
// To summarize a file:
23+
//
24+
// ctx := context.Background()
25+
// summary, usage, err := google.SummarizeFile(ctx, "/path/to/file.txt", "YOUR_API_KEY")
26+
// if err != nil {
27+
// log.Fatal(err)
28+
// }
29+
// fmt.Println("Summary:", summary)
30+
// fmt.Printf("Tokens used: %d\n", usage.TotalTokenCount)
31+
//
32+
// # Configuration
33+
//
34+
// The summarization behavior can be customized by modifying the constants:
35+
// - SummarizeModel: The Gemini model to use (default: "gemini-2.5-flash-lite")
36+
// - SummarizePrompt: The prompt sent to the model
37+
// - GoogleAPIURL: The base URL for the API (for reference, not currently used by the SDK)
38+
package google
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Copyright 2025, Command Line Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package google
5+
6+
import (
7+
"context"
8+
"fmt"
9+
"net/http"
10+
"os"
11+
"strings"
12+
13+
"github.com/google/generative-ai-go/genai"
14+
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
15+
"google.golang.org/api/option"
16+
)
17+
18+
const (
19+
// GoogleAPIURL is the base URL for the Google Generative AI API
20+
GoogleAPIURL = "https://generativelanguage.googleapis.com"
21+
22+
// SummarizePrompt is the prompt used for file summarization
23+
SummarizePrompt = "Please provide a concise summary of this file. Include the main topics, key points, and any notable information."
24+
25+
// SummarizeModel is the model used for file summarization
26+
SummarizeModel = "gemini-2.5-flash-lite"
27+
)
28+
29+
// GoogleUsage represents token usage information from Google's Generative AI API
30+
type GoogleUsage struct {
31+
PromptTokenCount int32 `json:"prompt_token_count"`
32+
CachedContentTokenCount int32 `json:"cached_content_token_count"`
33+
CandidatesTokenCount int32 `json:"candidates_token_count"`
34+
TotalTokenCount int32 `json:"total_token_count"`
35+
}
36+
37+
func detectMimeType(data []byte) string {
38+
mimeType := http.DetectContentType(data)
39+
return strings.Split(mimeType, ";")[0]
40+
}
41+
42+
func getMaxFileSize(mimeType string) (int, string) {
43+
if mimeType == "application/pdf" {
44+
return 5 * 1024 * 1024, "5MB"
45+
}
46+
if strings.HasPrefix(mimeType, "image/") {
47+
return 7 * 1024 * 1024, "7MB"
48+
}
49+
return 200 * 1024, "200KB"
50+
}
51+
52+
// SummarizeFile reads a file and generates a summary using Google's Generative AI.
53+
// It supports images, PDFs, and text files based on the limits defined in wshcmd-ai.go.
54+
// Returns the summary text, usage information, and any error encountered.
55+
func SummarizeFile(ctx context.Context, filename string, apiKey string) (string, *GoogleUsage, error) {
56+
// Read the file
57+
data, err := os.ReadFile(filename)
58+
if err != nil {
59+
return "", nil, fmt.Errorf("reading file: %w", err)
60+
}
61+
62+
// Detect MIME type
63+
mimeType := detectMimeType(data)
64+
65+
isPDF := mimeType == "application/pdf"
66+
isImage := strings.HasPrefix(mimeType, "image/")
67+
68+
if !isPDF && !isImage {
69+
mimeType = "text/plain"
70+
if utilfn.ContainsBinaryData(data) {
71+
return "", nil, fmt.Errorf("file contains binary data and cannot be summarized")
72+
}
73+
}
74+
75+
// Validate file size
76+
maxSize, sizeStr := getMaxFileSize(mimeType)
77+
if len(data) > maxSize {
78+
return "", nil, fmt.Errorf("file exceeds maximum size of %s for %s files", sizeStr, mimeType)
79+
}
80+
81+
// Create client
82+
client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
83+
if err != nil {
84+
return "", nil, fmt.Errorf("creating Google AI client: %w", err)
85+
}
86+
defer client.Close()
87+
88+
// Create model
89+
model := client.GenerativeModel(SummarizeModel)
90+
91+
// Prepare the content parts
92+
var parts []genai.Part
93+
94+
// Add the prompt
95+
parts = append(parts, genai.Text(SummarizePrompt))
96+
97+
// Add the file content based on type
98+
if isImage {
99+
// For images, use Blob
100+
parts = append(parts, genai.Blob{
101+
MIMEType: mimeType,
102+
Data: data,
103+
})
104+
} else if isPDF {
105+
// For PDFs, use Blob
106+
parts = append(parts, genai.Blob{
107+
MIMEType: mimeType,
108+
Data: data,
109+
})
110+
} else {
111+
// For text files, convert to string
112+
parts = append(parts, genai.Text(string(data)))
113+
}
114+
115+
// Generate content
116+
resp, err := model.GenerateContent(ctx, parts...)
117+
if err != nil {
118+
return "", nil, fmt.Errorf("generating content: %w", err)
119+
}
120+
121+
// Check if we got any candidates
122+
if len(resp.Candidates) == 0 {
123+
return "", nil, fmt.Errorf("no response candidates returned")
124+
}
125+
126+
// Extract the text from the first candidate
127+
candidate := resp.Candidates[0]
128+
if candidate.Content == nil || len(candidate.Content.Parts) == 0 {
129+
return "", nil, fmt.Errorf("no content in response")
130+
}
131+
132+
var summary strings.Builder
133+
for _, part := range candidate.Content.Parts {
134+
if textPart, ok := part.(genai.Text); ok {
135+
summary.WriteString(string(textPart))
136+
}
137+
}
138+
139+
// Convert usage metadata
140+
var usage *GoogleUsage
141+
if resp.UsageMetadata != nil {
142+
usage = &GoogleUsage{
143+
PromptTokenCount: resp.UsageMetadata.PromptTokenCount,
144+
CachedContentTokenCount: resp.UsageMetadata.CachedContentTokenCount,
145+
CandidatesTokenCount: resp.UsageMetadata.CandidatesTokenCount,
146+
TotalTokenCount: resp.UsageMetadata.TotalTokenCount,
147+
}
148+
}
149+
150+
return summary.String(), usage, nil
151+
}
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// Copyright 2025, Command Line Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package google
5+
6+
import (
7+
"context"
8+
"os"
9+
"path/filepath"
10+
"testing"
11+
"time"
12+
)
13+
14+
func TestDetectMimeType(t *testing.T) {
15+
tests := []struct {
16+
name string
17+
data []byte
18+
expected string
19+
}{
20+
{
21+
name: "plain text",
22+
data: []byte("Hello, World!"),
23+
expected: "text/plain",
24+
},
25+
{
26+
name: "empty file",
27+
data: []byte{},
28+
expected: "text/plain",
29+
},
30+
}
31+
32+
for _, tt := range tests {
33+
t.Run(tt.name, func(t *testing.T) {
34+
result := detectMimeType(tt.data)
35+
if !containsMimeType(result, tt.expected) {
36+
t.Errorf("detectMimeType() = %v, want to contain %v", result, tt.expected)
37+
}
38+
})
39+
}
40+
}
41+
42+
func containsMimeType(got, want string) bool {
43+
// DetectContentType may return variations like "text/plain; charset=utf-8"
44+
return got == want || (want == "text/plain" && got == "text/plain; charset=utf-8")
45+
}
46+
47+
func TestGetMaxFileSize(t *testing.T) {
48+
tests := []struct {
49+
name string
50+
mimeType string
51+
expectedSize int
52+
expectedStr string
53+
}{
54+
{
55+
name: "PDF file",
56+
mimeType: "application/pdf",
57+
expectedSize: 5 * 1024 * 1024,
58+
expectedStr: "5MB",
59+
},
60+
{
61+
name: "PNG image",
62+
mimeType: "image/png",
63+
expectedSize: 7 * 1024 * 1024,
64+
expectedStr: "7MB",
65+
},
66+
{
67+
name: "JPEG image",
68+
mimeType: "image/jpeg",
69+
expectedSize: 7 * 1024 * 1024,
70+
expectedStr: "7MB",
71+
},
72+
{
73+
name: "text file",
74+
mimeType: "text/plain",
75+
expectedSize: 200 * 1024,
76+
expectedStr: "200KB",
77+
},
78+
}
79+
80+
for _, tt := range tests {
81+
t.Run(tt.name, func(t *testing.T) {
82+
size, sizeStr := getMaxFileSize(tt.mimeType)
83+
if size != tt.expectedSize {
84+
t.Errorf("getMaxFileSize() size = %v, want %v", size, tt.expectedSize)
85+
}
86+
if sizeStr != tt.expectedStr {
87+
t.Errorf("getMaxFileSize() sizeStr = %v, want %v", sizeStr, tt.expectedStr)
88+
}
89+
})
90+
}
91+
}
92+
93+
func TestSummarizeFile_FileNotFound(t *testing.T) {
94+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
95+
defer cancel()
96+
97+
_, _, err := SummarizeFile(ctx, "/nonexistent/file.txt", "fake-api-key")
98+
if err == nil {
99+
t.Error("SummarizeFile() expected error for nonexistent file, got nil")
100+
}
101+
}
102+
103+
func TestSummarizeFile_BinaryFile(t *testing.T) {
104+
// Create a temporary binary file
105+
tmpDir := t.TempDir()
106+
binFile := filepath.Join(tmpDir, "test.bin")
107+
108+
// Create binary data (not text, image, or PDF)
109+
binaryData := []byte{0x00, 0x01, 0x02, 0x03, 0x7F, 0x80, 0xFF}
110+
if err := os.WriteFile(binFile, binaryData, 0644); err != nil {
111+
t.Fatalf("Failed to create test file: %v", err)
112+
}
113+
114+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
115+
defer cancel()
116+
117+
_, _, err := SummarizeFile(ctx, binFile, "fake-api-key")
118+
if err == nil {
119+
t.Error("SummarizeFile() expected error for binary file, got nil")
120+
}
121+
if err != nil && !containsString(err.Error(), "binary data") {
122+
t.Errorf("SummarizeFile() error = %v, want error containing 'binary data'", err)
123+
}
124+
}
125+
126+
func TestSummarizeFile_FileTooLarge(t *testing.T) {
127+
// Create a temporary text file that exceeds the limit
128+
tmpDir := t.TempDir()
129+
textFile := filepath.Join(tmpDir, "large.txt")
130+
131+
// Create a file larger than 200KB (text file limit)
132+
largeData := make([]byte, 201*1024)
133+
for i := range largeData {
134+
largeData[i] = 'a'
135+
}
136+
if err := os.WriteFile(textFile, largeData, 0644); err != nil {
137+
t.Fatalf("Failed to create test file: %v", err)
138+
}
139+
140+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
141+
defer cancel()
142+
143+
_, _, err := SummarizeFile(ctx, textFile, "fake-api-key")
144+
if err == nil {
145+
t.Error("SummarizeFile() expected error for file too large, got nil")
146+
}
147+
if err != nil && !containsString(err.Error(), "exceeds maximum size") {
148+
t.Errorf("SummarizeFile() error = %v, want error containing 'exceeds maximum size'", err)
149+
}
150+
}
151+
152+
func containsString(s, substr string) bool {
153+
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
154+
(len(s) > 0 && len(substr) > 0 && stringContains(s, substr)))
155+
}
156+
157+
func stringContains(s, substr string) bool {
158+
for i := 0; i <= len(s)-len(substr); i++ {
159+
if s[i:i+len(substr)] == substr {
160+
return true
161+
}
162+
}
163+
return false
164+
}
165+
166+
// Note: We don't test the actual API call without a real API key
167+
// Integration tests would require setting GOOGLE_API_KEY environment variable

0 commit comments

Comments
 (0)