Skip to content

Commit 662d391

Browse files
committed
test: add tests for filling out --project, --oas, --config via workspace
1 parent 22d9bd2 commit 662d391

4 files changed

Lines changed: 690 additions & 5 deletions

File tree

internal/mockstainless/mock.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,18 @@ type Mock struct {
3030

3131
mu sync.Mutex
3232
buildIndex map[string]*ProgressiveBuild
33+
nextBuildSeq int
3334
enableGitRepos bool
3435
gitRepos map[string]gitRepo // key: "owner/name"
3536
tempDir string
37+
requests []RecordedRequest
38+
}
39+
40+
type RecordedRequest struct {
41+
Method string
42+
Path string
43+
RawQuery string
44+
Body string
3645
}
3746

3847
type gitRepo struct {
@@ -55,6 +64,7 @@ func (m *Mock) init() {
5564
if m.CompareBuild != nil && m.CompareBuild.PreviewBuild != nil {
5665
m.buildIndex[m.CompareBuild.PreviewBuild.ID] = m.CompareBuild.PreviewBuild
5766
}
67+
m.nextBuildSeq = len(m.buildIndex)
5868
if m.enableGitRepos {
5969
m.initGitRepos()
6070
}
@@ -173,6 +183,26 @@ func (m *Mock) GetBuild(id string) *ProgressiveBuild {
173183
return m.buildIndex[id]
174184
}
175185

186+
func (m *Mock) RecordRequest(request RecordedRequest) {
187+
m.mu.Lock()
188+
defer m.mu.Unlock()
189+
m.requests = append(m.requests, request)
190+
}
191+
192+
func (m *Mock) Requests() []RecordedRequest {
193+
m.mu.Lock()
194+
defer m.mu.Unlock()
195+
out := make([]RecordedRequest, len(m.requests))
196+
copy(out, m.requests)
197+
return out
198+
}
199+
200+
func (m *Mock) ResetRequests() {
201+
m.mu.Lock()
202+
defer m.mu.Unlock()
203+
m.requests = nil
204+
}
205+
176206
func (m *Mock) Diagnostics(id string) []M {
177207
m.mu.Lock()
178208
defer m.mu.Unlock()
@@ -182,6 +212,77 @@ func (m *Mock) Diagnostics(id string) []M {
182212
return []M{}
183213
}
184214

215+
func (m *Mock) CreateBuildFromTemplate(template *ProgressiveBuild) *ProgressiveBuild {
216+
if template == nil {
217+
return nil
218+
}
219+
220+
m.mu.Lock()
221+
defer m.mu.Unlock()
222+
223+
m.nextBuildSeq++
224+
build := cloneProgressiveBuild(template)
225+
build.ID = fmt.Sprintf("bui_mock_%06d", m.nextBuildSeq)
226+
build.StartTime = time.Now()
227+
228+
m.Builds = append([]*ProgressiveBuild{build}, m.Builds...)
229+
m.buildIndex[build.ID] = build
230+
return build
231+
}
232+
233+
func cloneProgressiveBuild(template *ProgressiveBuild) *ProgressiveBuild {
234+
build := &ProgressiveBuild{
235+
ID: template.ID,
236+
ConfigCommit: template.ConfigCommit,
237+
Targets: append([]string(nil), template.Targets...),
238+
CompletedData: make(map[string]M, len(template.CompletedData)),
239+
Diagnostics: make([]M, len(template.Diagnostics)),
240+
Delay: template.Delay,
241+
}
242+
243+
for name, target := range template.CompletedData {
244+
build.CompletedData[name] = cloneMap(target)
245+
}
246+
for i, diagnostic := range template.Diagnostics {
247+
build.Diagnostics[i] = cloneMap(diagnostic)
248+
}
249+
250+
return build
251+
}
252+
253+
func cloneMap(src M) M {
254+
if src == nil {
255+
return nil
256+
}
257+
out := make(M, len(src))
258+
for key, value := range src {
259+
out[key] = cloneValue(value)
260+
}
261+
return out
262+
}
263+
264+
func cloneSlice(src []any) []any {
265+
if src == nil {
266+
return nil
267+
}
268+
out := make([]any, len(src))
269+
for i, value := range src {
270+
out[i] = cloneValue(value)
271+
}
272+
return out
273+
}
274+
275+
func cloneValue(value any) any {
276+
switch v := value.(type) {
277+
case M:
278+
return cloneMap(v)
279+
case []any:
280+
return cloneSlice(v)
281+
default:
282+
return v
283+
}
284+
}
285+
185286
// MockOption configures a Mock via NewMock.
186287
type MockOption func(*Mock)
187288

internal/mockstainless/server.go

Lines changed: 170 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package mockstainless
22

33
import (
4+
"bytes"
45
"encoding/json"
56
"fmt"
7+
"io"
68
"net/http"
79
"time"
10+
11+
"github.com/tidwall/gjson"
812
)
913

1014
func writeJSON(w http.ResponseWriter, status int, v any) {
@@ -39,8 +43,8 @@ func newServeMux(m *Mock) http.Handler {
3943
"user_code": "DEMO-CODE",
4044
"verification_uri": "https://app.stainless.com/activate",
4145
"verification_uri_complete": "https://app.stainless.com/activate?code=DEMO-CODE",
42-
"expires_in": 300,
43-
"interval": 1,
46+
"expires_in": 300,
47+
"interval": 1,
4448
})
4549
})
4650

@@ -85,11 +89,132 @@ func newServeMux(m *Mock) http.Handler {
8589
}
8690
})
8791

92+
mux.HandleFunc("PATCH /v0/projects/{project}", func(w http.ResponseWriter, r *http.Request) {
93+
project := r.PathValue("project")
94+
if project == "" {
95+
writeJSON(w, http.StatusBadRequest, M{"error": "project is required"})
96+
return
97+
}
98+
body := mustReadBody(r)
99+
writeJSON(w, http.StatusOK, M{
100+
"slug": project,
101+
"display_name": gjson.GetBytes(body, "display_name").String(),
102+
"object": "project",
103+
})
104+
})
105+
106+
mux.HandleFunc("POST /v0/projects/{project}/generate_commit_message", func(w http.ResponseWriter, r *http.Request) {
107+
if r.PathValue("project") == "" {
108+
writeJSON(w, http.StatusBadRequest, M{"error": "project is required"})
109+
return
110+
}
111+
writeJSON(w, http.StatusOK, M{
112+
"message": "mock commit message",
113+
})
114+
})
115+
88116
mux.HandleFunc("GET /v0/projects/{project}/configs", func(w http.ResponseWriter, r *http.Request) {
89117
writeJSON(w, http.StatusOK, m.ProjectConfigs)
90118
})
91119

120+
mux.HandleFunc("POST /v0/projects/{project}/configs/guess", func(w http.ResponseWriter, r *http.Request) {
121+
if r.PathValue("project") == "" {
122+
writeJSON(w, http.StatusBadRequest, M{"error": "project is required"})
123+
return
124+
}
125+
writeJSON(w, http.StatusOK, M{
126+
"stainless.yml": M{
127+
"content": "# guessed",
128+
},
129+
})
130+
})
131+
132+
mux.HandleFunc("POST /v0/projects/{project}/branches", func(w http.ResponseWriter, r *http.Request) {
133+
if r.PathValue("project") == "" {
134+
writeJSON(w, http.StatusBadRequest, M{"error": "project is required"})
135+
return
136+
}
137+
body := mustReadBody(r)
138+
writeJSON(w, http.StatusOK, M{
139+
"branch": gjson.GetBytes(body, "branch").String(),
140+
"object": "project_branch",
141+
})
142+
})
143+
144+
mux.HandleFunc("GET /v0/projects/{project}/branches", func(w http.ResponseWriter, r *http.Request) {
145+
if r.PathValue("project") == "" {
146+
writeJSON(w, http.StatusBadRequest, M{"error": "project is required"})
147+
return
148+
}
149+
writeJSON(w, http.StatusOK, Page([]M{
150+
{"branch": "main", "object": "project_branch"},
151+
}))
152+
})
153+
154+
mux.HandleFunc("GET /v0/projects/{project}/branches/{branch}", func(w http.ResponseWriter, r *http.Request) {
155+
if r.PathValue("project") == "" {
156+
writeJSON(w, http.StatusBadRequest, M{"error": "project is required"})
157+
return
158+
}
159+
writeJSON(w, http.StatusOK, M{
160+
"branch": r.PathValue("branch"),
161+
"object": "project_branch",
162+
})
163+
})
164+
165+
mux.HandleFunc("DELETE /v0/projects/{project}/branches/{branch}", func(w http.ResponseWriter, r *http.Request) {
166+
if r.PathValue("project") == "" {
167+
writeJSON(w, http.StatusBadRequest, M{"error": "project is required"})
168+
return
169+
}
170+
writeJSON(w, http.StatusOK, M{"deleted": true})
171+
})
172+
173+
mux.HandleFunc("PUT /v0/projects/{project}/branches/{branch}/rebase", func(w http.ResponseWriter, r *http.Request) {
174+
if r.PathValue("project") == "" {
175+
writeJSON(w, http.StatusBadRequest, M{"error": "project is required"})
176+
return
177+
}
178+
writeJSON(w, http.StatusOK, M{
179+
"branch": r.PathValue("branch"),
180+
"object": "project_branch",
181+
})
182+
})
183+
184+
mux.HandleFunc("PUT /v0/projects/{project}/branches/{branch}/reset", func(w http.ResponseWriter, r *http.Request) {
185+
if r.PathValue("project") == "" {
186+
writeJSON(w, http.StatusBadRequest, M{"error": "project is required"})
187+
return
188+
}
189+
writeJSON(w, http.StatusOK, M{
190+
"branch": r.PathValue("branch"),
191+
"object": "project_branch",
192+
})
193+
})
194+
195+
mux.HandleFunc("POST /v0/builds", func(w http.ResponseWriter, r *http.Request) {
196+
body := mustReadBody(r)
197+
if gjson.GetBytes(body, "project").String() == "" {
198+
writeJSON(w, http.StatusBadRequest, M{"error": "project is required"})
199+
return
200+
}
201+
if len(m.Builds) == 0 {
202+
writeJSON(w, http.StatusNotFound, M{"error": "missing build"})
203+
return
204+
}
205+
build := m.CreateBuildFromTemplate(m.Builds[0])
206+
if build == nil {
207+
writeJSON(w, http.StatusNotFound, M{"error": "missing build"})
208+
return
209+
}
210+
writeJSON(w, http.StatusOK, build.Snapshot())
211+
})
212+
92213
mux.HandleFunc("GET /v0/builds", func(w http.ResponseWriter, r *http.Request) {
214+
if r.URL.Query().Get("project") == "" {
215+
writeJSON(w, http.StatusBadRequest, M{"error": "project is required"})
216+
return
217+
}
93218
builds := make([]M, len(m.Builds))
94219
for i, b := range m.Builds {
95220
builds[i] = b.Snapshot()
@@ -171,21 +296,61 @@ func newServeMux(m *Mock) http.Handler {
171296

172297
if m.CompareBuild != nil {
173298
mux.HandleFunc("POST /v0/builds/compare", func(w http.ResponseWriter, r *http.Request) {
174-
if m.CompareBuild.PreviewBuild != nil {
175-
m.CompareBuild.PreviewBuild.Reset()
299+
body := mustReadBody(r)
300+
if gjson.GetBytes(body, "project").String() == "" {
301+
writeJSON(w, http.StatusBadRequest, M{"error": "project is required"})
302+
return
303+
}
304+
headBuild := m.CreateBuildFromTemplate(m.CompareBuild.PreviewBuild)
305+
if headBuild == nil {
306+
writeJSON(w, http.StatusNotFound, M{"error": "missing preview build"})
307+
return
176308
}
309+
head := cloneMap(m.CompareBuild.Head)
310+
head["id"] = headBuild.ID
311+
head["created_at"] = time.Now().Format(time.RFC3339)
177312
writeJSON(w, http.StatusOK, M{
178313
"base": m.CompareBuild.Base,
179-
"head": m.CompareBuild.Head,
314+
"head": head,
180315
})
181316
})
182317
}
183318

319+
mux.HandleFunc("POST /api/generate/spec", func(w http.ResponseWriter, r *http.Request) {
320+
body := mustReadBody(r)
321+
if gjson.GetBytes(body, "project").String() == "" ||
322+
gjson.GetBytes(body, "source.openapi_spec").String() == "" ||
323+
gjson.GetBytes(body, "source.stainless_config").String() == "" {
324+
writeJSON(w, http.StatusBadRequest, M{"error": "project, openapi_spec, and stainless_config are required"})
325+
return
326+
}
327+
writeJSON(w, http.StatusOK, M{
328+
"spec": M{
329+
"diagnostics": M{},
330+
},
331+
})
332+
})
333+
184334
// Add simulated latency to all requests (except health checks).
185335
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
336+
body := mustReadBody(r)
337+
m.RecordRequest(RecordedRequest{
338+
Method: r.Method,
339+
Path: r.URL.Path,
340+
RawQuery: r.URL.RawQuery,
341+
Body: string(body),
342+
})
343+
r.Body = io.NopCloser(bytes.NewReader(body))
344+
186345
if r.URL.Path != "/health" {
187346
time.Sleep(150 * time.Millisecond)
188347
}
189348
mux.ServeHTTP(w, r)
190349
})
191350
}
351+
352+
func mustReadBody(r *http.Request) []byte {
353+
body, _ := io.ReadAll(r.Body)
354+
r.Body = io.NopCloser(bytes.NewReader(body))
355+
return body
356+
}

pkg/cmd/cmd.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,21 @@ func getWorkspace(ctx context.Context) workspace.Config {
308308
return ctx.Value("workspace_config").(workspace.Config)
309309
}
310310

311+
func setFlagValue(cmd *cli.Command, name string, value string) error {
312+
for _, flag := range cmd.Flags {
313+
for _, flagName := range flag.Names() {
314+
if flagName != name {
315+
continue
316+
}
317+
if setter, ok := flag.(interface{ Set(string, string) error }); ok {
318+
return setter.Set(name, value)
319+
}
320+
return cmd.Set(name, value)
321+
}
322+
}
323+
return cmd.Set(name, value)
324+
}
325+
311326
func generateManpages(ctx context.Context, c *cli.Command) error {
312327
manpage, err := docs.ToManWithSection(Command, 1)
313328
if err != nil {

0 commit comments

Comments
 (0)