diff --git a/cmd/ipdex/file/file.go b/cmd/ipdex/file/file.go index f8cd5c8..90a6bb8 100644 --- a/cmd/ipdex/file/file.go +++ b/cmd/ipdex/file/file.go @@ -15,7 +15,6 @@ import ( "os" "path/filepath" "regexp" - "slices" "github.com/crowdsecurity/crowdsec/pkg/cticlient" "github.com/pterm/pterm" @@ -26,6 +25,40 @@ var ( ipRegex = regexp.MustCompile(`(?:[0-9]{1,3}\.){3}[0-9]{1,3}|[a-fA-F0-9:]+`) ) +func collectIPsFromFile(filePath string) ([]string, error) { + readFile, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer readFile.Close() + + ipsToProcess := make([]string, 0) + seenIPs := make(map[string]struct{}) + fileScanner := bufio.NewScanner(readFile) + fileScanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + fileScanner.Split(bufio.ScanLines) + + for fileScanner.Scan() { + line := fileScanner.Text() + ipsMatch := ipRegex.FindAllString(line, -1) + for _, ipAddr := range ipsMatch { + if !config.IsValidIP(ipAddr) { + continue + } + if _, exists := seenIPs[ipAddr]; exists { + continue + } + seenIPs[ipAddr] = struct{}{} + ipsToProcess = append(ipsToProcess, ipAddr) + } + } + if err := fileScanner.Err(); err != nil { + return nil, err + } + + return ipsToProcess, nil +} + func FileCommand(file string, forceRefresh bool, yes bool) { outputFormat := viper.GetString(config.OutputFormatOption) filepath, err := filepath.Abs(file) @@ -52,26 +85,13 @@ func FileCommand(file string, forceRefresh bool, yes bool) { reportExist = false } if !reportExist { - readFile, err := os.Open(filepath) + if outputFormat == display.HumanFormat { + style.Infof("Scanning file '%s' for IPs...", filepath) + } + ipsToProcess, err = collectIPsFromFile(filepath) if err != nil { style.Fatal(err.Error()) } - - fileScanner := bufio.NewScanner(readFile) - fileScanner.Split(bufio.ScanLines) - for fileScanner.Scan() { - line := fileScanner.Text() - ipsMatch := ipRegex.FindAllString(line, -1) - for _, ipAddr := range ipsMatch { - if slices.Contains(ipsToProcess, ipAddr) { - continue - } - if !config.IsValidIP(ipAddr) { - continue - } - ipsToProcess = append(ipsToProcess, ipAddr) - } - } nbIPToProcess = len(ipsToProcess) if nbIPToProcess == 0 { if outputFormat == display.HumanFormat { @@ -79,6 +99,9 @@ func FileCommand(file string, forceRefresh bool, yes bool) { } return } + if outputFormat == display.HumanFormat { + style.Infof("Found %d unique IPs.", nbIPToProcess) + } } else { for _, ip := range report.IPs { diff --git a/cmd/ipdex/file/file_test.go b/cmd/ipdex/file/file_test.go new file mode 100644 index 0000000..db14051 --- /dev/null +++ b/cmd/ipdex/file/file_test.go @@ -0,0 +1,29 @@ +package file + +import ( + "os" + "path/filepath" + "reflect" + "testing" +) + +func TestCollectIPsFromFileDeduplicatesAndKeepsValidIPs(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "ips.txt") + content := "1.1.1.1\ninvalid\n2.2.2.2 extra 1.1.1.1\n2001:4860:4860::8888\n999.1.1.1\n" + if err := os.WriteFile(filePath, []byte(content), 0o600); err != nil { + t.Fatalf("write temp file: %v", err) + } + + got, err := collectIPsFromFile(filePath) + if err != nil { + t.Fatalf("collectIPsFromFile returned error: %v", err) + } + + want := []string{"1.1.1.1", "2.2.2.2", "2001:4860:4860::8888"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected IP list\nwant: %#v\ngot: %#v", want, got) + } +}