@@ -53,7 +53,7 @@ func TestWaitForJobOnce_Success(t *testing.T) {
5353 client := codersdk .New (srvURL )
5454
5555 version := & codersdk.TemplateVersion {ID : versionID }
56- logs , done , err := waitForJobOnce (context .Background (), client , version )
56+ logs , done , err := waitForJobOnce (context .Background (), client , version , 0 )
5757 require .NoError (t , err )
5858 require .True (t , done )
5959 require .Len (t , logs , 1 )
@@ -92,7 +92,7 @@ func TestWaitForJobOnce_JobFailed(t *testing.T) {
9292 client := codersdk .New (srvURL )
9393
9494 version := & codersdk.TemplateVersion {ID : versionID }
95- _ , done , err := waitForJobOnce (context .Background (), client , version )
95+ _ , done , err := waitForJobOnce (context .Background (), client , version , 0 )
9696 require .Error (t , err )
9797 require .False (t , done )
9898 require .Contains (t , err .Error (), "provisioner job did not succeed" )
@@ -130,7 +130,7 @@ func TestWaitForJobOnce_StillActive(t *testing.T) {
130130 client := codersdk .New (srvURL )
131131
132132 version := & codersdk.TemplateVersion {ID : versionID }
133- _ , done , err := waitForJobOnce (context .Background (), client , version )
133+ _ , done , err := waitForJobOnce (context .Background (), client , version , 0 )
134134 require .NoError (t , err )
135135 require .False (t , done )
136136}
@@ -178,18 +178,26 @@ func TestWaitForJob_SucceedsOnRetry(t *testing.T) {
178178 t .Parallel ()
179179 versionID := uuid .New ()
180180 var versionCallCount atomic.Int32
181+ var wsCallCount atomic.Int32
182+ var secondAfter atomic.Value
183+ secondAfter .Store ("" )
181184
182185 handler := http .NewServeMux ()
183186 handler .HandleFunc ("/api/v2/templateversions/" , func (w http.ResponseWriter , r * http.Request ) {
184187 if strings .Contains (r .URL .RawQuery , "follow" ) {
188+ call := wsCallCount .Add (1 )
189+ if call == 2 {
190+ secondAfter .Store (r .URL .Query ().Get ("after" ))
191+ }
192+
185193 conn , err := websocket .Accept (w , r , nil )
186194 if err != nil {
187195 http .Error (w , err .Error (), http .StatusInternalServerError )
188196 return
189197 }
190198 ctx := r .Context ()
191199 _ = wsjson .Write (ctx , conn , codersdk.ProvisionerJobLog {
192- ID : int64 (versionCallCount . Load () ),
200+ ID : int64 (call ),
193201 Output : "log line" ,
194202 })
195203 _ = conn .Close (websocket .StatusNormalClosure , "done" )
@@ -219,4 +227,110 @@ func TestWaitForJob_SucceedsOnRetry(t *testing.T) {
219227 logs , err := waitForJob (context .Background (), client , version )
220228 require .NoError (t , err )
221229 require .Len (t , logs , 2 )
230+ require .Equal (t , int64 (1 ), logs [0 ].ID )
231+ require .Equal (t , int64 (2 ), logs [1 ].ID )
232+ require .Equal (t , "1" , secondAfter .Load ())
233+ }
234+
235+ func TestWaitForJob_UsesAfterCursorAcrossRetries (t * testing.T ) {
236+ t .Parallel ()
237+ versionID := uuid .New ()
238+ var versionCallCount atomic.Int32
239+ var wsCallCount atomic.Int32
240+ var secondAfter atomic.Value
241+ secondAfter .Store ("" )
242+
243+ handler := http .NewServeMux ()
244+ handler .HandleFunc ("/api/v2/templateversions/" , func (w http.ResponseWriter , r * http.Request ) {
245+ if strings .Contains (r .URL .RawQuery , "follow" ) {
246+ call := wsCallCount .Add (1 )
247+ if call == 2 {
248+ secondAfter .Store (r .URL .Query ().Get ("after" ))
249+ }
250+
251+ conn , err := websocket .Accept (w , r , nil )
252+ if err != nil {
253+ http .Error (w , err .Error (), http .StatusInternalServerError )
254+ return
255+ }
256+ ctx := r .Context ()
257+ if call == 1 {
258+ _ = wsjson .Write (ctx , conn , codersdk.ProvisionerJobLog {ID : 1 , Output : "log 1" })
259+ _ = wsjson .Write (ctx , conn , codersdk.ProvisionerJobLog {ID : 2 , Output : "log 2" })
260+ _ = wsjson .Write (ctx , conn , codersdk.ProvisionerJobLog {ID : 3 , Output : "log 3" })
261+ } else {
262+ _ = wsjson .Write (ctx , conn , codersdk.ProvisionerJobLog {ID : 4 , Output : "log 4" })
263+ _ = wsjson .Write (ctx , conn , codersdk.ProvisionerJobLog {ID : 5 , Output : "log 5" })
264+ }
265+ _ = conn .Close (websocket .StatusNormalClosure , "done" )
266+ return
267+ }
268+
269+ count := versionCallCount .Add (1 )
270+ status := codersdk .ProvisionerJobRunning
271+ if count >= 2 {
272+ status = codersdk .ProvisionerJobSucceeded
273+ }
274+ w .Header ().Set ("Content-Type" , "application/json" )
275+ _ = json .NewEncoder (w ).Encode (codersdk.TemplateVersion {
276+ ID : versionID ,
277+ Job : codersdk.ProvisionerJob {Status : status },
278+ })
279+ })
280+
281+ srv := httptest .NewServer (handler )
282+ t .Cleanup (srv .Close )
283+ srvURL , err := url .Parse (srv .URL )
284+ require .NoError (t , err )
285+ client := codersdk .New (srvURL )
286+
287+ version := & codersdk.TemplateVersion {ID : versionID }
288+ logs , err := waitForJob (context .Background (), client , version )
289+ require .NoError (t , err )
290+ require .Len (t , logs , 5 )
291+ for i , log := range logs {
292+ require .Equal (t , int64 (i + 1 ), log .ID )
293+ }
294+ require .Equal (t , int32 (2 ), wsCallCount .Load ())
295+ require .Equal (t , "3" , secondAfter .Load ())
296+ }
297+
298+ func TestWaitForJob_ContextCanceledDuringBackoff (t * testing.T ) {
299+ t .Parallel ()
300+ versionID := uuid .New ()
301+ ctx , cancel := context .WithCancel (context .Background ())
302+ t .Cleanup (cancel )
303+ var statusCallCount atomic.Int32
304+
305+ handler := http .NewServeMux ()
306+ handler .HandleFunc ("/api/v2/templateversions/" , func (w http.ResponseWriter , r * http.Request ) {
307+ if strings .Contains (r .URL .RawQuery , "follow" ) {
308+ conn , err := websocket .Accept (w , r , nil )
309+ if err != nil {
310+ http .Error (w , err .Error (), http .StatusInternalServerError )
311+ return
312+ }
313+ _ = conn .Close (websocket .StatusNormalClosure , "done" )
314+ return
315+ }
316+
317+ w .Header ().Set ("Content-Type" , "application/json" )
318+ _ = json .NewEncoder (w ).Encode (codersdk.TemplateVersion {
319+ ID : versionID ,
320+ Job : codersdk.ProvisionerJob {Status : codersdk .ProvisionerJobRunning },
321+ })
322+ if statusCallCount .Add (1 ) == 1 {
323+ cancel ()
324+ }
325+ })
326+
327+ srv := httptest .NewServer (handler )
328+ t .Cleanup (srv .Close )
329+ srvURL , err := url .Parse (srv .URL )
330+ require .NoError (t , err )
331+ client := codersdk .New (srvURL )
332+
333+ version := & codersdk.TemplateVersion {ID : versionID }
334+ _ , err = waitForJob (ctx , client , version )
335+ require .ErrorIs (t , err , context .Canceled )
222336}
0 commit comments