|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "sync" |
| 7 | + "sync/atomic" |
| 8 | + "time" |
| 9 | + |
| 10 | + "github.com/google/uuid" |
| 11 | + "github.com/threatwinds/go-sdk/catcher" |
| 12 | + "github.com/threatwinds/go-sdk/plugins" |
| 13 | + "github.com/utmstack/UTMStack/plugins/soc-ai/config" |
| 14 | + "github.com/utmstack/UTMStack/plugins/soc-ai/elastic" |
| 15 | + "github.com/utmstack/UTMStack/plugins/soc-ai/utils" |
| 16 | +) |
| 17 | + |
| 18 | +// AlertQueueItem represents an item in the processing queue |
| 19 | +type AlertQueueItem struct { |
| 20 | + Alert *plugins.Alert |
| 21 | + Timestamp time.Time |
| 22 | +} |
| 23 | + |
| 24 | +// AlertQueue manages the alert processing queue with workers |
| 25 | +type AlertQueue struct { |
| 26 | + queue chan *AlertQueueItem |
| 27 | + workers int |
| 28 | + ctx context.Context |
| 29 | + cancel context.CancelFunc |
| 30 | + wg sync.WaitGroup |
| 31 | + |
| 32 | + // Metrics |
| 33 | + processedCount int64 |
| 34 | + droppedCount int64 |
| 35 | + errorCount int64 |
| 36 | + queueSize int64 |
| 37 | + |
| 38 | + // Track consecutive drops for critical alerts |
| 39 | + consecutiveDrops int64 |
| 40 | + lastDropAlert time.Time |
| 41 | +} |
| 42 | + |
| 43 | +// Global queue instance |
| 44 | +var alertQueue *AlertQueue |
| 45 | + |
| 46 | +const ( |
| 47 | + DefaultQueueSize = 1000 |
| 48 | + DefaultWorkerCount = 5 |
| 49 | + QueueFullTimeout = 100 * time.Millisecond |
| 50 | +) |
| 51 | + |
| 52 | +func InitializeQueue() { |
| 53 | + ctx, cancel := context.WithCancel(context.Background()) |
| 54 | + |
| 55 | + alertQueue = &AlertQueue{ |
| 56 | + queue: make(chan *AlertQueueItem, DefaultQueueSize), |
| 57 | + workers: DefaultWorkerCount, |
| 58 | + ctx: ctx, |
| 59 | + cancel: cancel, |
| 60 | + } |
| 61 | + |
| 62 | + for i := range DefaultWorkerCount { |
| 63 | + alertQueue.wg.Add(1) |
| 64 | + go alertQueue.worker(i) |
| 65 | + } |
| 66 | + |
| 67 | + go alertQueue.metricsLogger() |
| 68 | + |
| 69 | + utils.Logger.LogF(100, "Alert queue initialized with %d workers and queue size %d", DefaultWorkerCount, DefaultQueueSize) |
| 70 | +} |
| 71 | + |
| 72 | +func EnqueueAlert(alert *plugins.Alert) bool { |
| 73 | + if alertQueue == nil { |
| 74 | + utils.Logger.LogF(500, "Alert queue not initialized") |
| 75 | + return false |
| 76 | + } |
| 77 | + |
| 78 | + item := &AlertQueueItem{ |
| 79 | + Alert: alert, |
| 80 | + Timestamp: time.Now(), |
| 81 | + } |
| 82 | + |
| 83 | + select { |
| 84 | + case alertQueue.queue <- item: |
| 85 | + atomic.AddInt64(&alertQueue.queueSize, 1) |
| 86 | + // Reset consecutive drops counter on successful enqueue |
| 87 | + atomic.StoreInt64(&alertQueue.consecutiveDrops, 0) |
| 88 | + utils.Logger.LogF(100, "Alert %s enqueued for processing", alert.Id) |
| 89 | + return true |
| 90 | + case <-time.After(QueueFullTimeout): |
| 91 | + atomic.AddInt64(&alertQueue.droppedCount, 1) |
| 92 | + atomic.AddInt64(&alertQueue.consecutiveDrops, 1) |
| 93 | + |
| 94 | + currentQueueSize := atomic.LoadInt64(&alertQueue.queueSize) |
| 95 | + totalDropped := atomic.LoadInt64(&alertQueue.droppedCount) |
| 96 | + consecutiveDrops := atomic.LoadInt64(&alertQueue.consecutiveDrops) |
| 97 | + |
| 98 | + _ = plugins.EnqueueNotification(plugins.TopicIntegrationFailure, plugins.Message{ |
| 99 | + Id: uuid.NewString(), |
| 100 | + Message: catcher.Error("Alert Dropped", nil, map[string]any{ |
| 101 | + "id": alert.Id, |
| 102 | + "total_dropped": totalDropped, |
| 103 | + "consecutive_drops": consecutiveDrops, |
| 104 | + }).Error(), |
| 105 | + }) |
| 106 | + utils.Logger.ErrorF("QUEUE FULL - Alert %s DROPPED! Queue size: %d/%d, Total dropped: %d, Consecutive: %d.", |
| 107 | + alert.Id, currentQueueSize, DefaultQueueSize, totalDropped, consecutiveDrops) |
| 108 | + |
| 109 | + elastic.RegisterError(fmt.Sprintf("Alert dropped - Queue FULL (%d/%d)", currentQueueSize, DefaultQueueSize), alert.Id) |
| 110 | + alertQueue.lastDropAlert = time.Now() |
| 111 | + return false |
| 112 | + } |
| 113 | +} |
| 114 | + |
| 115 | +func (aq *AlertQueue) worker(workerID int) { |
| 116 | + defer aq.wg.Done() |
| 117 | + |
| 118 | + for { |
| 119 | + select { |
| 120 | + case <-aq.ctx.Done(): |
| 121 | + return |
| 122 | + case item := <-aq.queue: |
| 123 | + if item == nil { |
| 124 | + continue |
| 125 | + } |
| 126 | + |
| 127 | + atomic.AddInt64(&aq.queueSize, -1) |
| 128 | + aq.processAlert(workerID, item) |
| 129 | + } |
| 130 | + } |
| 131 | +} |
| 132 | + |
| 133 | +func (aq *AlertQueue) processAlert(workerID int, item *AlertQueueItem) { |
| 134 | + startTime := time.Now() |
| 135 | + alert := cleanAlerts(alertToAlertFields(item.Alert)) |
| 136 | + |
| 137 | + utils.Logger.LogF(100, "Worker %d processing alert: %s", workerID, alert.ID) |
| 138 | + |
| 139 | + defer func() { |
| 140 | + if r := recover(); r != nil { |
| 141 | + atomic.AddInt64(&aq.errorCount, 1) |
| 142 | + _ = catcher.Error("recovered from panic in alert processing", nil, map[string]any{ |
| 143 | + "panic": r, |
| 144 | + "alert": alert.Name, |
| 145 | + "workerID": workerID, |
| 146 | + }) |
| 147 | + elastic.RegisterError(fmt.Sprintf("Panic in worker %d: %v", workerID, r), alert.ID) |
| 148 | + } |
| 149 | + }() |
| 150 | + |
| 151 | + if config.GetConfig() == nil || !config.GetConfig().ModuleActive { |
| 152 | + utils.Logger.LogF(100, "SOC-AI module is disabled, skipping alert: %s", alert.ID) |
| 153 | + atomic.AddInt64(&aq.processedCount, 1) |
| 154 | + return |
| 155 | + } |
| 156 | + |
| 157 | + if config.GetConfig().Provider == "openai" { |
| 158 | + if err := utils.ConnectionChecker(config.GPT_API_ENDPOINT); err != nil { |
| 159 | + atomic.AddInt64(&aq.errorCount, 1) |
| 160 | + _ = catcher.Error("Failed to establish internet connection", err, nil) |
| 161 | + elastic.RegisterError("Failed to establish internet connection", alert.ID) |
| 162 | + return |
| 163 | + } |
| 164 | + } |
| 165 | + |
| 166 | + err := sendRequestToLLM(&alert) |
| 167 | + if err != nil { |
| 168 | + atomic.AddInt64(&aq.errorCount, 1) |
| 169 | + elastic.RegisterError(err.Error(), alert.ID) |
| 170 | + return |
| 171 | + } |
| 172 | + |
| 173 | + err = processAlertToElastic(&alert) |
| 174 | + if err != nil { |
| 175 | + atomic.AddInt64(&aq.errorCount, 1) |
| 176 | + elastic.RegisterError(err.Error(), alert.ID) |
| 177 | + return |
| 178 | + } |
| 179 | + |
| 180 | + atomic.AddInt64(&aq.processedCount, 1) |
| 181 | + duration := time.Since(startTime) |
| 182 | + queueTime := startTime.Sub(item.Timestamp) |
| 183 | + |
| 184 | + utils.Logger.LogF(100, "Worker %d completed alert %s in %v (queue time: %v)", |
| 185 | + workerID, alert.ID, duration, queueTime) |
| 186 | +} |
| 187 | + |
| 188 | +func (aq *AlertQueue) metricsLogger() { |
| 189 | + ticker := time.NewTicker(time.Minute) |
| 190 | + defer ticker.Stop() |
| 191 | + |
| 192 | + for { |
| 193 | + select { |
| 194 | + case <-aq.ctx.Done(): |
| 195 | + return |
| 196 | + case <-ticker.C: |
| 197 | + processed := atomic.LoadInt64(&aq.processedCount) |
| 198 | + dropped := atomic.LoadInt64(&aq.droppedCount) |
| 199 | + errors := atomic.LoadInt64(&aq.errorCount) |
| 200 | + queueSize := atomic.LoadInt64(&aq.queueSize) |
| 201 | + |
| 202 | + utils.Logger.LogF(200, "Queue metrics - Processed: %d, Dropped: %d, Errors: %d, Current queue size: %d", |
| 203 | + processed, dropped, errors, queueSize) |
| 204 | + } |
| 205 | + } |
| 206 | +} |
0 commit comments