Skip to content

Commit d8c3a0c

Browse files
committed
Add NVIDIA provider
1 parent eba1bc1 commit d8c3a0c

2 files changed

Lines changed: 94 additions & 1 deletion

File tree

nvidia.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package iteragent
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"time"
11+
)
12+
13+
type nvidiaProvider struct {
14+
cfg OpenAICompatConfig
15+
client *http.Client
16+
}
17+
18+
func NewNvidia(cfg OpenAICompatConfig) Provider {
19+
return &nvidiaProvider{
20+
cfg: cfg,
21+
client: &http.Client{Timeout: 120 * time.Second},
22+
}
23+
}
24+
25+
func (p *nvidiaProvider) Name() string {
26+
return fmt.Sprintf("nvidia(%s)", p.cfg.Model)
27+
}
28+
29+
func (p *nvidiaProvider) Complete(ctx context.Context, messages []Message) (string, error) {
30+
url := p.cfg.BaseURL + "/chat/completions"
31+
if p.cfg.BaseURL == "" {
32+
url = "https://integrate.api.nvidia.com/v1/chat/completions"
33+
}
34+
35+
reqBody := openaiRequest{
36+
Model: p.cfg.Model,
37+
Messages: messages,
38+
Stream: false,
39+
}
40+
41+
body, err := json.Marshal(reqBody)
42+
if err != nil {
43+
return "", fmt.Errorf("marshal request: %w", err)
44+
}
45+
46+
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
47+
if err != nil {
48+
return "", fmt.Errorf("create request: %w", err)
49+
}
50+
51+
req.Header.Set("Content-Type", "application/json")
52+
req.Header.Set("Authorization", "Bearer "+p.cfg.APIKey)
53+
54+
resp, err := p.client.Do(req)
55+
if err != nil {
56+
return "", fmt.Errorf("do request: %w", err)
57+
}
58+
defer resp.Body.Close()
59+
60+
respBody, err := io.ReadAll(resp.Body)
61+
if err != nil {
62+
return "", fmt.Errorf("read response: %w", err)
63+
}
64+
65+
if resp.StatusCode != 200 {
66+
return "", fmt.Errorf("nvidia API error (%d): %s", resp.StatusCode, string(respBody))
67+
}
68+
69+
var parsed openaiResponse
70+
if err := json.Unmarshal(respBody, &parsed); err != nil {
71+
return "", fmt.Errorf("unmarshal response: %w", err)
72+
}
73+
74+
if len(parsed.Choices) == 0 {
75+
return "", fmt.Errorf("no response from nvidia")
76+
}
77+
78+
return parsed.Choices[0].Message.Content, nil
79+
}

provider.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ type Provider interface {
1313
}
1414

1515
// NewProvider returns the provider selected by ITERATE_PROVIDER.
16-
// Supported values: ollama, openai, anthropic, groq, gemini (default: gemini)
16+
// Supported values: ollama, openai, anthropic, groq, gemini, nvidia (default: gemini)
1717
// If apiKey is provided, it takes priority over environment variables.
1818
func NewProvider(providerName string, apiKey ...string) (Provider, error) {
1919
providedKey := ""
@@ -94,6 +94,20 @@ func NewProvider(providerName string, apiKey ...string) (Provider, error) {
9494
APIKey: key,
9595
}), nil
9696

97+
case "nvidia":
98+
key := providedKey
99+
if key == "" {
100+
key = os.Getenv("NVIDIA_API_KEY")
101+
}
102+
if key == "" {
103+
return nil, fmt.Errorf("NVIDIA_API_KEY is required for nvidia provider (or use --api-key)")
104+
}
105+
return NewNvidia(OpenAICompatConfig{
106+
BaseURL: getEnvOr("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"),
107+
Model: getEnvOr("ITERATE_MODEL", "nvidia/llama-3.3-nemotron-70b-instruct"),
108+
APIKey: key,
109+
}), nil
110+
97111
default:
98112
baseURL := os.Getenv("ITERATE_BASE_URL")
99113
if baseURL == "" {

0 commit comments

Comments
 (0)