From 0f71f36bf4406733624cab4f178fc163431a9669 Mon Sep 17 00:00:00 2001 From: Ritish Srivastava <121374890+Ritish134@users.noreply.github.com> Date: Mon, 30 Dec 2024 01:41:20 +0530 Subject: [PATCH] refactor recursiveScan Signed-off-by: Ritish Srivastava <121374890+Ritish134@users.noreply.github.com> --- pkg/action/scan.go | 363 ++++++++++++++++++++++----------------------- 1 file changed, 181 insertions(+), 182 deletions(-) diff --git a/pkg/action/scan.go b/pkg/action/scan.go index 2b29465ba..a8a17d12f 100644 --- a/pkg/action/scan.go +++ b/pkg/action/scan.go @@ -264,227 +264,226 @@ func CachedRules(ctx context.Context, fss []fs.FS) (*yara.Rules, error) { return compiledRuleCache, err } +type matchResult struct { + fr *malcontent.FileReport + err error +} // recursiveScan recursively YARA scans the configured paths - handling archives and OCI images. // //nolint:gocognit,cyclop,nestif // ignoring complexity of 101,38 func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report, error) { logger := clog.FromContext(ctx) - r := &malcontent.Report{ - Files: sync.Map{}, - } - if len(c.IgnoreTags) > 0 { - r.Filter = strings.Join(c.IgnoreTags, ",") - } + r := initializeReport(c) - // Store the first hit or miss result - type matchResult struct { - fr *malcontent.FileReport - err error - } - matchChan := make(chan matchResult, 1) - var matchOnce sync.Once + matchChan, cancel := setupMatchHandling(ctx, c, logger, r) + defer cancel() for _, scanPath := range c.ScanPaths { if c.Renderer != nil { c.Renderer.Scanning(ctx, scanPath) + } + processedPath, cleanup, err := prepareScanPath(ctx, c, scanPath, logger) + if cleanup != nil { + defer cleanup() } - imageURI := "" - ociExtractPath := "" - var err error - - if c.OCI { - // store the image URI for later use - imageURI = scanPath - ociExtractPath, err = archive.OCI(ctx, imageURI) - logger.Debug("oci image", slog.Any("scanPath", scanPath), slog.Any("ociExtractPath", ociExtractPath)) - if err != nil { - return nil, fmt.Errorf("failed to prepare OCI image for scanning: %w", err) - } - scanPath = ociExtractPath - } - - paths, err := findFilesRecursively(ctx, scanPath) if err != nil { - if len(c.ScanPaths) == 1 { - return nil, fmt.Errorf("find: %w", err) - } - // try to scan remaining scan paths - logger.Errorf("find failed: %v", err) + logger.Errorf("failed to prepare scan path %s: %v", scanPath, err) continue } - maxConcurrency := c.Concurrency - if maxConcurrency < 1 { - maxConcurrency = 1 + err = processPaths(ctx, c, processedPath, logger, r, matchChan) + if err != nil { + return finalizeReport(c, r, err, matchChan) } + } + + return finalizeReport(c, r, nil, matchChan) +} - // path refers to a real local path, not the requested scanPath - pc := make(chan string, len(paths)) - for _, path := range paths { - pc <- path +// Initialize the report structure +func initializeReport(c malcontent.Config) *malcontent.Report { + r := &malcontent.Report{Files: sync.Map{}} + if len(c.IgnoreTags) > 0 { + r.Filter = strings.Join(c.IgnoreTags, ",") + } + return r +} + +// Setup match handling for early exits based on hits or misses +func setupMatchHandling(ctx context.Context, c malcontent.Config, logger *clog.Logger, r *malcontent.Report) (chan matchResult, context.CancelFunc) { + matchChan := make(chan matchResult, 1) + scanCtx, cancel := context.WithCancel(ctx) + + go func() { + select { + case match := <-matchChan: + handleMatch(ctx, c, logger, r, match) + cancel() + case <-scanCtx.Done(): } - close(pc) + }() - handleArchive := func(path string) error { - frs, err := processArchive(ctx, c, c.RuleFS, path, logger) - if err != nil { - logger.Errorf("unable to process %s: %v", path, err) - } + return matchChan, cancel +} - if !c.OCI && (c.ExitFirstHit || c.ExitFirstMiss) { - match, err := exitIfHitOrMiss(frs, path, c.ExitFirstHit, c.ExitFirstMiss) - if err != nil { - matchOnce.Do(func() { - matchChan <- matchResult{fr: match, err: err} - }) - return err - } - } +// Prepare the scan path (handle OCI images or plain paths) +func prepareScanPath(ctx context.Context, c malcontent.Config, scanPath string, logger *clog.Logger) (string, func(), error) { + if c.OCI { + // Store the original image URI for reference + imageURI := scanPath - if frs != nil { - frs.Range(func(key, value any) bool { - if key == nil || value == nil { - return true - } - if k, ok := key.(string); ok { - if fr, ok := value.(*malcontent.FileReport); ok { - if len(c.TrimPrefixes) > 0 { - k = report.TrimPrefixes(k, c.TrimPrefixes) - } - r.Files.Store(k, fr) - if c.Renderer != nil && r.Diff == nil && fr.RiskScore >= c.MinFileRisk { - if err := c.Renderer.File(ctx, fr); err != nil { - logger.Errorf("render error: %v", err) - } - } - } - } - return true - }) - } - return nil - } + // Extract the OCI image + ociExtractPath, err := archive.OCI(ctx, imageURI) + if err != nil { + return "", nil, fmt.Errorf("failed to prepare OCI image for scanning: %w", err) + } - handleFile := func(path string) error { - trimPath := "" - if c.OCI { - scanPath = imageURI - trimPath = ociExtractPath - } + // Log debug information + logger.Debug("OCI image prepared", + slog.Any("scanPath", scanPath), + slog.Any("ociExtractPath", ociExtractPath), + ) - fr, err := processFile(ctx, c, c.RuleFS, path, scanPath, trimPath, logger) - if err != nil { - if len(c.TrimPrefixes) > 0 { - path = report.TrimPrefixes(path, c.TrimPrefixes) - } - r.Files.Store(path, &malcontent.FileReport{}) - return fmt.Errorf("process: %w", err) - } - if fr == nil { - return nil - } + // Return the extracted path and a cleanup function + return ociExtractPath, func() { os.RemoveAll(ociExtractPath) }, nil + } - if !c.OCI && (c.ExitFirstHit || c.ExitFirstMiss) { - var frMap sync.Map - frMap.Store(path, fr) - match, err := exitIfHitOrMiss(&frMap, path, c.ExitFirstHit, c.ExitFirstMiss) - if err != nil { - matchOnce.Do(func() { - matchChan <- matchResult{fr: match, err: err} - }) - return err - } - } + // Non-OCI paths are returned as-is + return scanPath, nil, nil - if len(c.TrimPrefixes) > 0 { - path = report.TrimPrefixes(path, c.TrimPrefixes) - } - r.Files.Store(path, fr) - if c.Renderer != nil && r.Diff == nil && fr.RiskScore >= c.MinFileRisk { - if err := c.Renderer.File(ctx, fr); err != nil { - return fmt.Errorf("render: %w", err) - } - } - return nil - } +} - scanCtx, cancel := context.WithCancel(ctx) - var g errgroup.Group - g.SetLimit(maxConcurrency) +// Process paths (files or directories) concurrently +func processPaths(ctx context.Context, c malcontent.Config, scanPath string, logger *clog.Logger, r *malcontent.Report, matchChan chan matchResult) error { + paths, err := findFilesRecursively(ctx, scanPath) + if err != nil { + return fmt.Errorf("find files: %w", err) + } - // Poll the match channel for the first hit or miss - go func() { - select { - case match := <-matchChan: - if match.fr != nil && c.Renderer != nil && match.fr.RiskScore >= c.MinFileRisk { - if err := c.Renderer.File(ctx, match.fr); err != nil { - logger.Errorf("render error: %v", err) - } - } - cancel() - case <-scanCtx.Done(): - return + pc := make(chan string, len(paths)) + for _, path := range paths { + pc <- path + } + close(pc) + + return processWithConcurrency(ctx, c, pc, logger, r, matchChan) +} + +// Manage concurrency for processing paths +func processWithConcurrency(ctx context.Context, c malcontent.Config, pc chan string, logger *clog.Logger, r *malcontent.Report, matchChan chan matchResult) error { + var g errgroup.Group + g.SetLimit(maxConcurrency(c)) + + for path := range pc { + path := path // avoid closure capture + g.Go(func() error { + if programkind.IsSupportedArchive(path) { + return handleArchive(ctx, c, path, logger, r, matchChan) } - }() + return processSingleFile(ctx, c, path, logger, r, matchChan) + }) + } + + return g.Wait() +} + +// Maximum concurrency setting +func maxConcurrency(c malcontent.Config) int { + if c.Concurrency < 1 { + return 1 + } + return c.Concurrency +} - for path := range pc { - g.Go(func() error { - select { - case <-scanCtx.Done(): - return scanCtx.Err() - default: - if programkind.IsSupportedArchive(path) { - return handleArchive(path) - } - return handleFile(path) - } - }) +// Finalize the report, handle matches if any +func finalizeReport(c malcontent.Config, r *malcontent.Report, err error, matchChan chan matchResult) (*malcontent.Report, error) { + select { + case match := <-matchChan: + if match.fr != nil { + r.Files.Store(match.fr.Path, match.fr) } + return r, match.err + default: + return r, err + } +} + +// Handle individual files +func processSingleFile(ctx context.Context, c malcontent.Config, path string, logger *clog.Logger, r *malcontent.Report, matchChan chan matchResult) error { + fr, err := processFile(ctx, c, c.RuleFS, path, "", "", logger) + if err != nil { + r.Files.Store(path, &malcontent.FileReport{}) + return fmt.Errorf("process file: %w", err) + } + if fr == nil { + return nil + } + + storeFileReport(c, r, path, fr) + handleExitFirst(ctx, c, path, fr, matchChan) + return nil +} + +// Handle archives +func handleArchive(ctx context.Context, c malcontent.Config, path string, logger *clog.Logger, r *malcontent.Report, matchChan chan matchResult) error { + frs, err := processArchive(ctx, c, c.RuleFS, path, logger) + if err != nil { + logger.Errorf("unable to process archive %s: %v", path, err) + return err + } - if err := g.Wait(); err != nil { - if c.OCI { - if cleanErr := os.RemoveAll(ociExtractPath); cleanErr != nil { - logger.Errorf("remove %s: %v", scanPath, cleanErr) - } + storeArchiveReports(c, r, frs) + handleExitFirst(ctx, c, path, nil, matchChan) + return nil +} + +// Store individual file reports +func storeFileReport(c malcontent.Config, r *malcontent.Report, path string, fr *malcontent.FileReport) { + if len(c.TrimPrefixes) > 0 { + path = report.TrimPrefixes(path, c.TrimPrefixes) + } + r.Files.Store(path, fr) +} + +// Store archive file reports +func storeArchiveReports(c malcontent.Config, r *malcontent.Report, frs *sync.Map) { + frs.Range(func(key, value any) bool { + if key == nil || value == nil { + return true + } + if k, ok := key.(string); ok { + if fr, ok := value.(*malcontent.FileReport); ok { + r.Files.Store(k, fr) } + } + return true + }) +} +// Handle early exits on first hit/miss +func handleExitFirst(ctx context.Context, c malcontent.Config, path string, fr *malcontent.FileReport, matchChan chan matchResult) { + if c.ExitFirstHit || c.ExitFirstMiss { + var frMap sync.Map + if fr != nil { + frMap.Store(path, fr) + } + match, err := exitIfHitOrMiss(&frMap, path, c.ExitFirstHit, c.ExitFirstMiss) + if err != nil { select { - case match := <-matchChan: - r := &malcontent.Report{ - Files: sync.Map{}, - } - if match.fr != nil { - if len(c.TrimPrefixes) > 0 { - match.fr.Path = report.TrimPrefixes(match.fr.Path, c.TrimPrefixes) - } - r.Files.Store(match.fr.Path, match.fr) - } - return r, match.err + case matchChan <- matchResult{fr: match, err: err}: default: - return r, err } } + } +} - // OCI images hadle their match his/miss logic per scanPath - if c.OCI { - match, err := exitIfHitOrMiss(&r.Files, imageURI, c.ExitFirstHit, c.ExitFirstMiss) - if err != nil && c.Renderer != nil && match.RiskScore >= c.MinFileRisk { - if match != nil && c.Renderer != nil && match.RiskScore >= c.MinFileRisk { - if renderErr := c.Renderer.File(ctx, match); renderErr != nil { - logger.Errorf("render error: %v", renderErr) - } - } - cancel() - return r, err - } - - if err := os.RemoveAll(ociExtractPath); err != nil { - logger.Errorf("remove %s: %v", scanPath, err) - } +// Handle match rendering +func handleMatch(ctx context.Context, c malcontent.Config, logger *clog.Logger, r *malcontent.Report, match matchResult) { + if match.fr != nil && c.Renderer != nil && match.fr.RiskScore >= c.MinFileRisk { + if err := c.Renderer.File(ctx, match.fr); err != nil { + logger.Errorf("render error: %v", err) } - cancel() - } // loop: next scan path - return r, nil + } } // processArchive extracts and scans a single archive file.