diff --git a/src/commands/main-action.test.ts b/src/commands/main-action.test.ts new file mode 100644 index 00000000..ddf7c1d0 --- /dev/null +++ b/src/commands/main-action.test.ts @@ -0,0 +1,251 @@ +import { createMainAction } from './main-action'; + +// eslint-disable-next-line @typescript-eslint/no-require-imports +jest.mock('../logger', () => require('../test-helpers/mock-logger.test-utils').loggerMockFactory()); +jest.mock('../docker-manager'); +jest.mock('../host-iptables'); +jest.mock('../cli-workflow'); +jest.mock('../redact-secrets'); +jest.mock('../option-parsers'); +jest.mock('./preflight'); +jest.mock('./signal-handler'); +jest.mock('./validate-options'); + +import { logger } from '../logger'; +import * as dockerManager from '../docker-manager'; +import * as hostIptables from '../host-iptables'; +import * as cliWorkflow from '../cli-workflow'; +import * as redactSecrets from '../redact-secrets'; +import * as optionParsers from '../option-parsers'; +import * as preflight from './preflight'; +import * as signalHandler from './signal-handler'; +import * as validateOptions from './validate-options'; + +const mockedLogger = logger as jest.Mocked; +const mockedDockerManager = dockerManager as jest.Mocked; +const mockedHostIptables = hostIptables as jest.Mocked; +const mockedCliWorkflow = cliWorkflow as jest.Mocked; +const mockedRedactSecrets = redactSecrets as jest.Mocked; +const mockedOptionParsers = optionParsers as jest.Mocked; +const mockedPreflight = preflight as jest.Mocked; +const mockedSignalHandler = signalHandler as jest.Mocked; +const mockedValidateOptions = validateOptions as jest.Mocked; + +/** Minimal WrapperConfig returned by the validateOptions mock. */ +const STUB_CONFIG = { + allowedDomains: ['github.com'], + blockedDomains: undefined, + agentCommand: 'echo hi', + logLevel: 'info', + keepContainers: false, + workDir: '/tmp/awf-test', + imageRegistry: 'ghcr.io/github/gh-aw-firewall', + imageTag: 'latest', + buildLocal: false, + dnsServers: ['8.8.8.8'], + awfDockerHost: undefined, + proxyLogsDir: undefined, + auditDir: undefined, + sessionStateDir: undefined, +} as unknown as import('../types').WrapperConfig; + +describe('createMainAction', () => { + let processExitSpy: jest.SpyInstance; + let consoleErrorSpy: jest.SpyInstance; + let getOptionValueSource: jest.Mock; + + beforeEach(() => { + jest.clearAllMocks(); + processExitSpy = jest.spyOn(process, 'exit').mockImplementation((code?: string | number | null) => { + if (code === 1) { + throw new Error(`process.exit: ${code}`); + } + return undefined as never; + }); + consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation(); + getOptionValueSource = jest.fn().mockReturnValue(undefined); + + // Default mock implementations + mockedPreflight.applyConfigFilePrecedence.mockImplementation(() => {}); + mockedValidateOptions.validateOptions.mockReturnValue(STUB_CONFIG); + mockedDockerManager.setAwfDockerHost.mockImplementation(() => {}); + mockedRedactSecrets.redactSecrets.mockImplementation((s: string) => s); + mockedOptionParsers.joinShellArgs.mockImplementation((args: string[]) => args.join(' ')); + mockedSignalHandler.registerSignalHandlers.mockImplementation(() => {}); + mockedCliWorkflow.runMainWorkflow.mockResolvedValue(0); + }); + + afterEach(() => { + processExitSpy.mockRestore(); + consoleErrorSpy.mockRestore(); + }); + + describe('when args is empty', () => { + it('exits with code 1 and prints usage error', async () => { + const action = createMainAction(getOptionValueSource); + await expect(action([], {})).rejects.toThrow('process.exit: 1'); + expect(processExitSpy).toHaveBeenCalledWith(1); + expect(mockedOptionParsers.joinShellArgs).not.toHaveBeenCalled(); + expect(consoleErrorSpy).toHaveBeenCalledWith( + expect.stringContaining('No command specified') + ); + }); + }); + + describe('when single arg is provided', () => { + it('uses the single arg as-is (preserves shell variables)', async () => { + const action = createMainAction(getOptionValueSource); + await action(['echo $HOME'], {}); + expect(mockedOptionParsers.joinShellArgs).not.toHaveBeenCalled(); + expect(mockedValidateOptions.validateOptions).toHaveBeenCalledWith( + expect.anything(), + 'echo $HOME' + ); + }); + }); + + describe('when multiple args are provided', () => { + it('joins args with joinShellArgs', async () => { + const action = createMainAction(getOptionValueSource); + await action(['curl', '-H', 'Auth: token', 'https://api.github.com'], {}); + expect(mockedOptionParsers.joinShellArgs).toHaveBeenCalledWith([ + 'curl', + '-H', + 'Auth: token', + 'https://api.github.com', + ]); + expect(mockedValidateOptions.validateOptions).toHaveBeenCalledWith( + expect.anything(), + 'curl -H Auth: token https://api.github.com' + ); + }); + }); + + describe('happy path', () => { + it('calls workflow steps and exits with 0', async () => { + mockedCliWorkflow.runMainWorkflow.mockResolvedValue(0); + const action = createMainAction(getOptionValueSource); + await action(['echo hi'], {}); + expect(mockedCliWorkflow.runMainWorkflow).toHaveBeenCalled(); + expect(processExitSpy).toHaveBeenCalledWith(0); + }); + + it('calls applyConfigFilePrecedence with options and resolver', async () => { + const options = { keepContainers: false }; + const action = createMainAction(getOptionValueSource); + await action(['echo hi'], options); + expect(mockedPreflight.applyConfigFilePrecedence).toHaveBeenCalledWith( + options, + getOptionValueSource + ); + }); + + it('calls setAwfDockerHost with config.awfDockerHost', async () => { + const configWithDockerHost = { ...STUB_CONFIG, awfDockerHost: '/var/run/docker.sock' }; + mockedValidateOptions.validateOptions.mockReturnValue( + configWithDockerHost as unknown as import('../types').WrapperConfig + ); + const action = createMainAction(getOptionValueSource); + await action(['echo hi'], {}); + expect(mockedDockerManager.setAwfDockerHost).toHaveBeenCalledWith('/var/run/docker.sock'); + }); + + it('registers signal handlers', async () => { + const action = createMainAction(getOptionValueSource); + await action(['echo hi'], {}); + expect(mockedSignalHandler.registerSignalHandlers).toHaveBeenCalled(); + }); + + it('logs allowed domains', async () => { + const action = createMainAction(getOptionValueSource); + await action(['echo hi'], {}); + expect(mockedLogger.info).toHaveBeenCalledWith( + expect.stringContaining('github.com') + ); + }); + + it('logs blocked domains when present', async () => { + const configWithBlocked = { + ...STUB_CONFIG, + blockedDomains: ['evil.com'], + }; + mockedValidateOptions.validateOptions.mockReturnValue( + configWithBlocked as unknown as import('../types').WrapperConfig + ); + const action = createMainAction(getOptionValueSource); + await action(['echo hi'], {}); + expect(mockedLogger.info).toHaveBeenCalledWith( + expect.stringContaining('evil.com') + ); + }); + + it('does not log blocked domains when empty', async () => { + const action = createMainAction(getOptionValueSource); + await action(['echo hi'], {}); + const blockedCalls = mockedLogger.info.mock.calls.filter( + (args) => String(args[0]).includes('Blocked domains') + ); + expect(blockedCalls).toHaveLength(0); + }); + }); + + describe('when runMainWorkflow returns non-zero exit code', () => { + it('exits with the non-zero code', async () => { + mockedCliWorkflow.runMainWorkflow.mockResolvedValue(42); + const action = createMainAction(getOptionValueSource); + await action(['curl https://example.com'], {}); + expect(processExitSpy).toHaveBeenCalledWith(42); + }); + }); + + describe('when runMainWorkflow throws', () => { + it('calls performCleanup and exits with code 1', async () => { + mockedCliWorkflow.runMainWorkflow.mockRejectedValue(new Error('docker failed')); + const action = createMainAction(getOptionValueSource); + await expect(action(['echo hi'], {})).rejects.toThrow('process.exit: 1'); + expect(mockedLogger.error).toHaveBeenCalledWith( + 'Fatal error:', + expect.any(Error) + ); + expect(mockedDockerManager.cleanup).toHaveBeenCalledWith( + STUB_CONFIG.workDir, + false, + STUB_CONFIG.proxyLogsDir, + STUB_CONFIG.auditDir, + STUB_CONFIG.sessionStateDir + ); + expect(mockedHostIptables.cleanupHostIptables).not.toHaveBeenCalled(); + expect(processExitSpy).toHaveBeenCalledWith(1); + }); + }); + + describe('redaction of sensitive config fields', () => { + it('does not log API keys in debug output', async () => { + const configWithKeys = { + ...STUB_CONFIG, + openaiApiKey: 'sk-secret', + anthropicApiKey: 'ant-secret', + copilotGithubToken: 'ghp-secret', + copilotApiKey: 'cop-secret', + geminiApiKey: 'gem-secret', + }; + mockedValidateOptions.validateOptions.mockReturnValue( + configWithKeys as unknown as import('../types').WrapperConfig + ); + const action = createMainAction(getOptionValueSource); + await action(['echo hi'], {}); + // Debug call should be made but without raw API keys + const debugCalls = mockedLogger.debug.mock.calls; + const configDebugCall = debugCalls.find((args) => + String(args[0]).includes('Configuration') + ); + expect(configDebugCall).toBeDefined(); + const serialized = String(configDebugCall?.[1]); + expect(serialized).not.toContain('sk-secret'); + expect(serialized).not.toContain('ant-secret'); + expect(serialized).not.toContain('ghp-secret'); + expect(serialized).not.toContain('cop-secret'); + expect(serialized).not.toContain('gem-secret'); + }); + }); +}); diff --git a/src/commands/subcommands.test.ts b/src/commands/subcommands.test.ts new file mode 100644 index 00000000..e9229cd2 --- /dev/null +++ b/src/commands/subcommands.test.ts @@ -0,0 +1,242 @@ +import { Command } from 'commander'; +import { registerSubcommands } from './subcommands'; + +// eslint-disable-next-line @typescript-eslint/no-require-imports +jest.mock('../logger', () => require('../test-helpers/mock-logger.test-utils').loggerMockFactory()); +jest.mock('./logs', () => ({ logsCommand: jest.fn().mockResolvedValue(undefined) })); +jest.mock('./logs-stats', () => ({ statsCommand: jest.fn().mockResolvedValue(undefined) })); +jest.mock('./logs-summary', () => ({ summaryCommand: jest.fn().mockResolvedValue(undefined) })); +jest.mock('./logs-audit', () => ({ auditCommand: jest.fn().mockResolvedValue(undefined) })); +jest.mock('./predownload', () => ({ predownloadCommand: jest.fn().mockResolvedValue(undefined) })); + +import { logger } from '../logger'; + +const mockedLogger = logger as jest.Mocked; + +/** + * Creates a fresh Commander program with subcommands registered. + * exitOverride() prevents process.exit from actually killing the test runner. + */ +function makeProgram(): Command { + const program = new Command('awf'); + program.exitOverride(); + registerSubcommands(program); + return program; +} + +describe('registerSubcommands', () => { + let processExitSpy: jest.SpyInstance; + + beforeEach(() => { + jest.clearAllMocks(); + processExitSpy = jest.spyOn(process, 'exit').mockImplementation(() => undefined as never); + }); + + afterEach(() => { + processExitSpy.mockRestore(); + }); + + describe('command registration', () => { + it('registers predownload subcommand on the program', () => { + const program = makeProgram(); + const names = program.commands.map((c) => c.name()); + expect(names).toContain('predownload'); + }); + + it('registers logs subcommand on the program', () => { + const program = makeProgram(); + const names = program.commands.map((c) => c.name()); + expect(names).toContain('logs'); + }); + + it('registers logs stats sub-subcommand', () => { + const program = makeProgram(); + const logsCmd = program.commands.find((c) => c.name() === 'logs')!; + const subNames = logsCmd.commands.map((c) => c.name()); + expect(subNames).toContain('stats'); + }); + + it('registers logs summary sub-subcommand', () => { + const program = makeProgram(); + const logsCmd = program.commands.find((c) => c.name() === 'logs')!; + const subNames = logsCmd.commands.map((c) => c.name()); + expect(subNames).toContain('summary'); + }); + + it('registers logs audit sub-subcommand', () => { + const program = makeProgram(); + const logsCmd = program.commands.find((c) => c.name() === 'logs')!; + const subNames = logsCmd.commands.map((c) => c.name()); + expect(subNames).toContain('audit'); + }); + }); + + describe('predownload defaults', () => { + it('sets default image-registry to ghcr.io/github/gh-aw-firewall', async () => { + const program = makeProgram(); + const predownload = program.commands.find((c) => c.name() === 'predownload')!; + await program.parseAsync(['node', 'awf', 'predownload'], { from: 'node' }); + expect(predownload.opts().imageRegistry).toBe('ghcr.io/github/gh-aw-firewall'); + }); + + it('sets default image-tag to latest', async () => { + const program = makeProgram(); + const predownload = program.commands.find((c) => c.name() === 'predownload')!; + await program.parseAsync(['node', 'awf', 'predownload'], { from: 'node' }); + expect(predownload.opts().imageTag).toBe('latest'); + }); + + it('sets default enable-api-proxy to false', async () => { + const program = makeProgram(); + const predownload = program.commands.find((c) => c.name() === 'predownload')!; + await program.parseAsync(['node', 'awf', 'predownload'], { from: 'node' }); + expect(predownload.opts().enableApiProxy).toBe(false); + }); + + it('sets default agent-image to default', async () => { + const program = makeProgram(); + const predownload = program.commands.find((c) => c.name() === 'predownload')!; + await program.parseAsync(['node', 'awf', 'predownload'], { from: 'node' }); + expect(predownload.opts().agentImage).toBe('default'); + }); + }); + + describe('validateFormat (via logs action)', () => { + it('exits with code 1 for invalid format in logs subcommand', async () => { + const program = makeProgram(); + await program.parseAsync(['node', 'awf', 'logs', '--format', 'invalid'], { from: 'node' }); + expect(processExitSpy).toHaveBeenCalledWith(1); + expect(mockedLogger.error).toHaveBeenCalledWith(expect.stringContaining('Invalid format')); + }); + + it('does not exit for valid format "raw" in logs subcommand', async () => { + const program = makeProgram(); + await program.parseAsync(['node', 'awf', 'logs', '--format', 'raw'], { from: 'node' }); + expect(processExitSpy).not.toHaveBeenCalled(); + }); + + it('does not exit for valid format "pretty" in logs subcommand', async () => { + const program = makeProgram(); + await program.parseAsync(['node', 'awf', 'logs', '--format', 'pretty'], { from: 'node' }); + expect(processExitSpy).not.toHaveBeenCalled(); + }); + + it('does not exit for valid format "json" in logs subcommand', async () => { + const program = makeProgram(); + await program.parseAsync(['node', 'awf', 'logs', '--format', 'json'], { from: 'node' }); + expect(processExitSpy).not.toHaveBeenCalled(); + }); + + it('exits with code 1 for invalid format in logs stats subcommand', async () => { + const program = makeProgram(); + const logsCmd = program.commands.find((c) => c.name() === 'logs')!; + const statsCmd = logsCmd.commands.find((c) => c.name() === 'stats')!; + await statsCmd.parseAsync(['node', 'awf', '--format', 'bogus'], { from: 'node' }); + expect(processExitSpy).toHaveBeenCalledWith(1); + expect(mockedLogger.error).toHaveBeenCalledWith(expect.stringContaining('Invalid format')); + }); + + it('does not exit for valid format in logs stats subcommand', async () => { + const program = makeProgram(); + await program.parseAsync(['node', 'awf', 'logs', 'stats', '--format', 'json'], { from: 'node' }); + expect(processExitSpy).not.toHaveBeenCalled(); + }); + + it('exits with code 1 for invalid format in logs summary subcommand', async () => { + const program = makeProgram(); + const logsCmd = program.commands.find((c) => c.name() === 'logs')!; + const summaryCmd = logsCmd.commands.find((c) => c.name() === 'summary')!; + await summaryCmd.parseAsync(['node', 'awf', '--format', 'bogus'], { from: 'node' }); + expect(processExitSpy).toHaveBeenCalledWith(1); + expect(mockedLogger.error).toHaveBeenCalledWith(expect.stringContaining('Invalid format')); + }); + + it('does not exit for valid format in logs summary subcommand', async () => { + const program = makeProgram(); + await program.parseAsync(['node', 'awf', 'logs', 'summary', '--format', 'markdown'], { from: 'node' }); + expect(processExitSpy).not.toHaveBeenCalled(); + }); + + it('exits with code 1 for invalid format in logs audit subcommand', async () => { + const program = makeProgram(); + const logsCmd = program.commands.find((c) => c.name() === 'logs')!; + const auditCmd = logsCmd.commands.find((c) => c.name() === 'audit')!; + await auditCmd.parseAsync(['node', 'awf', '--format', 'bogus'], { from: 'node' }); + expect(processExitSpy).toHaveBeenCalledWith(1); + expect(mockedLogger.error).toHaveBeenCalledWith(expect.stringContaining('Invalid format')); + }); + + it('does not exit for valid format "pretty" in logs audit subcommand', async () => { + const program = makeProgram(); + await program.parseAsync(['node', 'awf', 'logs', 'audit', '--format', 'pretty'], { from: 'node' }); + expect(processExitSpy).not.toHaveBeenCalled(); + }); + + it('exits with code 1 for invalid decision in logs audit subcommand', async () => { + const program = makeProgram(); + await program.parseAsync( + ['node', 'awf', 'logs', 'audit', '--format', 'pretty', '--decision', 'badvalue'], + { from: 'node' } + ); + expect(processExitSpy).toHaveBeenCalledWith(1); + expect(mockedLogger.error).toHaveBeenCalledWith( + expect.stringContaining('Invalid decision filter') + ); + }); + + it('does not exit for valid decision "allowed" in logs audit subcommand', async () => { + const program = makeProgram(); + await program.parseAsync( + ['node', 'awf', 'logs', 'audit', '--decision', 'allowed'], + { from: 'node' } + ); + expect(processExitSpy).not.toHaveBeenCalled(); + }); + + it('does not exit for valid decision "denied" in logs audit subcommand', async () => { + const program = makeProgram(); + await program.parseAsync( + ['node', 'awf', 'logs', 'audit', '--decision', 'denied'], + { from: 'node' } + ); + expect(processExitSpy).not.toHaveBeenCalled(); + }); + + it('warns when --with-pid is used without -f in logs subcommand', async () => { + const program = makeProgram(); + await program.parseAsync(['node', 'awf', 'logs', '--with-pid'], { from: 'node' }); + expect(mockedLogger.warn).toHaveBeenCalledWith( + expect.stringContaining('--with-pid only works with real-time streaming') + ); + }); + + it('does not warn when --with-pid is used with -f', async () => { + const program = makeProgram(); + await program.parseAsync(['node', 'awf', 'logs', '--with-pid', '-f'], { from: 'node' }); + expect(mockedLogger.warn).not.toHaveBeenCalled(); + }); + }); + + describe('predownload action error handling', () => { + it('exits with predownload error exitCode when predownload throws', async () => { + const { predownloadCommand } = await import('./predownload'); + const mockedPredownload = predownloadCommand as jest.Mock; + const err = Object.assign(new Error('pull failed'), { exitCode: 2 }); + mockedPredownload.mockRejectedValueOnce(err); + + const program = makeProgram(); + await program.parseAsync(['node', 'awf', 'predownload'], { from: 'node' }); + expect(processExitSpy).toHaveBeenCalledWith(2); + }); + + it('exits with code 1 when predownload throws without exitCode', async () => { + const { predownloadCommand } = await import('./predownload'); + const mockedPredownload = predownloadCommand as jest.Mock; + mockedPredownload.mockRejectedValueOnce(new Error('unknown')); + + const program = makeProgram(); + await program.parseAsync(['node', 'awf', 'predownload'], { from: 'node' }); + expect(processExitSpy).toHaveBeenCalledWith(1); + }); + }); +});