@@ -12,6 +12,7 @@ import (
1212 "github.com/warpdotdev/oz-agent-worker/internal/common"
1313 "github.com/warpdotdev/oz-agent-worker/internal/log"
1414 "github.com/warpdotdev/oz-agent-worker/internal/types"
15+ "golang.org/x/sync/semaphore"
1516)
1617
1718const (
@@ -25,12 +26,13 @@ const (
2526)
2627
2728type Config struct {
28- APIKey string
29- WorkerID string
30- WebSocketURL string
31- ServerRootURL string
32- LogLevel string
33- BackendType string // "docker" or "direct"
29+ APIKey string
30+ WorkerID string
31+ WebSocketURL string
32+ ServerRootURL string
33+ LogLevel string
34+ BackendType string // "docker" or "direct"
35+ MaxConcurrentTasks int // 0 means unlimited
3436
3537 // Backend-specific configs. Only the one matching BackendType should be set.
3638 Docker * DockerBackendConfig
@@ -49,6 +51,7 @@ type Worker struct {
4951 activeTasks map [string ]context.CancelFunc
5052 tasksMutex sync.Mutex
5153 backend Backend
54+ taskSemaphore * semaphore.Weighted // nil when unlimited
5255}
5356
5457func New (ctx context.Context , config Config ) (* Worker , error ) {
@@ -79,6 +82,11 @@ func New(ctx context.Context, config Config) (*Worker, error) {
7982 return nil , err
8083 }
8184
85+ var taskSemaphore * semaphore.Weighted
86+ if config .MaxConcurrentTasks > 0 {
87+ taskSemaphore = semaphore .NewWeighted (int64 (config .MaxConcurrentTasks ))
88+ }
89+
8290 return & Worker {
8391 config : config ,
8492 ctx : workerCtx ,
@@ -87,6 +95,7 @@ func New(ctx context.Context, config Config) (*Worker, error) {
8795 sendChan : make (chan []byte , 256 ),
8896 activeTasks : make (map [string ]context.CancelFunc ),
8997 backend : backend ,
98+ taskSemaphore : taskSemaphore ,
9099 }, nil
91100}
92101
@@ -298,6 +307,17 @@ func (w *Worker) handleMessage(message []byte) {
298307func (w * Worker ) handleTaskAssignment (assignment * types.TaskAssignmentMessage ) {
299308 log .Infof (w .ctx , "Received task assignment: taskID=%s, title=%s" , assignment .TaskID , assignment .Task .Title )
300309
310+ // Check concurrency limit before claiming the task.
311+ if w .taskSemaphore != nil {
312+ if ! w .taskSemaphore .TryAcquire (1 ) {
313+ log .Warnf (w .ctx , "Rejecting task %s: worker at maximum concurrency (%d)" , assignment .TaskID , w .config .MaxConcurrentTasks )
314+ if err := w .sendTaskRejected (assignment .TaskID , "worker at maximum concurrency" ); err != nil {
315+ log .Errorf (w .ctx , "Failed to send task rejected message: %v" , err )
316+ }
317+ return
318+ }
319+ }
320+
301321 // It's important to update the task state to claimed as the task lifecycle treats this as a dependency to advance to further states.
302322 if err := w .sendTaskClaimed (assignment .TaskID ); err != nil {
303323 log .Errorf (w .ctx , "Failed to send task claimed message: %v" , err )
@@ -379,6 +399,10 @@ func (w *Worker) executeTask(ctx context.Context, assignment *types.TaskAssignme
379399 w .tasksMutex .Lock ()
380400 delete (w .activeTasks , assignment .TaskID )
381401 w .tasksMutex .Unlock ()
402+
403+ if w .taskSemaphore != nil {
404+ w .taskSemaphore .Release (1 )
405+ }
382406 }()
383407
384408 taskID := assignment .TaskID
@@ -420,6 +444,30 @@ func (w *Worker) sendTaskClaimed(taskID string) error {
420444 return w .sendMessage (msgBytes )
421445}
422446
447+ func (w * Worker ) sendTaskRejected (taskID , reason string ) error {
448+ rejectedMsg := types.TaskRejectedMessage {
449+ TaskID : taskID ,
450+ Reason : reason ,
451+ }
452+
453+ data , err := json .Marshal (rejectedMsg )
454+ if err != nil {
455+ return fmt .Errorf ("failed to marshal task rejected message: %w" , err )
456+ }
457+
458+ msg := types.WebSocketMessage {
459+ Type : types .MessageTypeTaskRejected ,
460+ Data : data ,
461+ }
462+
463+ msgBytes , err := json .Marshal (msg )
464+ if err != nil {
465+ return fmt .Errorf ("failed to marshal websocket message: %w" , err )
466+ }
467+
468+ return w .sendMessage (msgBytes )
469+ }
470+
423471func (w * Worker ) sendTaskFailed (taskID , message string ) error {
424472 failedMsg := types.TaskFailedMessage {
425473 TaskID : taskID ,
0 commit comments