Skip to content

Commit aa16543

Browse files
committed
Fix domain enum only option
1 parent 05d8fe2 commit aa16543

2 files changed

Lines changed: 101 additions & 55 deletions

File tree

internal/collector/collector.go

Lines changed: 62 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ type Config struct {
5757
Debug bool
5858

5959
// Collection options
60-
DomainEnumOnly bool
61-
SkipLinkedServerEnum bool
62-
CollectFromLinkedServers bool
63-
SkipPrivateAddress bool
64-
ScanAllComputers bool
65-
SkipADNodeCreation bool
60+
DomainEnumOnly bool
61+
SkipLinkedServerEnum bool
62+
CollectFromLinkedServers bool
63+
SkipPrivateAddress bool
64+
ScanAllComputers bool
65+
SkipADNodeCreation bool
6666
DisableNontraversableEdges bool
6767
DisablePossibleEdges bool
6868
SkipIPDedupe bool // Skip DNS-based IP deduplication of targets
@@ -289,64 +289,74 @@ func (c *Collector) Run() error {
289289
return fmt.Errorf("no servers to process")
290290
}
291291

292-
c.config.Logger.Info("Processing SQL Servers", "count", len(c.serversToProcess))
293-
c.config.Logger.Log(context.Background(), logging.LevelVerbose, "Memory usage", "usage", c.getMemoryUsage())
294-
295-
// Track all processed servers to avoid duplicates
296-
processedServers := make(map[string]bool)
297-
298-
// Process servers (concurrently if workers > 0)
299-
if c.config.Workers > 0 {
300-
c.processServersConcurrently()
301-
// Mark all initial servers as processed
292+
if c.config.DomainEnumOnly {
293+
sort.Slice(c.serversToProcess, func(i, j int) bool {
294+
return strings.ToLower(c.serversToProcess[i].ConnectionString) < strings.ToLower(c.serversToProcess[j].ConnectionString)
295+
})
296+
c.config.Logger.Info("Domain enumeration only enabled; skipping MSSQL collection", "count", len(c.serversToProcess))
302297
for _, server := range c.serversToProcess {
303-
processedServers[strings.ToLower(server.Hostname)] = true
298+
c.config.Logger.Info("Discovered server", "server", server.ConnectionString, "objectID", server.ObjectIdentifier)
304299
}
305300
} else {
306-
// Sequential processing
307-
for i, server := range c.serversToProcess {
308-
log := c.config.Logger.With("target", server.ConnectionString)
309-
log.Info("Processing server", "progress", fmt.Sprintf("%d/%d", i+1, len(c.serversToProcess)))
310-
processedServers[strings.ToLower(server.Hostname)] = true
311-
312-
if err := c.processServer(server); err != nil {
313-
log.Warn("Failed to process server", "error", err)
314-
// Continue with other servers
301+
c.config.Logger.Info("Processing SQL Servers", "count", len(c.serversToProcess))
302+
c.config.Logger.Log(context.Background(), logging.LevelVerbose, "Memory usage", "usage", c.getMemoryUsage())
303+
304+
// Track all processed servers to avoid duplicates
305+
processedServers := make(map[string]bool)
306+
307+
// Process servers (concurrently if workers > 0)
308+
if c.config.Workers > 0 {
309+
c.processServersConcurrently()
310+
// Mark all initial servers as processed
311+
for _, server := range c.serversToProcess {
312+
processedServers[strings.ToLower(server.Hostname)] = true
313+
}
314+
} else {
315+
// Sequential processing
316+
for i, server := range c.serversToProcess {
317+
log := c.config.Logger.With("target", server.ConnectionString)
318+
log.Info("Processing server", "progress", fmt.Sprintf("%d/%d", i+1, len(c.serversToProcess)))
319+
processedServers[strings.ToLower(server.Hostname)] = true
320+
321+
if err := c.processServer(server); err != nil {
322+
log.Warn("Failed to process server", "error", err)
323+
// Continue with other servers
324+
}
315325
}
316326
}
317-
}
318327

319-
// Process linked servers recursively if enabled
320-
if c.config.CollectFromLinkedServers {
321-
c.processLinkedServersQueue(processedServers)
322-
}
328+
// Process linked servers recursively if enabled
329+
if c.config.CollectFromLinkedServers {
330+
c.processLinkedServersQueue(processedServers)
331+
}
323332

324-
// Write accumulated AD nodes to separate files (computers.json, users.json, groups.json)
325-
if !c.config.SkipADNodeCreation {
326-
if err := c.writeADFiles(); err != nil {
327-
return fmt.Errorf("failed to write AD files: %w", err)
333+
// Write accumulated AD nodes to separate files (computers.json, users.json, groups.json)
334+
if !c.config.SkipADNodeCreation {
335+
if err := c.writeADFiles(); err != nil {
336+
return fmt.Errorf("failed to write AD files: %w", err)
337+
}
328338
}
329-
}
330339

331-
// Create zip file
332-
if len(c.outputFiles) > 0 {
333-
var err error
334-
zipPath, err = c.createZipFile()
335-
if err != nil {
336-
return fmt.Errorf("failed to create zip file: %w", err)
340+
// Create zip file
341+
if len(c.outputFiles) > 0 {
342+
var err error
343+
zipPath, err = c.createZipFile()
344+
if err != nil {
345+
return fmt.Errorf("failed to create zip file: %w", err)
346+
}
347+
c.config.Logger.Info("Output written", "path", zipPath)
348+
} else {
349+
c.config.Logger.Info("No data collected - no output file created")
337350
}
338-
c.config.Logger.Info("Output written", "path", zipPath)
339-
} else {
340-
c.config.Logger.Info("No data collected - no output file created")
341-
}
342351

343-
// Create separate zip for per-target log files
344-
if c.config.LogPerTarget && len(c.logFiles) > 0 {
345-
logsZipPath, err := c.createLogsZipFile()
346-
if err != nil {
347-
return fmt.Errorf("failed to create logs zip file: %w", err)
352+
// Create separate zip for per-target log files
353+
if c.config.LogPerTarget && len(c.logFiles) > 0 {
354+
logsZipPath, err := c.createLogsZipFile()
355+
if err != nil {
356+
return fmt.Errorf("failed to create logs zip file: %w", err)
357+
}
358+
c.config.Logger.Info("Per-target logs written", "path", logsZipPath)
348359
}
349-
c.config.Logger.Info("Per-target logs written", "path", logsZipPath)
350360
}
351361
} else {
352362
c.config.Logger.Info("Skipping collection (--skip-collection)")

internal/collector/collector_test.go

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -777,9 +777,9 @@ func TestOutputFormat(t *testing.T) {
777777

778778
// Write a test edge
779779
edge := &bloodhound.Edge{
780-
Start: bloodhound.EdgeEndpoint{Value: "source-1"},
781-
End: bloodhound.EdgeEndpoint{Value: "target-1"},
782-
Kind: "MSSQL_Contains",
780+
Start: bloodhound.EdgeEndpoint{Value: "source-1"},
781+
End: bloodhound.EdgeEndpoint{Value: "target-1"},
782+
Kind: "MSSQL_Contains",
783783
Properties: map[string]interface{}{},
784784
}
785785
if err := writer.WriteEdge(edge); err != nil {
@@ -1035,3 +1035,39 @@ func TestSkipIPDedupe(t *testing.T) {
10351035
t.Fatalf("expected 2 servers (dedupe skipped), got %d", len(c.serversToProcess))
10361036
}
10371037
}
1038+
1039+
func TestRunDomainEnumOnlySkipsCollection(t *testing.T) {
1040+
tmpDir := t.TempDir()
1041+
1042+
config := &Config{
1043+
TempDir: tmpDir,
1044+
ServerInstance: "sql.example.com",
1045+
DomainEnumOnly: true,
1046+
SkipIPDedupe: true,
1047+
}
1048+
1049+
c, err := New(config)
1050+
if err != nil {
1051+
t.Fatalf("failed to create collector: %v", err)
1052+
}
1053+
1054+
if err := c.Run(); err != nil {
1055+
t.Fatalf("expected domain-enum-only run to succeed, got error: %v", err)
1056+
}
1057+
1058+
if len(c.serversToProcess) != 1 {
1059+
t.Fatalf("expected 1 discovered server, got %d", len(c.serversToProcess))
1060+
}
1061+
1062+
if len(c.outputFiles) != 0 {
1063+
t.Fatalf("expected no output files when domain-enum-only is enabled, got %d", len(c.outputFiles))
1064+
}
1065+
1066+
entries, err := os.ReadDir(tmpDir)
1067+
if err != nil {
1068+
t.Fatalf("failed to read temp dir: %v", err)
1069+
}
1070+
if len(entries) != 0 {
1071+
t.Fatalf("expected no files to be created, found %d", len(entries))
1072+
}
1073+
}

0 commit comments

Comments
 (0)