Skip to content

Commit fb73534

Browse files
committed
v1.0.0: Add Azure, Bedrock, Vertex providers, SSE streaming, MCP transport, context management
1 parent afc2e9c commit fb73534

8 files changed

Lines changed: 1782 additions & 64 deletions

File tree

azure.go

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
package iteragent
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"strings"
11+
)
12+
13+
type AzureOpenAIConfig struct {
14+
APIKey string
15+
Endpoint string
16+
Deployment string
17+
APIVersion string
18+
MaxTokens int
19+
Temperature float32
20+
ThinkingLevel ThinkingLevel
21+
}
22+
23+
type AzureOpenAIProvider struct {
24+
config AzureOpenAIConfig
25+
client *http.Client
26+
}
27+
28+
func NewAzureOpenAI(config AzureOpenAIConfig) *AzureOpenAIProvider {
29+
return &AzureOpenAIProvider{
30+
config: config,
31+
client: &http.Client{},
32+
}
33+
}
34+
35+
func (p *AzureOpenAIProvider) Name() string {
36+
return "azure_openai"
37+
}
38+
39+
func (p *AzureOpenAIProvider) Complete(ctx context.Context, messages []Message) (string, error) {
40+
apiVersion := p.config.APIVersion
41+
if apiVersion == "" {
42+
apiVersion = "2024-02-15-preview"
43+
}
44+
45+
url := fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s",
46+
p.config.Endpoint, p.config.Deployment, apiVersion)
47+
48+
body := map[string]interface{}{
49+
"messages": messagesToAzureFormat(messages),
50+
"stream": false,
51+
}
52+
53+
if p.config.MaxTokens > 0 {
54+
body["max_tokens"] = p.config.MaxTokens
55+
}
56+
if p.config.Temperature > 0 {
57+
body["temperature"] = p.config.Temperature
58+
}
59+
60+
jsonBody, err := json.Marshal(body)
61+
if err != nil {
62+
return "", fmt.Errorf("marshal request: %w", err)
63+
}
64+
65+
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
66+
if err != nil {
67+
return "", fmt.Errorf("create request: %w", err)
68+
}
69+
70+
req.Header.Set("Content-Type", "application/json")
71+
req.Header.Set("api-key", p.config.APIKey)
72+
73+
resp, err := p.client.Do(req)
74+
if err != nil {
75+
return "", fmt.Errorf("request failed: %w", err)
76+
}
77+
defer resp.Body.Close()
78+
79+
respBody, err := io.ReadAll(resp.Body)
80+
if err != nil {
81+
return "", fmt.Errorf("read response: %w", err)
82+
}
83+
84+
if resp.StatusCode != http.StatusOK {
85+
return "", fmt.Errorf("Azure OpenAI error (%d): %s", resp.StatusCode, string(respBody))
86+
}
87+
88+
var response struct {
89+
Choices []struct {
90+
Message struct {
91+
Content string `json:"content"`
92+
} `json:"message"`
93+
} `json:"choices"`
94+
}
95+
96+
if err := json.Unmarshal(respBody, &response); err != nil {
97+
return "", fmt.Errorf("parse response: %w", err)
98+
}
99+
100+
if len(response.Choices) == 0 {
101+
return "", fmt.Errorf("no response choices")
102+
}
103+
104+
return response.Choices[0].Message.Content, nil
105+
}
106+
107+
func messagesToAzureFormat(messages []Message) []map[string]interface{} {
108+
result := make([]map[string]interface{}, len(messages))
109+
for i, msg := range messages {
110+
m := map[string]interface{}{
111+
"role": msg.Role,
112+
"content": msg.Content,
113+
}
114+
result[i] = m
115+
}
116+
return result
117+
}
118+
119+
func (p *AzureOpenAIProvider) Stream(ctx context.Context, config StreamConfig, messages []Message, onEvent func(StreamEvent)) (Message, error) {
120+
apiVersion := p.config.APIVersion
121+
if apiVersion == "" {
122+
apiVersion = "2024-02-15-preview"
123+
}
124+
125+
url := fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s",
126+
p.config.Endpoint, p.config.Deployment, apiVersion)
127+
128+
body := map[string]interface{}{
129+
"messages": messagesToAzureFormat(messages),
130+
"stream": true,
131+
}
132+
133+
if config.MaxTokens > 0 {
134+
body["max_tokens"] = config.MaxTokens
135+
}
136+
if config.Temperature > 0 {
137+
body["temperature"] = config.Temperature
138+
}
139+
140+
jsonBody, err := json.Marshal(body)
141+
if err != nil {
142+
return Message{}, fmt.Errorf("marshal request: %w", err)
143+
}
144+
145+
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
146+
if err != nil {
147+
return Message{}, fmt.Errorf("create request: %w", err)
148+
}
149+
150+
req.Header.Set("Content-Type", "application/json")
151+
req.Header.Set("api-key", p.config.APIKey)
152+
153+
resp, err := p.client.Do(req)
154+
if err != nil {
155+
return Message{}, fmt.Errorf("request failed: %w", err)
156+
}
157+
defer resp.Body.Close()
158+
159+
if resp.StatusCode != http.StatusOK {
160+
respBody, _ := io.ReadAll(resp.Body)
161+
return Message{}, fmt.Errorf("Azure OpenAI error (%d): %s", resp.StatusCode, string(respBody))
162+
}
163+
164+
var content strings.Builder
165+
decoder := NewSSEDecoder(resp.Body)
166+
167+
for {
168+
event, err := decoder.Decode()
169+
if err == io.EOF {
170+
break
171+
}
172+
if err != nil {
173+
break
174+
}
175+
176+
if event.Type == "content" || event.Type == "content_block" {
177+
content.WriteString(event.Content)
178+
onEvent(event)
179+
}
180+
}
181+
182+
return Message{
183+
Role: "assistant",
184+
Content: content.String(),
185+
}, nil
186+
}
187+
188+
type SSEDecoder struct {
189+
reader io.Reader
190+
}
191+
192+
func NewSSEDecoder(reader io.Reader) *SSEDecoder {
193+
return &SSEDecoder{reader: reader}
194+
}
195+
196+
func (d *SSEDecoder) Decode() (StreamEvent, error) {
197+
var line string
198+
for {
199+
buf := make([]byte, 1024)
200+
n, err := d.reader.Read(buf)
201+
if n == 0 || err != nil {
202+
return StreamEvent{}, err
203+
}
204+
line = string(buf[:n])
205+
if strings.HasPrefix(line, "data:") {
206+
break
207+
}
208+
}
209+
210+
line = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
211+
if line == "[DONE]" {
212+
return StreamEvent{}, io.EOF
213+
}
214+
215+
var delta struct {
216+
Choices []struct {
217+
Delta struct {
218+
Content string `json:"content"`
219+
} `json:"delta"`
220+
} `json:"choices"`
221+
}
222+
223+
if err := json.Unmarshal([]byte(line), &delta); err != nil {
224+
return StreamEvent{}, err
225+
}
226+
227+
content := ""
228+
if len(delta.Choices) > 0 {
229+
content = delta.Choices[0].Delta.Content
230+
}
231+
232+
return StreamEvent{
233+
Type: StreamEventContent,
234+
Content: content,
235+
}, nil
236+
}

0 commit comments

Comments
 (0)