From 2c680e2d849fd3a4192de2bbb2d1b1cc2ebfe473 Mon Sep 17 00:00:00 2001 From: pk910 Date: Tue, 8 Jul 2025 02:47:54 +0200 Subject: [PATCH 1/4] add AI payload scenario --- scenarios/aitx/README.md | 352 +++++++++++++ scenarios/aitx/ai_service.go | 713 +++++++++++++++++++++++++++ scenarios/aitx/aitx.go | 562 +++++++++++++++++++++ scenarios/aitx/feedback_collector.go | 140 ++++++ scenarios/aitx/geas_processor.go | 96 ++++ scenarios/aitx/payload_processor.go | 88 ++++ scenarios/aitx/payload_template.go | 231 +++++++++ scenarios/scenarios.go | 2 + 8 files changed, 2184 insertions(+) create mode 100644 scenarios/aitx/README.md create mode 100644 scenarios/aitx/ai_service.go create mode 100644 scenarios/aitx/aitx.go create mode 100644 scenarios/aitx/feedback_collector.go create mode 100644 scenarios/aitx/geas_processor.go create mode 100644 scenarios/aitx/payload_processor.go create mode 100644 scenarios/aitx/payload_template.go diff --git a/scenarios/aitx/README.md b/scenarios/aitx/README.md new file mode 100644 index 00000000..d21bdcc6 --- /dev/null +++ b/scenarios/aitx/README.md @@ -0,0 +1,352 @@ +# AI Transaction Generator (aitx) + +The AI Transaction Generator scenario leverages OpenRouter's API to generate diverse, AI-powered Ethereum transaction payloads for comprehensive stress testing and validation. + +## Overview + +This scenario uses AI models to create dynamic transaction patterns that go beyond static predefined scenarios. It supports multiple generation modes and includes a feedback loop to improve payload diversity over time. + +## Features + +- **AI-Powered Generation**: Uses OpenRouter API with configurable AI models (default: Claude 3.5 Sonnet) +- **Multiple Generation Modes**: Support for geas assembly, calldata, transfers, or mixed mode +- **Feedback Loop**: AI learns from transaction execution results to generate better payloads +- **Batch Processing**: Generates multiple payloads per API call to minimize costs +- **Placeholder System**: Dynamic parameter substitution for transaction variations +- **Safety Validation**: Prevents malicious code generation with built-in safety checks +- **Cost Management**: Configurable limits on API calls and token consumption + +## Configuration + +### Required Configuration + +```yaml +openrouter_api_key: "your-api-key-here" # OpenRouter API key (or set OPENROUTER_API_KEY env var) +``` + +### AI Configuration Options + +```yaml +# AI Model Settings +model: "anthropic/claude-3.5-sonnet" # AI model to use +test_direction: "focus on gas optimization" # Directional guidance for AI +payloads_per_request: 50 # Payloads generated per API call +max_ai_calls: 10 # Maximum API calls limit +max_tokens: 100000 # Maximum token consumption limit + +# Generation Settings +generation_mode: "mixed" # "geas", "calldata", "transfer", "mixed" + +# Feedback Settings +feedback_batch_size: 20 # Transaction results included in feedback +enable_feedback_loop: true # Enable AI learning from results + +# Debug Settings +log_ai_conversations: false # Enable detailed AI conversation logging +``` + +### Standard Transaction Options + +```yaml +total_count: 1000 # Total transactions to send (0 = unlimited) +throughput: 10 # Transactions per slot +max_pending: 100 # Maximum pending transactions +max_wallets: 50 # Maximum child wallets +rebroadcast: 1 # Enable transaction rebroadcast +basefee: 20 # Base fee in gwei +tipfee: 2 # Tip fee in gwei +gaslimit: 1000000 # Gas limit for transactions +timeout: "30m" # Scenario timeout +client_group: "" # Client group preference +log_txs: false # Log individual transactions +``` + +## Usage + +### Basic Usage + +```bash +./spamoor aitx --openrouter-api-key="your-key" --count=100 --throughput=5 +``` + +### Advanced Configuration + +```bash +./spamoor aitx \ + --openrouter-api-key="your-key" \ + --model="anthropic/claude-3.5-sonnet" \ + --generation-mode="mixed" \ + --test-direction="focus on complex contract interactions" \ + --payloads-per-request=30 \ + --max-ai-calls=5 \ + --enable-feedback-loop=true \ + --count=500 \ + --throughput=10 \ + --gaslimit=2000000 +``` + +### YAML Configuration + +```yaml +# aitx-config.yaml +scenarios: + aitx: + # AI Configuration + openrouter_api_key: "your-api-key-here" + model: "anthropic/claude-3.5-sonnet" + test_direction: "explore edge cases and gas optimization patterns" + generation_mode: "mixed" + payloads_per_request: 40 + max_ai_calls: 8 + max_tokens: 80000 + + # Feedback Configuration + feedback_batch_size: 25 + enable_feedback_loop: true + + # Debug Configuration + log_ai_conversations: true + + # Transaction Configuration + total_count: 2000 + throughput: 15 + max_pending: 150 + basefee: 25 + tipfee: 3 + gaslimit: 1500000 + rebroadcast: 1 + log_txs: true +``` + +## Generation Modes + +### Geas Mode (`generation_mode: "geas"`) +**PREFERRED MODE** - Provides the most interesting and comprehensive EVM testing. + +Two geas deployment methods are supported: + +#### Simple Method (for arbitrary opcode/precompile testing) +- Generates standalone geas assembly code +- Code gets deployed as a contract and can be called with specific calldata +- Best for: testing specific opcodes, precompiles, edge cases +- Example: Testing modular exponentiation precompile, specific memory operations + +#### Init/Run Method (for performance testing) +- Generates separate initialization code and run code +- Init code executes once during contract deployment +- Run code executes in a loop until gas is consumed (like gasburnertx) +- Best for: performance benchmarks, gas burning, stress testing +- Example: Testing complex loops, storage operations, cryptographic operations + +Both methods automatically validate for security (blocks dangerous operations like selfdestruct, delegatecall, create2) + +### Calldata Mode (`generation_mode: "calldata"`) +- Generates raw calldata for contract interactions +- Creates function calls with ABI-encoded parameters +- Supports common patterns like transfers, approvals, and complex calls +- Includes both simple and complex contract interaction patterns + +### Transfer Mode (`generation_mode: "transfer"`) +- Generates simple ETH transfers between addresses +- Focuses on value transfer patterns and address variations +- Uses placeholder system for dynamic amounts and recipients + +### Mixed Mode (`generation_mode: "mixed"`) +- Combines all generation modes in a single scenario +- **Prioritizes geas generation** (70-80% of payloads will be geas-based) +- Provides maximum diversity in transaction patterns +- AI automatically balances between geas methods and other transaction types +- Recommended for comprehensive testing scenarios + +## Placeholder System + +The AI can use placeholders that are dynamically substituted during transaction building: + +- `${WALLET_ADDRESS}`: Random wallet from pool +- `${RANDOM_ADDRESS}`: Randomly generated address +- `${ETH_AMOUNT_SMALL/MEDIUM/LARGE}`: Dynamic ETH amounts +- `${GAS_LIMIT_LOW/MEDIUM/HIGH}`: Dynamic gas limits +- `${RANDOM_UINT256}`: Random 256-bit integer +- `${RANDOM_BYTES32}`: Random 32-byte value +- `${LOOP_COUNT_SMALL/MEDIUM/LARGE}`: Loop iteration counts + +## Feedback Loop + +When enabled, the scenario: +1. Collects transaction execution results (success/failure, gas usage, errors) +2. Provides statistical analysis to the AI for subsequent generations +3. Encourages the AI to avoid failing patterns and explore successful variations +4. Builds context across multiple AI calls for progressive improvement + +## Cost Management + +- **API Call Limits**: Prevents runaway costs with `max_ai_calls` +- **Token Limits**: Controls total token consumption with `max_tokens` +- **Batch Processing**: Generates multiple payloads per API call +- **Efficient Caching**: Reuses generated payloads until cache is exhausted + +## Debugging + +### AI Conversation Logging + +Enable detailed logging of AI conversations with `--log-ai-conversations=true` or in YAML: + +```yaml +log_ai_conversations: true +``` + +This will log: +- Full conversation history for each AI request +- All retry attempts with error feedback +- AI responses before and after parsing +- Truncated content for readability (messages over 2000 chars) + +Use with debug logging level for maximum detail: +```bash +./spamoor aitx --log-ai-conversations=true --log-level=debug +``` + +### Geas Compilation Validation + +The scenario automatically validates all geas code during the AI conversation: + +**Real-time Compilation**: Every geas payload is compiled using the geas compiler before being accepted +**Immediate Feedback**: Compilation errors are sent back to the AI in the same conversation for immediate correction +**Comprehensive Error Guidance**: AI receives specific guidance on: +- Valid EVM opcodes and syntax +- Proper formatting requirements (newlines, hex format) +- Common compilation issues and fixes +- Stack management requirements + +**Example Error Feedback to AI**: +``` +GEAS COMPILATION ERROR DETECTED: +Error: geas compilation failed: unknown opcode 'keccak256' + +Your geas assembly code failed to compile. Please fix the following issues: + +GEAS CODE REQUIREMENTS: +1. Use VALID EVM opcodes only (e.g., push1, add, mul, sstore, sload, etc.) +2. Format: ONE opcode per line, separated by \n +3. Use correct syntax: 'push1 0x20' not 'push1(0x20)' or 'PUSH1 0x20' +4. Use 'sha3' instead of 'keccak256' for the EVM opcode +... +``` + +### Common Debug Scenarios + +**Parsing Errors**: When AI generates invalid JSON, conversation logs show exactly what was sent and received +**Geas Compilation Errors**: Invalid assembly code is caught immediately with specific error feedback to guide AI fixes +**Retry Logic**: Track how error feedback is provided and how AI responds to corrections +**Token Usage**: Monitor conversation length and token consumption patterns +**Response Quality**: Analyze AI output quality and prompt effectiveness + +## Security + +- **Code Validation**: Blocks dangerous operations (selfdestruct, delegatecall, create2) +- **Geas Compilation**: Validates assembly code before execution +- **Payload Validation**: Ensures all generated payloads meet safety requirements +- **API Key Security**: Supports environment variable configuration + +## Troubleshooting + +### Common Issues + +#### API Key Not Found +``` +Error: OpenRouter API key is required +``` +**Solution**: Set the API key using `--openrouter-api-key` flag or `OPENROUTER_API_KEY` environment variable. + +#### AI Generation Failures +``` +Error: AI payload generation failed: HTTP request failed +``` +**Solutions**: +- Check internet connectivity +- Verify API key validity +- Ensure OpenRouter service is available +- Try reducing `payloads_per_request` if hitting rate limits + +#### Invalid Generation Mode +``` +Error: invalid generation mode 'invalid', must be one of: geas, calldata, transfer, mixed +``` +**Solution**: Use a valid generation mode: `geas`, `calldata`, `transfer`, or `mixed`. + +#### Geas Compilation Errors +``` +Error: geas compilation failed: syntax error +``` +**Solutions**: +- AI-generated geas code may be invalid +- Check the AI model configuration +- Try different `test_direction` guidance +- Verify the AI is generating valid assembly syntax + +### Performance Optimization + +- **Throughput**: Start with low throughput (5-10) and increase gradually +- **Wallet Count**: Ensure adequate wallets for transaction distribution +- **Gas Limits**: Adjust based on payload complexity +- **Batch Size**: Increase `payloads_per_request` for better cost efficiency +- **Feedback**: Enable feedback loop for improved AI generation over time + +### Monitoring + +- Monitor token consumption to stay within budget +- Track API call usage against limits +- Watch transaction success rates for payload quality +- Review gas usage patterns for optimization opportunities + +## Examples + +### Gas Optimization Focus +```bash +./spamoor aitx \ + --openrouter-api-key="your-key" \ + --test-direction="create gas-efficient transaction patterns" \ + --generation-mode="mixed" \ + --enable-feedback-loop=true \ + --count=1000 +``` + +### Complex Contract Testing +```bash +./spamoor aitx \ + --openrouter-api-key="your-key" \ + --test-direction="generate complex contract interactions with edge cases" \ + --generation-mode="calldata" \ + --gaslimit=3000000 \ + --count=500 +``` + +### Assembly Code Generation +```bash +./spamoor aitx \ + --openrouter-api-key="your-key" \ + --test-direction="create diverse EVM assembly patterns for stress testing" \ + --generation-mode="geas" \ + --gaslimit=5000000 \ + --payloads-per-request=20 +``` + +### Performance Testing with Init/Run Geas +```bash +./spamoor aitx \ + --openrouter-api-key="your-key" \ + --test-direction="generate performance benchmarks using init/run pattern" \ + --generation-mode="geas" \ + --gaslimit=8000000 \ + --count=200 +``` + +### Opcode Testing with Simple Geas +```bash +./spamoor aitx \ + --openrouter-api-key="your-key" \ + --test-direction="test specific opcodes and precompiles using simple deployment" \ + --generation-mode="geas" \ + --gaslimit=3000000 \ + --payloads-per-request=15 +``` \ No newline at end of file diff --git a/scenarios/aitx/ai_service.go b/scenarios/aitx/ai_service.go new file mode 100644 index 00000000..e4e53d5b --- /dev/null +++ b/scenarios/aitx/ai_service.go @@ -0,0 +1,713 @@ +package aitx + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/sirupsen/logrus" +) + +type AIService struct { + client *http.Client + apiKey string + model string + baseURL string + basePrompt string + tokenCount uint64 + callCount uint64 + logger logrus.FieldLogger + logConversations bool +} + +type GenerationRequest struct { + BasePrompt string + TestDirection string + GenerationMode string + PayloadCount uint64 + PreviousSummary string + TransactionFeedback *TransactionFeedback +} + +type TransactionFeedback struct { + TotalTransactions uint64 `json:"total_transactions"` + SuccessfulTxs uint64 `json:"successful_txs"` + FailedTxs uint64 `json:"failed_txs"` + AverageGasUsed uint64 `json:"average_gas_used"` + MedianGasUsed uint64 `json:"median_gas_used"` + AverageBlockExecTime string `json:"average_block_exec_time"` + RecentResults []TransactionResult `json:"recent_results"` + Summary string `json:"summary"` +} + +type TransactionResult struct { + PayloadType string `json:"payload_type"` + PayloadDescription string `json:"payload_description"` + Status string `json:"status"` + GasUsed uint64 `json:"gas_used"` + BlockExecTime string `json:"block_exec_time"` + ErrorMessage string `json:"error_message,omitempty"` + LogData []string `json:"log_data,omitempty"` +} + +type GenerationResponse struct { + Payloads []PayloadTemplate + Summary string + TokensUsed uint64 +} + +type ConversationContinuation struct { + History []Message + Feedback string +} + +type OpenRouterRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type OpenRouterResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +func NewAIService(apiKey, model string, logConversations bool, logger logrus.FieldLogger) *AIService { + if apiKey == "" { + apiKey = os.Getenv("OPENROUTER_API_KEY") + } + + return &AIService{ + client: &http.Client{ + Timeout: 60 * time.Second, + }, + apiKey: apiKey, + model: model, + baseURL: "https://openrouter.ai/api/v1/chat/completions", + logger: logger.WithField("component", "ai_service"), + logConversations: logConversations, + } +} + +func (ai *AIService) SetBasePrompt(prompt string) { + ai.basePrompt = prompt +} + +func (ai *AIService) GeneratePayloads(ctx context.Context, req GenerationRequest, processor *PayloadProcessor) (*GenerationResponse, error) { + maxRetries := 3 + var lastError error + var conversationHistory []Message + + // Build initial prompt + req.BasePrompt = ai.basePrompt + initialPrompt := ai.buildPrompt(req) + + conversationHistory = append(conversationHistory, Message{ + Role: "user", + Content: initialPrompt, + }) + + for attempt := 0; attempt < maxRetries; attempt++ { + ai.callCount++ + ai.logger.Debugf("making AI request #%d (attempt %d/%d) for %d payloads", + ai.callCount, attempt+1, maxRetries, req.PayloadCount) + + openRouterReq := OpenRouterRequest{ + Model: ai.model, + Messages: conversationHistory, + MaxTokens: 10000, + } + + response, err := ai.callOpenRouter(ctx, openRouterReq) + if err != nil { + lastError = fmt.Errorf("AI API call failed: %w", err) + continue + } + + ai.tokenCount += uint64(response.Usage.TotalTokens) + ai.logger.Infof("AI call #%d completed: %d tokens used, %d total tokens", + ai.callCount, response.Usage.TotalTokens, ai.tokenCount) + + // Try to parse the response + result, parseErr := ai.parseResponse(response) + if parseErr == nil { + // Validate payloads (including geas compilation) + validPayloads, validationErr := processor.ProcessPayloads(result.Payloads) + if validationErr == nil { + // Log AI response for debugging if enabled + if ai.logConversations { + ai.logConversation(conversationHistory, attempt+1) + } + + // Success! Update result with validated payloads and return + result.Payloads = validPayloads + ai.logger.Infof("AI conversation #%d completed successfully after %d attempt(s)", ai.callCount, attempt+1) + return result, nil + } + // Validation failed, treat as parsing error for retry + parseErr = validationErr + } + + // Parsing failed, add the AI response and error feedback to conversation + lastError = parseErr + + // Add AI response to conversation history + if len(response.Choices) > 0 { + conversationHistory = append(conversationHistory, Message{ + Role: "assistant", + Content: response.Choices[0].Message.Content, + }) + } + + // Add error feedback for retry + errorFeedback := ai.buildErrorFeedback(parseErr, attempt+1, maxRetries) + conversationHistory = append(conversationHistory, Message{ + Role: "user", + Content: errorFeedback, + }) + + ai.logger.Warnf("AI request #%d failed (attempt %d/%d): %v, retrying...", + ai.callCount, attempt+1, maxRetries, parseErr) + } + + return nil, fmt.Errorf("failed to generate valid payloads after %d attempts, last error: %w", maxRetries, lastError) +} + +func (ai *AIService) GeneratePayloadsWithConversation(ctx context.Context, req GenerationRequest, processor *PayloadProcessor, continuation *ConversationContinuation) (*GenerationResponse, []Message, error) { + maxRetries := 3 + var lastError error + var conversationHistory []Message + + if continuation != nil { + // Continue existing conversation + conversationHistory = continuation.History + conversationHistory = append(conversationHistory, Message{ + Role: "user", + Content: continuation.Feedback, + }) + } else { + // Start new conversation + req.BasePrompt = ai.basePrompt + initialPrompt := ai.buildPrompt(req) + conversationHistory = append(conversationHistory, Message{ + Role: "user", + Content: initialPrompt, + }) + } + + for attempt := 0; attempt < maxRetries; attempt++ { + ai.callCount++ + ai.logger.Debugf("making AI request #%d (attempt %d/%d) for conversation", + ai.callCount, attempt+1, maxRetries) + + openRouterReq := OpenRouterRequest{ + Model: ai.model, + Messages: conversationHistory, + MaxTokens: 10000, + } + + response, err := ai.callOpenRouter(ctx, openRouterReq) + if err != nil { + lastError = fmt.Errorf("AI API call failed: %w", err) + continue + } + + ai.tokenCount += uint64(response.Usage.TotalTokens) + ai.logger.Infof("AI call #%d completed: %d tokens used, %d total tokens", + ai.callCount, response.Usage.TotalTokens, ai.tokenCount) + + // Try to parse the response + result, parseErr := ai.parseResponse(response) + if parseErr == nil { + // Validate payloads (including geas compilation) + validPayloads, validationErr := processor.ProcessPayloads(result.Payloads) + if validationErr == nil { + // Success! Update result with validated payloads and add AI response to history + result.Payloads = validPayloads + if len(response.Choices) > 0 { + conversationHistory = append(conversationHistory, Message{ + Role: "assistant", + Content: response.Choices[0].Message.Content, + }) + } + + // Log AI response for debugging if enabled + if ai.logConversations { + ai.logConversation(conversationHistory, attempt+1) + } + + ai.logger.Infof("AI conversation #%d completed successfully after %d attempt(s)", ai.callCount, attempt+1) + return result, conversationHistory, nil + } + // Validation failed, treat as parsing error for retry + parseErr = validationErr + } + + // Parsing failed, add the AI response and error feedback to conversation + lastError = parseErr + + // Add AI response to conversation history + if len(response.Choices) > 0 { + conversationHistory = append(conversationHistory, Message{ + Role: "assistant", + Content: response.Choices[0].Message.Content, + }) + } + + // Add error feedback for retry + errorFeedback := ai.buildErrorFeedback(parseErr, attempt+1, maxRetries) + conversationHistory = append(conversationHistory, Message{ + Role: "user", + Content: errorFeedback, + }) + + ai.logger.Warnf("AI request #%d failed (attempt %d/%d): %v, retrying...", + ai.callCount, attempt+1, maxRetries, parseErr) + } + + return nil, conversationHistory, fmt.Errorf("failed to generate valid payloads after %d attempts, last error: %w", maxRetries, lastError) +} + +func (ai *AIService) buildPrompt(req GenerationRequest) string { + var promptBuilder strings.Builder + + promptBuilder.WriteString(req.BasePrompt) + promptBuilder.WriteString("\n\n") + + if req.TestDirection != "" { + promptBuilder.WriteString(fmt.Sprintf("TEST DIRECTION: %s\n\n", req.TestDirection)) + } + + promptBuilder.WriteString(fmt.Sprintf("Generate %d transaction payload(s).\n", req.PayloadCount)) + + if req.TransactionFeedback != nil { + promptBuilder.WriteString("FEEDBACK FROM PREVIOUS TRANSACTIONS:\n") + promptBuilder.WriteString(fmt.Sprintf("Total executed: %d (Success: %d, Failed: %d)\n", + req.TransactionFeedback.TotalTransactions, + req.TransactionFeedback.SuccessfulTxs, + req.TransactionFeedback.FailedTxs)) + promptBuilder.WriteString(fmt.Sprintf("Gas usage - Average: %d, Median: %d\n", + req.TransactionFeedback.AverageGasUsed, + req.TransactionFeedback.MedianGasUsed)) + promptBuilder.WriteString(fmt.Sprintf("Average block execution time: %s\n", + req.TransactionFeedback.AverageBlockExecTime)) + + if len(req.TransactionFeedback.RecentResults) > 0 { + promptBuilder.WriteString("\nRecent transaction results:\n") + for _, result := range req.TransactionFeedback.RecentResults { + promptBuilder.WriteString(fmt.Sprintf("- %s: %s (gas: %d, block_time: %s)\n", + result.PayloadDescription, result.Status, + result.GasUsed, result.BlockExecTime)) + if result.ErrorMessage != "" { + promptBuilder.WriteString(fmt.Sprintf(" Error: %s\n", result.ErrorMessage)) + } + if len(result.LogData) > 0 { + promptBuilder.WriteString(fmt.Sprintf(" Logs: %v\n", result.LogData)) + } + } + } + + if req.TransactionFeedback.Summary != "" { + promptBuilder.WriteString(fmt.Sprintf("\nPrevious summary: %s\n", req.TransactionFeedback.Summary)) + } + + promptBuilder.WriteString("\nPlease generate NEW, DIFFERENT payloads that:\n") + promptBuilder.WriteString("1. Avoid patterns that consistently failed\n") + promptBuilder.WriteString("2. Explore different gas usage patterns\n") + promptBuilder.WriteString("3. Consider block execution time impact\n") + promptBuilder.WriteString("4. Build on successful patterns but with variations\n") + promptBuilder.WriteString("5. Consider log data from successful transactions\n\n") + } + + if req.PreviousSummary != "" { + promptBuilder.WriteString(fmt.Sprintf("Previous generation summary: %s\n\n", req.PreviousSummary)) + } + + return promptBuilder.String() +} + +func (ai *AIService) buildBasePrompt(generationMode string) string { + var promptBuilder strings.Builder + + promptBuilder.WriteString("You are an Ethereum transaction generator for the Spamoor testing framework.\n") + promptBuilder.WriteString("Your role is to create geas init/run contracts for comprehensive EVM testing.\n\n") + + promptBuilder.WriteString("GEAS INIT/RUN CONTRACT GENERATION:\n\n") + + promptBuilder.WriteString("CONCEPT:\n") + promptBuilder.WriteString("The init/run pattern deploys a contract with two phases:\n") + promptBuilder.WriteString("1. INIT PHASE: Executes ONCE during contract deployment (constructor)\n") + promptBuilder.WriteString("2. RUN PHASE: Executes in a LOOP when the contract is called, consuming all available gas\n\n") + + promptBuilder.WriteString("EXECUTION MODEL:\n") + promptBuilder.WriteString("1. Contract is deployed with init_code executing once\n") + promptBuilder.WriteString("2. Contract is then CALLED with optional calldata\n") + promptBuilder.WriteString("3. Run code executes repeatedly until gas is almost exhausted\n") + promptBuilder.WriteString("4. Post code executes ONCE at the end when gas is low (for final LOGs/cleanup)\n") + promptBuilder.WriteString("5. Each run iteration MUST maintain clean stack (no pollution)\n\n") + + promptBuilder.WriteString("CRITICAL REQUIREMENTS:\n") + promptBuilder.WriteString("1. RUN CODE should reuse previous iteration results for subsequent operations to avoid intermediate result caching in the EVM\n") + promptBuilder.WriteString("2. RUN CODE may modify stack to keep track of previous results - push empty value from init code, modify via SWAPn in loop\n") + promptBuilder.WriteString("3. Stack must be same size at the end of each run iteration (but may contain different values)\n") + promptBuilder.WriteString("4. POST CODE executes once at end when gas is low - ideal for LOG events to report final results\n") + promptBuilder.WriteString("5. Init, run, and post code can access calldata using CALLDATALOAD, CALLDATASIZE, CALLDATACOPY\n") + promptBuilder.WriteString("6. Avoid LOG events in run code (expensive) - use post code for final result logging\n\n") + + promptBuilder.WriteString("CALLDATA ACCESS:\n") + promptBuilder.WriteString("- CALLDATASIZE: Get size of input data\n") + promptBuilder.WriteString("- PUSH1 0x00 CALLDATALOAD: Load first 32 bytes of calldata\n") + promptBuilder.WriteString("- PUSH1 0x20 CALLDATALOAD: Load second 32 bytes of calldata\n") + promptBuilder.WriteString("- CALLDATACOPY: Copy calldata to memory\n\n") + + promptBuilder.WriteString("GEAS CODE FORMAT:\n") + promptBuilder.WriteString("- ONE opcode per line, separated by \\n\n") + promptBuilder.WriteString("- Uppercase opcodes only\n") + promptBuilder.WriteString("- Hex values with 0x prefix\n") + promptBuilder.WriteString("- Example: PUSH1 0x20\\nPUSH1 0x00\\nMSTORE\n\n") + + promptBuilder.WriteString("EXAMPLE PATTERNS:\n") + promptBuilder.WriteString("1. Parameter processing: Load calldata, perform operations, store results\n") + promptBuilder.WriteString("2. Computation loops: Mathematical operations with clean stack management\n") + promptBuilder.WriteString("3. Storage patterns: Read/write with counters or mappings\n") + promptBuilder.WriteString("4. Event emission: Log computation results or state changes\n") + promptBuilder.WriteString("5. Memory operations: Expand memory, hash data, manipulate arrays\n\n") + + promptBuilder.WriteString("AVAILABLE PLACEHOLDERS:\n") + promptBuilder.WriteString("- ${RANDOM_UINT256}: Random 256-bit unsigned integer\n") + promptBuilder.WriteString("- ${RANDOM_BYTES32}: Random 32-byte value\n") + promptBuilder.WriteString("- ${CURRENT_BLOCK}: Current block number\n\n") + + promptBuilder.WriteString("RESPONSE FORMAT:\n") + promptBuilder.WriteString("CRITICAL: Your response is parsed programmatically. Return ONLY JSON objects in ```json blocks with NO explanations.\n") + promptBuilder.WriteString("Generate at least 20 separate JSON objects (do not stop before), each wrapped in ```json and ``` tags:\n\n") + + promptBuilder.WriteString(`{ + "type": "geas", + "description": "Brief description of what this contract does", + "init_code": "PUSH1 0x00\nSSTORE", + "run_code": "PUSH1 0x00\nSLOAD\nPUSH1 0x01\nADD\nDUP1\nPUSH1 0x00\nSSTORE\nPOP", + "post_code": "PUSH1 0x00\nSLOAD\nPUSH1 0x00\nMSTORE\nPUSH1 0x20\nPUSH1 0x00\nLOG0", + "gas_remainder": "10000", + "calldata": "0x1234567800000000000000000000000000000000000000000000000000000005" +}` + "\n\n") + + promptBuilder.WriteString("POST_CODE FIELD:\n") + promptBuilder.WriteString("- Optional code that executes ONCE at the end when gas is low\n") + promptBuilder.WriteString("- Ideal for LOG events to report final computation results\n") + promptBuilder.WriteString("- Can access stack values accumulated during run iterations\n") + promptBuilder.WriteString("- Example: LOG0 to emit final counter value or computation result\n\n") + + promptBuilder.WriteString("CALLDATA FIELD:\n") + promptBuilder.WriteString("- Optional hex-encoded calldata for the contract call\n") + promptBuilder.WriteString("- Can be used to pass parameters to the run code\n") + promptBuilder.WriteString("- Access in run code via calldataload, calldatasize, etc.\n") + promptBuilder.WriteString("- Example: \"0x\" + 32-byte parameter as hex\n\n") + + promptBuilder.WriteString("IMPORTANT:\n") + promptBuilder.WriteString("- Generate ONLY geas init_run contracts (type=\\\"geas\\\")\n") + promptBuilder.WriteString("- Focus on diverse EVM testing patterns\n") + promptBuilder.WriteString("- Reuse previous iteration results to avoid EVM caching\n") + promptBuilder.WriteString("- Use SWAPn to manage persistent values on stack\n") + promptBuilder.WriteString("- Use calldata for dynamic behavior\n") + promptBuilder.WriteString("- NO explanatory text - ONLY JSON objects\n\n") + + return promptBuilder.String() +} + +func (ai *AIService) callOpenRouter(ctx context.Context, req OpenRouterRequest) (*OpenRouterResponse, error) { + jsonData, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", ai.baseURL, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+ai.apiKey) + httpReq.Header.Set("HTTP-Referer", "https://github.com/ethpandaops/spamoor") + httpReq.Header.Set("X-Title", "Spamoor AI Transaction Generator") + + resp, err := ai.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("OpenRouter API error %d: %s", resp.StatusCode, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var openRouterResp OpenRouterResponse + if err := json.Unmarshal(body, &openRouterResp); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &openRouterResp, nil +} + +func (ai *AIService) parseResponse(response *OpenRouterResponse) (*GenerationResponse, error) { + if len(response.Choices) == 0 { + return nil, fmt.Errorf("no choices in AI response") + } + + content := response.Choices[0].Message.Content + ai.logger.Infof("AI response content: %s", content) + + var payloads []PayloadTemplate + + // Try parsing as direct JSON array first + err := json.Unmarshal([]byte(content), &payloads) + if err != nil { + ai.logger.Debugf("failed to parse as direct JSON array, extracting from conversational response: %v", err) + + // Extract individual JSON objects from conversational text + payloads, err = ai.extractJSONObjectsFromText(content) + if err != nil { + // Fallback to old extraction method + ai.logger.Debugf("failed to extract JSON objects, trying array extraction: %v", err) + payloads, err = ai.extractJSONFromText(content) + if err != nil { + return nil, fmt.Errorf("failed to parse AI response as JSON: %w", err) + } + } + } + + if len(payloads) == 0 { + return nil, fmt.Errorf("no payloads found in AI response") + } + + ai.logger.Infof("Successfully parsed %d payloads from AI response", len(payloads)) + summary := fmt.Sprintf("Generated %d payloads using %s", len(payloads), ai.model) + + return &GenerationResponse{ + Payloads: payloads, + Summary: summary, + TokensUsed: uint64(response.Usage.TotalTokens), + }, nil +} + +func (ai *AIService) extractJSONObjectsFromText(content string) ([]PayloadTemplate, error) { + var payloads []PayloadTemplate + + // Look for JSON code blocks marked with ```json + lines := strings.Split(content, "\n") + var jsonBlock strings.Builder + inJSONBlock := false + + for _, line := range lines { + line = strings.TrimSpace(line) + + if strings.HasPrefix(line, "```json") { + inJSONBlock = true + jsonBlock.Reset() + continue + } + + if strings.HasPrefix(line, "```") && inJSONBlock { + // End of JSON block, try to parse it + jsonStr := jsonBlock.String() + ai.logger.Infof("Attempting to parse JSON block: %s", jsonStr) + + var payload PayloadTemplate + if err := json.Unmarshal([]byte(jsonStr), &payload); err == nil { + payloads = append(payloads, payload) + ai.logger.Infof("Successfully parsed payload: %s", payload.Description) + } else { + ai.logger.Errorf("Failed to parse JSON block: %v", err) + } + + inJSONBlock = false + continue + } + + if inJSONBlock { + jsonBlock.WriteString(line) + jsonBlock.WriteString("\n") + } + } + + // If we found payloads, return them + if len(payloads) > 0 { + return payloads, nil + } + + // Fallback: look for individual JSON objects using regex-like approach + return ai.extractJSONObjectsWithRegex(content) +} + +func (ai *AIService) extractJSONObjectsWithRegex(content string) ([]PayloadTemplate, error) { + var payloads []PayloadTemplate + + // Look for patterns like { ... } that might be JSON objects + braceLevel := 0 + var currentObj strings.Builder + inObject := false + + for i, r := range content { + if r == '{' { + if braceLevel == 0 { + inObject = true + currentObj.Reset() + } + braceLevel++ + currentObj.WriteRune(r) + } else if r == '}' { + braceLevel-- + currentObj.WriteRune(r) + + if braceLevel == 0 && inObject { + // Try to parse this object + objStr := strings.TrimSpace(currentObj.String()) + ai.logger.Infof("Attempting to parse JSON object: %s", objStr) + + var payload PayloadTemplate + if err := json.Unmarshal([]byte(objStr), &payload); err == nil { + payloads = append(payloads, payload) + ai.logger.Infof("Successfully parsed payload: %s", payload.Description) + } else { + ai.logger.Errorf("Failed to parse JSON object at position %d: %v", i, err) + } + + inObject = false + } + } else if inObject { + currentObj.WriteRune(r) + } + } + + if len(payloads) == 0 { + return nil, fmt.Errorf("no valid JSON objects found in response") + } + + return payloads, nil +} + +func (ai *AIService) extractJSONFromText(content string) ([]PayloadTemplate, error) { + start := strings.Index(content, "[") + end := strings.LastIndex(content, "]") + + if start == -1 || end == -1 || start >= end { + return nil, fmt.Errorf("no JSON array found in response") + } + + jsonStr := content[start : end+1] + + var payloads []PayloadTemplate + err := json.Unmarshal([]byte(jsonStr), &payloads) + if err != nil { + return nil, fmt.Errorf("failed to parse extracted JSON: %w", err) + } + + return payloads, nil +} + +func (ai *AIService) GetTokenCount() uint64 { + return ai.tokenCount +} + +func (ai *AIService) buildErrorFeedback(parseErr error, attempt int, maxRetries int) string { + var feedbackBuilder strings.Builder + + errorStr := parseErr.Error() + + // Check if this is a geas compilation error + if strings.Contains(errorStr, "geas compilation failed") { + feedbackBuilder.WriteString("GEAS COMPILATION ERROR DETECTED:\n") + feedbackBuilder.WriteString(fmt.Sprintf("Error: %v\n\n", parseErr)) + + feedbackBuilder.WriteString("Your geas assembly code failed to compile. Please fix the following issues:\n\n") + + feedbackBuilder.WriteString("GEAS CODE REQUIREMENTS:\n") + feedbackBuilder.WriteString("1. Use VALID EVM opcodes only (e.g., PUSH1, ADD, MUL, SSTORE, SLOAD, etc.)\n") + feedbackBuilder.WriteString("2. Format: ONE opcode per line, separated by \\n\n") + feedbackBuilder.WriteString("3. Use correct syntax: 'PUSH1 0x20' with uppercase opcodes\n") + feedbackBuilder.WriteString("4. Hexadecimal values must start with 0x\n") + feedbackBuilder.WriteString("5. All opcodes are allowed including selfdestruct, delegatecall, create2\n") + feedbackBuilder.WriteString("6. Ensure stack balance (don't leave extra items on stack)\n") + feedbackBuilder.WriteString("7. CRITICAL: Run code MUST have clean stack after each iteration\n\n") + + feedbackBuilder.WriteString("COMMON FIXES:\n") + feedbackBuilder.WriteString("- Check opcode spelling and case sensitivity\n") + feedbackBuilder.WriteString("- Verify hex values format (0x prefix)\n") + feedbackBuilder.WriteString("- Ensure proper stack management with 'pop'\n") + feedbackBuilder.WriteString("- Use 'pop' to clean up ALL unused stack items\n") + feedbackBuilder.WriteString("- Remember: sha3 not keccak256 for EVM opcode\n\n") + + feedbackBuilder.WriteString("EXAMPLE VALID GEAS CODE:\n") + feedbackBuilder.WriteString("\"PUSH1 0x20\\nPUSH1 0x00\\nMSTORE\\nPUSH1 0x20\\nPUSH1 0x00\\nSHA3\\nPOP\"\n\n") + } else { + feedbackBuilder.WriteString("PARSING/VALIDATION ERROR DETECTED:\n") + feedbackBuilder.WriteString(fmt.Sprintf("Error: %v\n\n", parseErr)) + + feedbackBuilder.WriteString("Your previous response could not be parsed or validated correctly. ") + feedbackBuilder.WriteString("Please ensure your response follows the exact JSON format specified.\n\n") + + feedbackBuilder.WriteString("REQUIREMENTS:\n") + feedbackBuilder.WriteString("1. Wrap JSON payload in ```json and ``` code blocks\n") + feedbackBuilder.WriteString("2. Return ONLY ONE payload object (not an array)\n") + feedbackBuilder.WriteString("3. Include all required fields: type, description, init_code, run_code\n") + feedbackBuilder.WriteString("4. Set type=\"geas\" (init_run method is implied)\n") + feedbackBuilder.WriteString("5. Use proper JSON syntax with quotes around strings\n") + feedbackBuilder.WriteString("6. GEAS CODE FORMAT: Use newlines (\\n) to separate opcodes - ONE opcode per line\n") + feedbackBuilder.WriteString("7. Include optional 'calldata' and 'post_code' fields\n") + feedbackBuilder.WriteString("8. Do NOT include geas_method or placeholders fields\n\n") + } + + if attempt < maxRetries { + feedbackBuilder.WriteString(fmt.Sprintf("This is attempt %d of %d. Please try again with the corrected code.\n", attempt, maxRetries)) + } else { + feedbackBuilder.WriteString("This is the final attempt. Please ensure your response is properly formatted and valid.\n") + } + + return feedbackBuilder.String() +} + +func (ai *AIService) logConversation(conversationHistory []Message, attempt int) { + ai.logger.Infof("=== AI Conversation #%d (Attempt %d) ===", ai.callCount, attempt) + + for i, message := range conversationHistory { + role := strings.ToUpper(message.Role) + content := message.Content + + ai.logger.Infof("--- Message %d: %s ---\n%s\n", i+1, role, content) + } + + ai.logger.Infof("=== End Conversation #%d ===", ai.callCount) +} + +func (ai *AIService) GetCallCount() uint64 { + return ai.callCount +} diff --git a/scenarios/aitx/aitx.go b/scenarios/aitx/aitx.go new file mode 100644 index 00000000..97104401 --- /dev/null +++ b/scenarios/aitx/aitx.go @@ -0,0 +1,562 @@ +package aitx + +import ( + "context" + "encoding/hex" + "fmt" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/holiman/uint256" + "github.com/sirupsen/logrus" + "github.com/spf13/pflag" + + "github.com/ethpandaops/spamoor/scenario" + "github.com/ethpandaops/spamoor/spamoor" + "github.com/ethpandaops/spamoor/txbuilder" + "github.com/ethpandaops/spamoor/utils" +) + +type ScenarioOptions struct { + TotalCount uint64 `yaml:"total_count"` + Throughput uint64 `yaml:"throughput"` + MaxPending uint64 `yaml:"max_pending"` + MaxWallets uint64 `yaml:"max_wallets"` + Rebroadcast uint64 `yaml:"rebroadcast"` + BaseFee float64 `yaml:"base_fee"` + TipFee float64 `yaml:"tip_fee"` + GasLimit uint64 `yaml:"gas_limit"` + Timeout string `yaml:"timeout"` + ClientGroup string `yaml:"client_group"` + LogTxs bool `yaml:"log_txs"` + + // AI-specific options + OpenRouterAPIKey string `yaml:"openrouter_api_key"` + Model string `yaml:"model"` + TestDirection string `yaml:"test_direction"` + PayloadsPerRequest uint64 `yaml:"payloads_per_request"` + MaxAICalls uint64 `yaml:"max_ai_calls"` + MaxTokens uint64 `yaml:"max_tokens"` + + // Generation options + GenerationMode string `yaml:"generation_mode"` // "geas", "calldata", "transfer", "mixed" + + // Feedback options + FeedbackBatchSize uint64 `yaml:"feedback_batch_size"` + EnableFeedbackLoop bool `yaml:"enable_feedback_loop"` + + // Debug options + LogAIConversations bool `yaml:"log_ai_conversations"` +} + +type Scenario struct { + options ScenarioOptions + logger *logrus.Entry + walletPool *spamoor.WalletPool + + aiService *AIService + processor *PayloadProcessor + placeholderSubstituter *PlaceholderSubstituter + feedbackCollector *FeedbackCollector + geasProcessor *GeasProcessor + + payloadCache []PayloadTemplate + cacheIndex int + aiMutex sync.Mutex // Protects AI calls and payload cache + conversationHistory []Message // Persisted conversation history + conversationResponses int // Number of AI responses in current conversation +} + +var ScenarioName = "aitx" +var ScenarioDefaultOptions = ScenarioOptions{ + TotalCount: 0, + Throughput: 10, + MaxPending: 0, + MaxWallets: 0, + Rebroadcast: 1, + BaseFee: 20, + TipFee: 2, + GasLimit: 5000000, + Timeout: "", + ClientGroup: "", + LogTxs: false, + + // AI defaults + OpenRouterAPIKey: "", + Model: "anthropic/claude-3.5-sonnet", + TestDirection: "", + PayloadsPerRequest: 50, + MaxAICalls: 10, + MaxTokens: 100000, + + // Generation defaults + GenerationMode: "geas", + + // Feedback defaults + FeedbackBatchSize: 20, + EnableFeedbackLoop: true, + + // Debug defaults + LogAIConversations: false, +} + +var ScenarioDescriptor = scenario.Descriptor{ + Name: ScenarioName, + Description: "AI-powered transaction generator using OpenRouter for diverse test payloads", + DefaultOptions: ScenarioDefaultOptions, + NewScenario: newScenario, +} + +func newScenario(logger logrus.FieldLogger) scenario.Scenario { + return &Scenario{ + options: ScenarioDefaultOptions, + logger: logger.WithField("scenario", ScenarioName), + } +} + +func (s *Scenario) Flags(flags *pflag.FlagSet) error { + flags.Uint64VarP(&s.options.TotalCount, "count", "c", ScenarioDefaultOptions.TotalCount, "Total number of AI transactions to send") + flags.Uint64VarP(&s.options.Throughput, "throughput", "t", ScenarioDefaultOptions.Throughput, "Number of AI transactions to send per slot") + flags.Uint64Var(&s.options.MaxPending, "max-pending", ScenarioDefaultOptions.MaxPending, "Maximum number of pending transactions") + flags.Uint64Var(&s.options.MaxWallets, "max-wallets", ScenarioDefaultOptions.MaxWallets, "Maximum number of child wallets to use") + flags.Uint64Var(&s.options.Rebroadcast, "rebroadcast", ScenarioDefaultOptions.Rebroadcast, "Enable reliable rebroadcast system") + flags.Float64Var(&s.options.BaseFee, "basefee", ScenarioDefaultOptions.BaseFee, "Max fee per gas to use in transactions (in gwei)") + flags.Float64Var(&s.options.TipFee, "tipfee", ScenarioDefaultOptions.TipFee, "Max tip per gas to use in transactions (in gwei)") + flags.Uint64Var(&s.options.GasLimit, "gaslimit", ScenarioDefaultOptions.GasLimit, "Gas limit to use in transactions") + flags.StringVar(&s.options.Timeout, "timeout", ScenarioDefaultOptions.Timeout, "Timeout for the scenario (e.g. '1h', '30m', '5s') - empty means no timeout") + flags.StringVar(&s.options.ClientGroup, "client-group", ScenarioDefaultOptions.ClientGroup, "Client group to use for sending transactions") + flags.BoolVar(&s.options.LogTxs, "log-txs", ScenarioDefaultOptions.LogTxs, "Log all submitted transactions") + + // AI-specific flags + flags.StringVar(&s.options.OpenRouterAPIKey, "openrouter-api-key", ScenarioDefaultOptions.OpenRouterAPIKey, "OpenRouter API key (can also use OPENROUTER_API_KEY env var)") + flags.StringVar(&s.options.Model, "model", ScenarioDefaultOptions.Model, "AI model to use for generation") + flags.StringVar(&s.options.TestDirection, "test-direction", ScenarioDefaultOptions.TestDirection, "Directional guidance for AI test generation") + flags.Uint64Var(&s.options.PayloadsPerRequest, "payloads-per-request", ScenarioDefaultOptions.PayloadsPerRequest, "Number of payload templates to generate per AI request") + flags.Uint64Var(&s.options.MaxAICalls, "max-ai-calls", ScenarioDefaultOptions.MaxAICalls, "Maximum number of AI API calls to make") + flags.Uint64Var(&s.options.MaxTokens, "max-tokens", ScenarioDefaultOptions.MaxTokens, "Maximum total tokens to consume") + + // Generation flags (removed - only geas init_run is supported) + + // Feedback flags + flags.Uint64Var(&s.options.FeedbackBatchSize, "feedback-batch-size", ScenarioDefaultOptions.FeedbackBatchSize, "Number of transaction results to include in feedback") + flags.BoolVar(&s.options.EnableFeedbackLoop, "enable-feedback-loop", ScenarioDefaultOptions.EnableFeedbackLoop, "Enable result feedback to AI for learning") + + // Debug flags + flags.BoolVar(&s.options.LogAIConversations, "log-ai-conversations", ScenarioDefaultOptions.LogAIConversations, "Enable detailed logging of AI conversations for debugging") + + return nil +} + +func (s *Scenario) Init(options *scenario.Options) error { + s.walletPool = options.WalletPool + + if options.Config != "" { + err := scenario.ParseAndValidateConfig(&ScenarioDescriptor, options.Config, &s.options, s.logger) + if err != nil { + return err + } + } + + // Validate AI configuration + if s.options.OpenRouterAPIKey == "" { + return fmt.Errorf("OpenRouter API key is required (use --openrouter-api-key flag or OPENROUTER_API_KEY env var)") + } + + // Generation mode is fixed to "geas" for init_run only + + // Configure wallet count + if s.options.MaxWallets > 0 { + s.walletPool.SetWalletCount(s.options.MaxWallets) + } else if s.options.TotalCount > 0 { + maxWallets := s.options.TotalCount / 50 + if maxWallets < 10 { + maxWallets = 10 + } else if maxWallets > 1000 { + maxWallets = 1000 + } + s.walletPool.SetWalletCount(maxWallets) + } else { + if s.options.Throughput*10 < 1000 { + s.walletPool.SetWalletCount(s.options.Throughput * 10) + } else { + s.walletPool.SetWalletCount(1000) + } + } + + if s.options.TotalCount == 0 && s.options.Throughput == 0 { + return fmt.Errorf("neither total count nor throughput limit set, must define at least one of them") + } + + // Initialize AI components + s.aiService = NewAIService(s.options.OpenRouterAPIKey, s.options.Model, s.options.LogAIConversations, s.logger) + s.processor = NewPayloadProcessor(s.logger) + s.placeholderSubstituter = NewPlaceholderSubstituter(s.walletPool, s.walletPool.GetClient(spamoor.SelectClientByIndex, 0, ""), s.logger) + s.feedbackCollector = NewFeedbackCollector(s.options.FeedbackBatchSize, s.logger) + s.geasProcessor = NewGeasProcessor(s.logger) + + // Set AI base prompt based on generation mode + basePrompt := s.aiService.buildBasePrompt(s.options.GenerationMode) + s.aiService.SetBasePrompt(basePrompt) + + return nil +} + +func (s *Scenario) Run(ctx context.Context) error { + s.logger.Infof("starting AI transaction generator scenario") + defer s.logger.Infof("AI transaction generator scenario finished") + + maxPending := s.options.MaxPending + if maxPending == 0 { + maxPending = s.options.Throughput * 10 + if maxPending == 0 { + maxPending = 1000 + } + if maxPending > s.walletPool.GetConfiguredWalletCount()*10 { + maxPending = s.walletPool.GetConfiguredWalletCount() * 10 + } + } + + // Parse timeout duration + var timeout time.Duration + if s.options.Timeout != "" { + var err error + timeout, err = time.ParseDuration(s.options.Timeout) + if err != nil { + return fmt.Errorf("invalid timeout format '%s': %w", s.options.Timeout, err) + } + } + + err := scenario.RunTransactionScenario(ctx, scenario.TransactionScenarioOptions{ + TotalCount: s.options.TotalCount, + Throughput: s.options.Throughput, + MaxPending: maxPending, + ThroughputIncrementInterval: 0, + Timeout: timeout, + WalletPool: s.walletPool, + + Logger: s.logger, + ProcessNextTxFn: func(ctx context.Context, txIdx uint64, onComplete func()) (func(), error) { + logger := s.logger + tx, client, wallet, err := s.sendAITransaction(ctx, txIdx, onComplete) + if client != nil { + logger = logger.WithField("rpc", client.GetName()) + } + if tx != nil { + logger = logger.WithField("nonce", tx.Nonce()) + } + if wallet != nil { + logger = logger.WithField("wallet", s.walletPool.GetWalletName(wallet.GetAddress())) + } + + return func() { + if err != nil { + logger.Warnf("could not send AI transaction: %v", err) + } else if s.options.LogTxs { + logger.Infof("sent AI tx #%6d: %v", txIdx+1, tx.Hash().String()) + } else { + logger.Debugf("sent AI tx #%6d: %v", txIdx+1, tx.Hash().String()) + } + }, err + }, + }) + + return err +} + +func (s *Scenario) sendAITransaction(ctx context.Context, txIdx uint64, onComplete func()) (*types.Transaction, *spamoor.Client, *spamoor.Wallet, error) { + // Deploy a contract and send 10 call transactions using batch sending + defer onComplete() + + client := s.walletPool.GetClient(spamoor.SelectClientByIndex, int(txIdx), s.options.ClientGroup) + wallet := s.walletPool.GetWallet(spamoor.SelectWalletByPendingTxCount, int(txIdx)) + + if client == nil { + return nil, client, wallet, fmt.Errorf("no client available") + } + + // Get next payload template from AI or cache + template, err := s.getNextPayloadTemplate(ctx) + if err != nil { + s.logger.Errorf("failed to get AI payload template: %v", err) + dummyPayload := &PayloadInstance{Type: "geas", Description: "failed_generation"} + s.feedbackCollector.RecordFailure(dummyPayload, "payload_generation_failed", err.Error()) + return nil, client, wallet, err + } + + // Substitute placeholders + payload, err := template.Substitute(s.placeholderSubstituter) + if err != nil { + s.logger.Errorf("failed to substitute placeholders: %v", err) + dummyPayload := &PayloadInstance{Type: template.Type, Description: template.Description} + s.feedbackCollector.RecordFailure(dummyPayload, "placeholder_substitution_failed", err.Error()) + return nil, client, wallet, err + } + + // Build deployment transaction + s.logger.Infof("deploying contract for payload: %s", payload.Description) + deployTx, contractAddress, err := s.deployGeasContract(ctx, wallet, client, payload) + if err != nil { + s.logger.Errorf("failed to deploy contract: %v", err) + s.feedbackCollector.RecordFailure(payload, "deployment_failed", err.Error()) + return nil, client, wallet, err + } + + // Deploy contract and wait for confirmation using SendAndAwaitTransaction + deployReceipt, err := s.walletPool.GetTxPool().SendAndAwaitTransaction(ctx, wallet, deployTx, &spamoor.SendTransactionOptions{ + Client: client, + Rebroadcast: s.options.Rebroadcast > 0, + LogFn: spamoor.GetDefaultLogFn(s.logger, "deploy", fmt.Sprintf("%6d", txIdx+1), deployTx), + }) + + if err != nil { + s.logger.Errorf("failed to deploy contract: %v", err) + s.feedbackCollector.RecordFailure(payload, "deployment_failed", err.Error()) + return nil, client, wallet, err + } + + if deployReceipt.Status != 1 { + s.logger.Errorf("contract deployment failed (status: %d)", deployReceipt.Status) + s.feedbackCollector.RecordFailure(payload, "deployment_reverted", "deployment transaction reverted") + return deployTx, client, wallet, nil + } + + s.logger.Infof("contract deployed successfully at %s for payload: %s", contractAddress.Hex(), payload.Description) + + // Build 10 call transactions + var callTxs []*types.Transaction + for i := 0; i < 10; i++ { + callTx, err := s.callGeasContract(ctx, wallet, client, contractAddress, payload) + if err != nil { + s.logger.Errorf("failed to build call transaction %d: %v", i+1, err) + s.feedbackCollector.RecordFailure(payload, "call_build_failed", err.Error()) + return deployTx, client, wallet, err + } + callTxs = append(callTxs, callTx) + } + + // Send all call transactions as a batch from same wallet + _, err = s.walletPool.GetTxPool().SendTransactionBatch(ctx, wallet, callTxs, &spamoor.BatchOptions{ + SendTransactionOptions: spamoor.SendTransactionOptions{ + Client: client, + Rebroadcast: s.options.Rebroadcast > 0, + OnConfirm: func(tx *types.Transaction, receipt *types.Receipt) { + // Collect execution results for feedback from call transactions + s.collectTransactionResult(payload, tx, receipt) + }, + }, + }) + + if err != nil { + s.logger.Errorf("failed to send call transaction batch: %v", err) + s.feedbackCollector.RecordFailure(payload, "batch_send_failed", err.Error()) + return deployTx, client, wallet, err + } + + return deployTx, client, wallet, nil +} + +func (s *Scenario) getNextPayloadTemplate(ctx context.Context) (*PayloadTemplate, error) { + // Lock to ensure only one AI call happens at a time + s.aiMutex.Lock() + defer s.aiMutex.Unlock() + + // Check if we have cached payloads + if s.cacheIndex < len(s.payloadCache) { + template := s.payloadCache[s.cacheIndex] + s.cacheIndex++ + return &template, nil + } + + // Generate new batch of payloads + if s.aiService.GetCallCount() >= s.options.MaxAICalls { + return nil, fmt.Errorf("maximum AI calls limit reached (%d)", s.options.MaxAICalls) + } + + if s.aiService.GetTokenCount() >= s.options.MaxTokens { + return nil, fmt.Errorf("maximum token limit reached (%d)", s.options.MaxTokens) + } + + // Check if we need to start a new conversation (after 10 responses) + if s.conversationResponses >= 10 { + s.logger.Infof("resetting conversation after %d responses", s.conversationResponses) + s.conversationHistory = nil + s.conversationResponses = 0 + } + + if len(s.conversationHistory) == 0 { + s.logger.Infof("making AI call #%d - starting new conversation (other transactions waiting)", s.aiService.GetCallCount()+1) + } else { + s.logger.Infof("making AI call #%d - continuing conversation with %d messages (other transactions waiting)", + s.aiService.GetCallCount()+1, len(s.conversationHistory)) + } + + // Generate payloads using conversation continuation + var response *GenerationResponse + var err error + + if len(s.conversationHistory) == 0 { + // Start new conversation + req := GenerationRequest{ + TestDirection: s.options.TestDirection, + GenerationMode: s.options.GenerationMode, + PayloadCount: s.options.PayloadsPerRequest, + PreviousSummary: "", + TransactionFeedback: nil, + } + + // Add feedback if enabled + if s.options.EnableFeedbackLoop { + req.TransactionFeedback = s.feedbackCollector.GenerateFeedback() + } + + response, s.conversationHistory, err = s.aiService.GeneratePayloadsWithConversation(ctx, req, s.processor, nil) + } else { + // Continue existing conversation + feedback := "" + if s.options.EnableFeedbackLoop { + txFeedback := s.feedbackCollector.GenerateFeedback() + if txFeedback != nil { + feedback = fmt.Sprintf("Transaction feedback: %d total (%d success, %d failed), avg gas: %d. Generate more diverse patterns based on this data.", + txFeedback.TotalTransactions, txFeedback.SuccessfulTxs, txFeedback.FailedTxs, txFeedback.AverageGasUsed) + } + } + + if feedback == "" { + feedback = fmt.Sprintf("Generate %d more unique geas init_run contracts with different patterns and behaviors.", s.options.PayloadsPerRequest) + } + + response, s.conversationHistory, err = s.aiService.GeneratePayloadsWithConversation(ctx, GenerationRequest{}, s.processor, &ConversationContinuation{ + History: s.conversationHistory, + Feedback: feedback, + }) + } + + if err != nil { + return nil, fmt.Errorf("AI payload generation failed: %w", err) + } + + // Increment conversation response count + s.conversationResponses++ + + // Payloads are already validated by the AI service + validPayloads := response.Payloads + + // Update cache + s.payloadCache = validPayloads + s.cacheIndex = 1 // Return first, set index to second + + if len(validPayloads) == 0 { + return nil, fmt.Errorf("no valid payloads generated") + } + + s.logger.Infof("AI call completed, generated %d payloads (conversation: %d responses, cache refilled)", + len(validPayloads), s.conversationResponses) + + return &validPayloads[0], nil +} + +func (s *Scenario) deployGeasContract(ctx context.Context, wallet *spamoor.Wallet, client *spamoor.Client, payload *PayloadInstance) (*types.Transaction, common.Address, error) { + // Compile geas code + bytecode, err := s.geasProcessor.CompileGeasPayload(payload) + if err != nil { + return nil, common.Address{}, fmt.Errorf("geas compilation failed: %w", err) + } + + // Get suggested fees + feeCap, tipCap, err := s.walletPool.GetTxPool().GetSuggestedFees(client, s.options.BaseFee, s.options.TipFee) + if err != nil { + return nil, common.Address{}, fmt.Errorf("failed to get suggested fees: %w", err) + } + + // Build deployment transaction + txData, err := txbuilder.DynFeeTx(&txbuilder.TxMetadata{ + GasFeeCap: uint256.MustFromBig(feeCap), + GasTipCap: uint256.MustFromBig(tipCap), + Gas: s.options.GasLimit, + To: nil, // Contract creation + Value: uint256.NewInt(0), // No value for contract deployment + Data: bytecode, + }) + if err != nil { + return nil, common.Address{}, fmt.Errorf("failed to build transaction data: %w", err) + } + + tx, err := wallet.BuildDynamicFeeTx(txData) + if err != nil { + return nil, common.Address{}, err + } + + // Calculate contract address + contractAddr := crypto.CreateAddress(wallet.GetAddress(), tx.Nonce()) + + return tx, contractAddr, nil +} + +func (s *Scenario) callGeasContract(ctx context.Context, wallet *spamoor.Wallet, client *spamoor.Client, contractAddr common.Address, payload *PayloadInstance) (*types.Transaction, error) { + // Get suggested fees + feeCap, tipCap, err := s.walletPool.GetTxPool().GetSuggestedFees(client, s.options.BaseFee, s.options.TipFee) + if err != nil { + return nil, fmt.Errorf("failed to get suggested fees: %w", err) + } + + // Build call transaction with calldata + txData, err := txbuilder.DynFeeTx(&txbuilder.TxMetadata{ + GasFeeCap: uint256.MustFromBig(feeCap), + GasTipCap: uint256.MustFromBig(tipCap), + Gas: s.options.GasLimit, + To: &contractAddr, + Value: uint256.NewInt(0), + Data: payload.Calldata, + }) + if err != nil { + return nil, fmt.Errorf("failed to build transaction data: %w", err) + } + + return wallet.BuildDynamicFeeTx(txData) +} + +func (s *Scenario) collectTransactionResult(payload *PayloadInstance, tx *types.Transaction, receipt *types.Receipt) { + if receipt == nil { + s.feedbackCollector.RecordFailure(payload, "receipt_nil", "receipt was nil") + return + } + + // Determine status + status := "success" + errorMsg := "" + if receipt.Status == 0 { + status = "reverted" + } + + // Calculate transaction fees + txFees := utils.GetTransactionFees(tx, receipt) + + // Extract log data + var logData []string + for _, log := range receipt.Logs { + // Convert log data to hex for analysis + logData = append(logData, fmt.Sprintf("addr:%s topics:%d data:%s", + log.Address.Hex(), + len(log.Topics), + hex.EncodeToString(log.Data))) + } + + s.logger.Debugf("transaction confirmed: %s (%s) - %s, gas: %d, fees: %s, logs: %d", + payload.Description, payload.Type, status, receipt.GasUsed, txFees.TotalFeeGweiString(), len(receipt.Logs)) + + // Record result for feedback + result := TransactionResult{ + PayloadType: payload.Type, + PayloadDescription: payload.Description, + Status: status, + GasUsed: receipt.GasUsed, + BlockExecTime: "N/A", // Placeholder for external system + ErrorMessage: errorMsg, + LogData: logData, + } + + s.feedbackCollector.RecordResult(result) +} diff --git a/scenarios/aitx/feedback_collector.go b/scenarios/aitx/feedback_collector.go new file mode 100644 index 00000000..ae6f32da --- /dev/null +++ b/scenarios/aitx/feedback_collector.go @@ -0,0 +1,140 @@ +package aitx + +import ( + "fmt" + "sort" + "strings" + "sync" + + "github.com/sirupsen/logrus" +) + +type FeedbackCollector struct { + results []TransactionResult + mutex sync.RWMutex + maxResults uint64 + totalTransactions uint64 + successfulTxs uint64 + failedTxs uint64 + logger logrus.FieldLogger +} + +func NewFeedbackCollector(maxResults uint64, logger logrus.FieldLogger) *FeedbackCollector { + return &FeedbackCollector{ + results: make([]TransactionResult, 0, maxResults), + maxResults: maxResults, + logger: logger.WithField("component", "feedback_collector"), + } +} + +func (fc *FeedbackCollector) RecordResult(result TransactionResult) { + fc.mutex.Lock() + defer fc.mutex.Unlock() + + fc.results = append(fc.results, result) + if uint64(len(fc.results)) > fc.maxResults { + fc.results = fc.results[1:] + } + + fc.totalTransactions++ + if result.Status == "success" { + fc.successfulTxs++ + } else { + fc.failedTxs++ + } + + fc.logger.Debugf("recorded transaction result: %s (%s) - %s", + result.PayloadDescription, result.PayloadType, result.Status) +} + +func (fc *FeedbackCollector) RecordFailure(payload *PayloadInstance, status, errorMsg string) { + result := TransactionResult{ + PayloadType: payload.Type, + PayloadDescription: payload.Description, + Status: status, + GasUsed: 0, + BlockExecTime: "N/A", + ErrorMessage: errorMsg, + } + fc.RecordResult(result) +} + +func (fc *FeedbackCollector) GenerateFeedback() *TransactionFeedback { + fc.mutex.RLock() + defer fc.mutex.RUnlock() + + if len(fc.results) == 0 { + return nil + } + + gasValues := make([]uint64, 0, len(fc.results)) + for _, result := range fc.results { + if result.Status == "success" && result.GasUsed > 0 { + gasValues = append(gasValues, result.GasUsed) + } + } + + var avgGas, medianGas uint64 + if len(gasValues) > 0 { + sort.Slice(gasValues, func(i, j int) bool { return gasValues[i] < gasValues[j] }) + + var total uint64 + for _, gas := range gasValues { + total += gas + } + avgGas = total / uint64(len(gasValues)) + medianGas = gasValues[len(gasValues)/2] + } + + summary := fc.generateSummary() + + return &TransactionFeedback{ + TotalTransactions: fc.totalTransactions, + SuccessfulTxs: fc.successfulTxs, + FailedTxs: fc.failedTxs, + AverageGasUsed: avgGas, + MedianGasUsed: medianGas, + AverageBlockExecTime: "N/A", + RecentResults: fc.getRecentResults(10), + Summary: summary, + } +} + +func (fc *FeedbackCollector) generateSummary() string { + if len(fc.results) == 0 { + return "No transaction results yet." + } + + typeSuccess := make(map[string]int) + typeTotal := make(map[string]int) + + for _, result := range fc.results { + typeTotal[result.PayloadType]++ + if result.Status == "success" { + typeSuccess[result.PayloadType]++ + } + } + + var summaryParts []string + for payloadType, total := range typeTotal { + success := typeSuccess[payloadType] + successRate := float64(success) / float64(total) * 100 + summaryParts = append(summaryParts, + fmt.Sprintf("%s: %.1f%% success (%d/%d)", payloadType, successRate, success, total)) + } + + return fmt.Sprintf("Pattern analysis: %s", strings.Join(summaryParts, ", ")) +} + +func (fc *FeedbackCollector) getRecentResults(count int) []TransactionResult { + if len(fc.results) <= count { + return fc.results + } + return fc.results[len(fc.results)-count:] +} + +func (fc *FeedbackCollector) GetStats() (uint64, uint64, uint64) { + fc.mutex.RLock() + defer fc.mutex.RUnlock() + return fc.totalTransactions, fc.successfulTxs, fc.failedTxs +} diff --git a/scenarios/aitx/geas_processor.go b/scenarios/aitx/geas_processor.go new file mode 100644 index 00000000..2cf186e5 --- /dev/null +++ b/scenarios/aitx/geas_processor.go @@ -0,0 +1,96 @@ +package aitx + +import ( + "fmt" + "strings" + + geas "github.com/fjl/geas/asm" + "github.com/sirupsen/logrus" +) + +type GeasProcessor struct { + logger logrus.FieldLogger +} + +func NewGeasProcessor(logger logrus.FieldLogger) *GeasProcessor { + return &GeasProcessor{ + logger: logger.WithField("component", "geas_processor"), + } +} + +func (gp *GeasProcessor) CompileGeasPayload(payload *PayloadInstance) ([]byte, error) { + compiler := geas.NewCompiler(nil) + return gp.compileInitRunGeas(payload, compiler) +} + +func (gp *GeasProcessor) compileInitRunGeas(payload *PayloadInstance, compiler *geas.Compiler) ([]byte, error) { + + // Build init code that deploys the contract + initcodeGeas := ` + ;; Init code + push @.start + codesize + sub + dup1 + push @.start + push0 + codecopy + push0 + return + + .start: + ` + + // Build the contract template with init, run, and post code + contractGeasTpl := ` + %s + gas ;; [gas, custom] + push 0 ;; [loop_counter, gas, custom] + jump @loop + + exit: + ;; Execute post code once at the end + %s + stop ;; [custom] + + loop: + push %d ;; [gas_remainder, loop_counter, gas, custom] + gas ;; [gas, gas_remainder, loop_counter, gas, custom] + lt ;; [gas < gas_remainder, loop_counter, gas, custom] + jumpi @exit ;; [loop_counter, gas, custom] + + ;; increase loop_counter + push 1 ;; [1, loop_counter, gas, custom] + add ;; [loop_counter+1, gas, custom] + + ;; run the performance test code + %s + + jump @loop + ` + + gp.logger.Debugf("compiling init_run geas - init: %s, run: %s, post: %s", + strings.ReplaceAll(payload.InitCode, "\n", "\\n"), + strings.ReplaceAll(payload.RunCode, "\n", "\\n"), + strings.ReplaceAll(payload.PostCode, "\n", "\\n")) + + // Compile init code + initcode := compiler.CompileString(initcodeGeas) + if initcode == nil { + return nil, fmt.Errorf("failed to compile geas init code: %v", compiler.Errors()) + } + + // Compile the contract code with init, run, and post parts + contractCode := compiler.CompileString(fmt.Sprintf(contractGeasTpl, payload.InitCode, payload.PostCode, payload.GasRemainder, payload.RunCode)) + if contractCode == nil { + return nil, fmt.Errorf("failed to compile geas contract code: %v", compiler.Errors()) + } + + // Combine init code and contract code + combinedCode := append(initcode, contractCode...) + + gp.logger.Debugf("compiled init_run geas code to %d bytes (init: %d, contract: %d)", + len(combinedCode), len(initcode), len(contractCode)) + + return combinedCode, nil +} diff --git a/scenarios/aitx/payload_processor.go b/scenarios/aitx/payload_processor.go new file mode 100644 index 00000000..91b67e9e --- /dev/null +++ b/scenarios/aitx/payload_processor.go @@ -0,0 +1,88 @@ +package aitx + +import ( + "fmt" + + "github.com/sirupsen/logrus" +) + +type PayloadProcessor struct { + logger logrus.FieldLogger + geasProcessor *GeasProcessor +} + +func NewPayloadProcessor(logger logrus.FieldLogger) *PayloadProcessor { + return &PayloadProcessor{ + logger: logger.WithField("component", "payload_processor"), + geasProcessor: NewGeasProcessor(logger), + } +} + +func (pp *PayloadProcessor) ValidatePayload(payload *PayloadTemplate) error { + if payload.Type != "geas" { + return fmt.Errorf("only 'geas' type is supported, got: %s", payload.Type) + } + + if payload.Description == "" { + return fmt.Errorf("payload description is required") + } + + if payload.InitCode == "" { + return fmt.Errorf("geas payload requires init_code") + } + + if payload.RunCode == "" { + return fmt.Errorf("geas payload requires run_code") + } + + // Validate geas compilation + if err := pp.validateGeasCompilation(payload); err != nil { + return fmt.Errorf("geas compilation failed: %w", err) + } + + return nil +} + +func (pp *PayloadProcessor) ProcessPayloads(templates []PayloadTemplate) ([]PayloadTemplate, error) { + var validPayloads []PayloadTemplate + + for i, template := range templates { + err := pp.ValidatePayload(&template) + if err != nil { + pp.logger.Errorf("invalid payload #%d (%s): %v", i+1, template.Description, err) + continue + } + + pp.logger.Infof("payload #%d (%s) validated successfully", i+1, template.Description) + validPayloads = append(validPayloads, template) + } + + if len(validPayloads) == 0 { + return nil, fmt.Errorf("no valid payloads found") + } + + pp.logger.Infof("processed %d payloads, %d valid", len(templates), len(validPayloads)) + return validPayloads, nil +} + +func (pp *PayloadProcessor) validateGeasCompilation(payload *PayloadTemplate) error { + // Create a temporary payload instance for compilation testing + tempPayload := &PayloadInstance{ + Type: payload.Type, + Description: payload.Description, + InitCode: payload.InitCode, + RunCode: payload.RunCode, + PostCode: payload.PostCode, + GasRemainder: 10000, // Default value for validation + } + + // Attempt to compile the geas code + _, err := pp.geasProcessor.CompileGeasPayload(tempPayload) + if err != nil { + pp.logger.Debugf("geas compilation validation failed for payload '%s': %v", payload.Description, err) + return err + } + + pp.logger.Debugf("geas compilation validation passed for payload '%s'", payload.Description) + return nil +} diff --git a/scenarios/aitx/payload_template.go b/scenarios/aitx/payload_template.go new file mode 100644 index 00000000..987a4484 --- /dev/null +++ b/scenarios/aitx/payload_template.go @@ -0,0 +1,231 @@ +package aitx + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "math/big" + "strconv" + "strings" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/sirupsen/logrus" + + "github.com/ethpandaops/spamoor/spamoor" +) + +type PayloadTemplate struct { + Type string `json:"type"` + Description string `json:"description"` + InitCode string `json:"init_code"` + RunCode string `json:"run_code"` + PostCode string `json:"post_code,omitempty"` // Optional - executes once at end + GasRemainder string `json:"gas_remainder,omitempty"` // Optional - defaults to 10000 + Calldata string `json:"calldata,omitempty"` // Optional - calldata for contract call +} + +type PayloadInstance struct { + Type string + Description string + InitCode string + RunCode string + PostCode string + GasRemainder uint64 + Calldata []byte +} + +type PlaceholderSubstituter struct { + walletPool *spamoor.WalletPool + client *spamoor.Client + logger logrus.FieldLogger +} + +func NewPlaceholderSubstituter(walletPool *spamoor.WalletPool, client *spamoor.Client, logger logrus.FieldLogger) *PlaceholderSubstituter { + return &PlaceholderSubstituter{ + walletPool: walletPool, + client: client, + logger: logger, + } +} + +func (pt *PayloadTemplate) Substitute(substituter *PlaceholderSubstituter) (*PayloadInstance, error) { + instance := &PayloadInstance{ + Type: pt.Type, + Description: pt.Description, + } + + var err error + instance.InitCode, err = substituter.SubstitutePlaceholders(pt.InitCode) + if err != nil { + return nil, fmt.Errorf("failed to substitute init code placeholders: %w", err) + } + + instance.RunCode, err = substituter.SubstitutePlaceholders(pt.RunCode) + if err != nil { + return nil, fmt.Errorf("failed to substitute run code placeholders: %w", err) + } + + instance.PostCode, err = substituter.SubstitutePlaceholders(pt.PostCode) + if err != nil { + return nil, fmt.Errorf("failed to substitute post code placeholders: %w", err) + } + + gasRemainderStr, err := substituter.SubstitutePlaceholders(pt.GasRemainder) + if err != nil { + return nil, fmt.Errorf("failed to substitute gas remainder placeholders: %w", err) + } + + if gasRemainderStr != "" { + gasRemainder, err := strconv.ParseUint(gasRemainderStr, 10, 64) + if err != nil { + instance.GasRemainder = 10000 // Default value + } else { + instance.GasRemainder = gasRemainder + } + } else { + instance.GasRemainder = 10000 // Default value + } + + // Handle calldata + if pt.Calldata != "" { + calldataStr, err := substituter.SubstitutePlaceholders(pt.Calldata) + if err != nil { + return nil, fmt.Errorf("failed to substitute calldata placeholders: %w", err) + } + + // Parse hex calldata + calldataStr = strings.TrimPrefix(calldataStr, "0x") + instance.Calldata, err = hex.DecodeString(calldataStr) + if err != nil { + return nil, fmt.Errorf("failed to decode calldata: %w", err) + } + } + + return instance, instance.Validate() +} + +func (pi *PayloadInstance) Validate() error { + if pi.Type != "geas" { + return fmt.Errorf("only 'geas' type is supported, got: %s", pi.Type) + } + + if pi.Description == "" { + return fmt.Errorf("payload description is required") + } + + if pi.InitCode == "" { + return fmt.Errorf("geas payload requires init_code") + } + + if pi.RunCode == "" { + return fmt.Errorf("geas payload requires run_code") + } + + return nil +} + +func (ps *PlaceholderSubstituter) SubstitutePlaceholders(input string) (string, error) { + result := input + + substitutions := map[string]func() string{ + "${WALLET_ADDRESS}": func() string { + if ps.walletPool.GetWalletCount() == 0 { + return "0x0000000000000000000000000000000000000000" + } + walletIdx := time.Now().UnixNano() % int64(ps.walletPool.GetWalletCount()) + wallet := ps.walletPool.GetWallet(spamoor.SelectWalletByIndex, int(walletIdx)) + return wallet.GetAddress().Hex() + }, + "${RANDOM_ADDRESS}": func() string { + bytes := make([]byte, 20) + rand.Read(bytes) + return common.BytesToAddress(bytes).Hex() + }, + "${ETH_AMOUNT_SMALL}": func() string { + min := big.NewInt(1000000000000000) // 0.001 ETH + max := big.NewInt(10000000000000000) // 0.01 ETH + diff := new(big.Int).Sub(max, min) + randInt, _ := rand.Int(rand.Reader, diff) + return new(big.Int).Add(min, randInt).String() + }, + "${ETH_AMOUNT_MEDIUM}": func() string { + min := big.NewInt(10000000000000000) // 0.01 ETH + max := big.NewInt(100000000000000000) // 0.1 ETH + diff := new(big.Int).Sub(max, min) + randInt, _ := rand.Int(rand.Reader, diff) + return new(big.Int).Add(min, randInt).String() + }, + "${ETH_AMOUNT_LARGE}": func() string { + min := big.NewInt(100000000000000000) // 0.1 ETH + max := big.NewInt(1000000000000000000) // 1.0 ETH + diff := new(big.Int).Sub(max, min) + randInt, _ := rand.Int(rand.Reader, diff) + return new(big.Int).Add(min, randInt).String() + }, + "${GAS_LIMIT_LOW}": func() string { + randBytes := make([]byte, 4) + rand.Read(randBytes) + randVal := int64(randBytes[0])<<24 | int64(randBytes[1])<<16 | int64(randBytes[2])<<8 | int64(randBytes[3]) + if randVal < 0 { + randVal = -randVal + } + return fmt.Sprintf("%d", 21000+(randVal%29000)) + }, + "${GAS_LIMIT_MEDIUM}": func() string { + randBytes := make([]byte, 4) + rand.Read(randBytes) + randVal := int64(randBytes[0])<<24 | int64(randBytes[1])<<16 | int64(randBytes[2])<<8 | int64(randBytes[3]) + if randVal < 0 { + randVal = -randVal + } + return fmt.Sprintf("%d", 50000+(randVal%150000)) + }, + "${GAS_LIMIT_HIGH}": func() string { + randBytes := make([]byte, 4) + rand.Read(randBytes) + randVal := int64(randBytes[0])<<24 | int64(randBytes[1])<<16 | int64(randBytes[2])<<8 | int64(randBytes[3]) + if randVal < 0 { + randVal = -randVal + } + return fmt.Sprintf("%d", 200000+(randVal%800000)) + }, + "${RANDOM_UINT256}": func() string { + bytes := make([]byte, 32) + rand.Read(bytes) + return new(big.Int).SetBytes(bytes).String() + }, + "${RANDOM_BYTES32}": func() string { + bytes := make([]byte, 32) + rand.Read(bytes) + return "0x" + hex.EncodeToString(bytes) + }, + "${CURRENT_BLOCK}": func() string { + return "0" + }, + "${LOOP_COUNT_SMALL}": func() string { + randBytes := make([]byte, 1) + rand.Read(randBytes) + return fmt.Sprintf("%d", 1+(int(randBytes[0])%10)) + }, + "${LOOP_COUNT_MEDIUM}": func() string { + randBytes := make([]byte, 1) + rand.Read(randBytes) + return fmt.Sprintf("%d", 10+(int(randBytes[0])%90)) + }, + "${LOOP_COUNT_LARGE}": func() string { + randBytes := make([]byte, 2) + rand.Read(randBytes) + randVal := int(randBytes[0])<<8 | int(randBytes[1]) + return fmt.Sprintf("%d", 100+(randVal%900)) + }, + } + + for placeholder, substituteFn := range substitutions { + if strings.Contains(result, placeholder) { + result = strings.ReplaceAll(result, placeholder, substituteFn()) + } + } + + return result, nil +} diff --git a/scenarios/scenarios.go b/scenarios/scenarios.go index 0283f354..aae0e417 100644 --- a/scenarios/scenarios.go +++ b/scenarios/scenarios.go @@ -3,6 +3,7 @@ package scenarios import ( "github.com/ethpandaops/spamoor/scenario" + "github.com/ethpandaops/spamoor/scenarios/aitx" blobcombined "github.com/ethpandaops/spamoor/scenarios/blob-combined" blobconflicting "github.com/ethpandaops/spamoor/scenarios/blob-conflicting" blobreplacements "github.com/ethpandaops/spamoor/scenarios/blob-replacements" @@ -27,6 +28,7 @@ import ( // This registry includes scenarios for testing various Ethereum transaction types and patterns. // Each descriptor defines the configuration, constructor, and metadata for a specific test scenario. var ScenarioDescriptors = []*scenario.Descriptor{ + &aitx.ScenarioDescriptor, &blobcombined.ScenarioDescriptor, &blobconflicting.ScenarioDescriptor, &blobs.ScenarioDescriptor, From 382728f17ce95e7a7c56ec2ff72bfc19fe752a59 Mon Sep 17 00:00:00 2001 From: pk910 Date: Tue, 8 Jul 2025 03:16:39 +0200 Subject: [PATCH 2/4] continue aitx scenario --- scenarios/aitx/aitx.go | 524 +++++++++++++++++++++++++++++------------ 1 file changed, 368 insertions(+), 156 deletions(-) diff --git a/scenarios/aitx/aitx.go b/scenarios/aitx/aitx.go index 97104401..5d47e981 100644 --- a/scenarios/aitx/aitx.go +++ b/scenarios/aitx/aitx.go @@ -4,6 +4,7 @@ import ( "context" "encoding/hex" "fmt" + "sort" "sync" "time" @@ -52,6 +53,17 @@ type ScenarioOptions struct { LogAIConversations bool `yaml:"log_ai_conversations"` } +type PayloadState struct { + Template PayloadTemplate + IsDeployed bool + IsDeploying bool + ContractAddress common.Address + SuccessCount int + FailCount int + LastUsed time.Time + mutex sync.Mutex // Protects individual payload state +} + type Scenario struct { options ScenarioOptions logger *logrus.Entry @@ -63,11 +75,17 @@ type Scenario struct { feedbackCollector *FeedbackCollector geasProcessor *GeasProcessor - payloadCache []PayloadTemplate - cacheIndex int - aiMutex sync.Mutex // Protects AI calls and payload cache - conversationHistory []Message // Persisted conversation history - conversationResponses int // Number of AI responses in current conversation + // Async payload management + payloadStates []*PayloadState + payloadMutex sync.RWMutex // Protects payload states slice + payloadRoundRobin int // Round-robin index + aiRequestChan chan struct{} // Signals need for more payloads + aiReadyChan chan struct{} // Signals AI has returned payloads + shutdownChan chan struct{} // Signals shutdown + conversationHistory []Message // Persisted conversation history + conversationResponses int // Number of AI responses in current conversation + aiWorkerRunning bool // Tracks if AI worker is running + aiWorkerMutex sync.Mutex // Protects AI worker state } var ScenarioName = "aitx" @@ -201,6 +219,12 @@ func (s *Scenario) Init(options *scenario.Options) error { basePrompt := s.aiService.buildBasePrompt(s.options.GenerationMode) s.aiService.SetBasePrompt(basePrompt) + // Initialize async payload management + s.payloadStates = make([]*PayloadState, 0, 100) + s.aiRequestChan = make(chan struct{}, 1) + s.aiReadyChan = make(chan struct{}, 1) + s.shutdownChan = make(chan struct{}) + return nil } @@ -208,6 +232,25 @@ func (s *Scenario) Run(ctx context.Context) error { s.logger.Infof("starting AI transaction generator scenario") defer s.logger.Infof("AI transaction generator scenario finished") + // Start background AI worker + go s.aiWorker(ctx) + + // Initial AI request to get started + select { + case s.aiRequestChan <- struct{}{}: + default: + } + + // Wait for AI to be ready + s.logger.Infof("waiting for AI payloads to be ready") + select { + case <-s.aiReadyChan: + case <-ctx.Done(): + return ctx.Err() + } + + s.logger.Infof("AI payloads ready, starting transaction generation") + maxPending := s.options.MaxPending if maxPending == 0 { maxPending = s.options.Throughput * 10 @@ -263,11 +306,14 @@ func (s *Scenario) Run(ctx context.Context) error { }, }) + // Signal shutdown to AI worker + close(s.shutdownChan) + return err } func (s *Scenario) sendAITransaction(ctx context.Context, txIdx uint64, onComplete func()) (*types.Transaction, *spamoor.Client, *spamoor.Wallet, error) { - // Deploy a contract and send 10 call transactions using batch sending + // Send single call transaction using round-robin payload selection defer onComplete() client := s.walletPool.GetClient(spamoor.SelectClientByIndex, int(txIdx), s.options.ClientGroup) @@ -277,185 +323,101 @@ func (s *Scenario) sendAITransaction(ctx context.Context, txIdx uint64, onComple return nil, client, wallet, fmt.Errorf("no client available") } - // Get next payload template from AI or cache - template, err := s.getNextPayloadTemplate(ctx) + // Get next payload using round-robin selection + payloadState, err := s.getNextPayload(ctx) if err != nil { - s.logger.Errorf("failed to get AI payload template: %v", err) - dummyPayload := &PayloadInstance{Type: "geas", Description: "failed_generation"} - s.feedbackCollector.RecordFailure(dummyPayload, "payload_generation_failed", err.Error()) + s.logger.Errorf("failed to get payload: %v", err) return nil, client, wallet, err } // Substitute placeholders - payload, err := template.Substitute(s.placeholderSubstituter) + payload, err := payloadState.Template.Substitute(s.placeholderSubstituter) if err != nil { s.logger.Errorf("failed to substitute placeholders: %v", err) - dummyPayload := &PayloadInstance{Type: template.Type, Description: template.Description} - s.feedbackCollector.RecordFailure(dummyPayload, "placeholder_substitution_failed", err.Error()) + s.recordPayloadFailure(payloadState, "placeholder_substitution_failed", err.Error()) return nil, client, wallet, err } - // Build deployment transaction - s.logger.Infof("deploying contract for payload: %s", payload.Description) - deployTx, contractAddress, err := s.deployGeasContract(ctx, wallet, client, payload) + // Handle deployment if needed + contractAddress, deployTx, err := s.ensureContractDeployed(ctx, payloadState, payload, wallet, client, txIdx) if err != nil { s.logger.Errorf("failed to deploy contract: %v", err) - s.feedbackCollector.RecordFailure(payload, "deployment_failed", err.Error()) - return nil, client, wallet, err + s.recordPayloadFailure(payloadState, "deployment_failed", err.Error()) + return deployTx, client, wallet, err } - // Deploy contract and wait for confirmation using SendAndAwaitTransaction - deployReceipt, err := s.walletPool.GetTxPool().SendAndAwaitTransaction(ctx, wallet, deployTx, &spamoor.SendTransactionOptions{ - Client: client, - Rebroadcast: s.options.Rebroadcast > 0, - LogFn: spamoor.GetDefaultLogFn(s.logger, "deploy", fmt.Sprintf("%6d", txIdx+1), deployTx), - }) - + // Build call transaction + callTx, err := s.callGeasContract(ctx, wallet, client, contractAddress, payload) if err != nil { - s.logger.Errorf("failed to deploy contract: %v", err) - s.feedbackCollector.RecordFailure(payload, "deployment_failed", err.Error()) - return nil, client, wallet, err - } - - if deployReceipt.Status != 1 { - s.logger.Errorf("contract deployment failed (status: %d)", deployReceipt.Status) - s.feedbackCollector.RecordFailure(payload, "deployment_reverted", "deployment transaction reverted") - return deployTx, client, wallet, nil + s.logger.Errorf("failed to build call transaction: %v", err) + s.recordPayloadFailure(payloadState, "call_build_failed", err.Error()) + return deployTx, client, wallet, err } - s.logger.Infof("contract deployed successfully at %s for payload: %s", contractAddress.Hex(), payload.Description) - - // Build 10 call transactions - var callTxs []*types.Transaction - for i := 0; i < 10; i++ { - callTx, err := s.callGeasContract(ctx, wallet, client, contractAddress, payload) - if err != nil { - s.logger.Errorf("failed to build call transaction %d: %v", i+1, err) - s.feedbackCollector.RecordFailure(payload, "call_build_failed", err.Error()) - return deployTx, client, wallet, err - } - callTxs = append(callTxs, callTx) - } - - // Send all call transactions as a batch from same wallet - _, err = s.walletPool.GetTxPool().SendTransactionBatch(ctx, wallet, callTxs, &spamoor.BatchOptions{ - SendTransactionOptions: spamoor.SendTransactionOptions{ - Client: client, - Rebroadcast: s.options.Rebroadcast > 0, - OnConfirm: func(tx *types.Transaction, receipt *types.Receipt) { - // Collect execution results for feedback from call transactions - s.collectTransactionResult(payload, tx, receipt) - }, + // Send call transaction + err = s.walletPool.GetTxPool().SendTransaction(ctx, wallet, callTx, &spamoor.SendTransactionOptions{ + Client: client, + Rebroadcast: s.options.Rebroadcast > 0, + OnConfirm: func(tx *types.Transaction, receipt *types.Receipt) { + // Record success + s.recordPayloadSuccess(payloadState, payload, tx, receipt) }, + OnComplete: func(tx *types.Transaction, receipt *types.Receipt, err error) { + if err != nil { + s.recordPayloadFailure(payloadState, "call_failed", err.Error()) + } + }, + LogFn: spamoor.GetDefaultLogFn(s.logger, "call", fmt.Sprintf("%6d", txIdx+1), callTx), }) if err != nil { - s.logger.Errorf("failed to send call transaction batch: %v", err) - s.feedbackCollector.RecordFailure(payload, "batch_send_failed", err.Error()) - return deployTx, client, wallet, err + s.logger.Errorf("failed to send call transaction: %v", err) + s.recordPayloadFailure(payloadState, "call_send_failed", err.Error()) + return callTx, client, wallet, err } - return deployTx, client, wallet, nil + return callTx, client, wallet, nil } -func (s *Scenario) getNextPayloadTemplate(ctx context.Context) (*PayloadTemplate, error) { - // Lock to ensure only one AI call happens at a time - s.aiMutex.Lock() - defer s.aiMutex.Unlock() - - // Check if we have cached payloads - if s.cacheIndex < len(s.payloadCache) { - template := s.payloadCache[s.cacheIndex] - s.cacheIndex++ - return &template, nil - } - - // Generate new batch of payloads - if s.aiService.GetCallCount() >= s.options.MaxAICalls { - return nil, fmt.Errorf("maximum AI calls limit reached (%d)", s.options.MaxAICalls) - } - - if s.aiService.GetTokenCount() >= s.options.MaxTokens { - return nil, fmt.Errorf("maximum token limit reached (%d)", s.options.MaxTokens) - } - - // Check if we need to start a new conversation (after 10 responses) - if s.conversationResponses >= 10 { - s.logger.Infof("resetting conversation after %d responses", s.conversationResponses) - s.conversationHistory = nil - s.conversationResponses = 0 - } - - if len(s.conversationHistory) == 0 { - s.logger.Infof("making AI call #%d - starting new conversation (other transactions waiting)", s.aiService.GetCallCount()+1) - } else { - s.logger.Infof("making AI call #%d - continuing conversation with %d messages (other transactions waiting)", - s.aiService.GetCallCount()+1, len(s.conversationHistory)) - } - - // Generate payloads using conversation continuation - var response *GenerationResponse - var err error - - if len(s.conversationHistory) == 0 { - // Start new conversation - req := GenerationRequest{ - TestDirection: s.options.TestDirection, - GenerationMode: s.options.GenerationMode, - PayloadCount: s.options.PayloadsPerRequest, - PreviousSummary: "", - TransactionFeedback: nil, +// getNextPayload selects the next payload using round-robin, preferring payloads with < 20 successes +func (s *Scenario) getNextPayload(ctx context.Context) (*PayloadState, error) { + s.payloadMutex.Lock() + defer s.payloadMutex.Unlock() + + // Check for payloads with < 20 successes + for i := 0; i < len(s.payloadStates); i++ { + idx := (s.payloadRoundRobin + i) % len(s.payloadStates) + payloadState := s.payloadStates[idx] + if payloadState.SuccessCount < 20 && !payloadState.IsDeploying { + s.payloadRoundRobin = (idx + 1) % len(s.payloadStates) + payloadState.LastUsed = time.Now() + return payloadState, nil } - - // Add feedback if enabled - if s.options.EnableFeedbackLoop { - req.TransactionFeedback = s.feedbackCollector.GenerateFeedback() - } - - response, s.conversationHistory, err = s.aiService.GeneratePayloadsWithConversation(ctx, req, s.processor, nil) - } else { - // Continue existing conversation - feedback := "" - if s.options.EnableFeedbackLoop { - txFeedback := s.feedbackCollector.GenerateFeedback() - if txFeedback != nil { - feedback = fmt.Sprintf("Transaction feedback: %d total (%d success, %d failed), avg gas: %d. Generate more diverse patterns based on this data.", - txFeedback.TotalTransactions, txFeedback.SuccessfulTxs, txFeedback.FailedTxs, txFeedback.AverageGasUsed) - } - } - - if feedback == "" { - feedback = fmt.Sprintf("Generate %d more unique geas init_run contracts with different patterns and behaviors.", s.options.PayloadsPerRequest) - } - - response, s.conversationHistory, err = s.aiService.GeneratePayloadsWithConversation(ctx, GenerationRequest{}, s.processor, &ConversationContinuation{ - History: s.conversationHistory, - Feedback: feedback, - }) - } - - if err != nil { - return nil, fmt.Errorf("AI payload generation failed: %w", err) } - // Increment conversation response count - s.conversationResponses++ - - // Payloads are already validated by the AI service - validPayloads := response.Payloads + // No payload with < 20 successes, request more payloads + s.requestMorePayloads() - // Update cache - s.payloadCache = validPayloads - s.cacheIndex = 1 // Return first, set index to second - - if len(validPayloads) == 0 { - return nil, fmt.Errorf("no valid payloads generated") + // If we have any payloads, return the next one + if len(s.payloadStates) > 0 { + idx := s.payloadRoundRobin % len(s.payloadStates) + payloadState := s.payloadStates[idx] + s.payloadRoundRobin = (idx + 1) % len(s.payloadStates) + payloadState.LastUsed = time.Now() + return payloadState, nil } - s.logger.Infof("AI call completed, generated %d payloads (conversation: %d responses, cache refilled)", - len(validPayloads), s.conversationResponses) + return nil, fmt.Errorf("no payloads available") +} - return &validPayloads[0], nil +// requestMorePayloads signals the AI worker to generate more payloads +func (s *Scenario) requestMorePayloads() { + select { + case s.aiRequestChan <- struct{}{}: + s.logger.Debugf("requested more payloads from AI worker") + default: + // Channel full, request already pending + } } func (s *Scenario) deployGeasContract(ctx context.Context, wallet *spamoor.Wallet, client *spamoor.Client, payload *PayloadInstance) (*types.Transaction, common.Address, error) { @@ -518,7 +480,11 @@ func (s *Scenario) callGeasContract(ctx context.Context, wallet *spamoor.Wallet, return wallet.BuildDynamicFeeTx(txData) } -func (s *Scenario) collectTransactionResult(payload *PayloadInstance, tx *types.Transaction, receipt *types.Receipt) { +func (s *Scenario) recordPayloadSuccess(payloadState *PayloadState, payload *PayloadInstance, tx *types.Transaction, receipt *types.Receipt) { + payloadState.mutex.Lock() + payloadState.SuccessCount++ + payloadState.mutex.Unlock() + if receipt == nil { s.feedbackCollector.RecordFailure(payload, "receipt_nil", "receipt was nil") return @@ -529,6 +495,11 @@ func (s *Scenario) collectTransactionResult(payload *PayloadInstance, tx *types. errorMsg := "" if receipt.Status == 0 { status = "reverted" + // Record as failure instead of success + payloadState.mutex.Lock() + payloadState.SuccessCount-- + payloadState.FailCount++ + payloadState.mutex.Unlock() } // Calculate transaction fees @@ -544,8 +515,9 @@ func (s *Scenario) collectTransactionResult(payload *PayloadInstance, tx *types. hex.EncodeToString(log.Data))) } - s.logger.Debugf("transaction confirmed: %s (%s) - %s, gas: %d, fees: %s, logs: %d", - payload.Description, payload.Type, status, receipt.GasUsed, txFees.TotalFeeGweiString(), len(receipt.Logs)) + s.logger.Debugf("transaction confirmed: %s (%s) - %s, gas: %d, fees: %s, logs: %d, success: %d, fail: %d", + payload.Description, payload.Type, status, receipt.GasUsed, txFees.TotalFeeGweiString(), len(receipt.Logs), + payloadState.SuccessCount, payloadState.FailCount) // Record result for feedback result := TransactionResult{ @@ -560,3 +532,243 @@ func (s *Scenario) collectTransactionResult(payload *PayloadInstance, tx *types. s.feedbackCollector.RecordResult(result) } + +// ensureContractDeployed ensures the contract is deployed for the payload +func (s *Scenario) ensureContractDeployed(ctx context.Context, payloadState *PayloadState, payload *PayloadInstance, wallet *spamoor.Wallet, client *spamoor.Client, txIdx uint64) (common.Address, *types.Transaction, error) { + payloadState.mutex.Lock() + defer payloadState.mutex.Unlock() + + // Check if already deployed + if payloadState.IsDeployed { + return payloadState.ContractAddress, nil, nil + } + + payloadState.IsDeploying = true + defer func() { + payloadState.IsDeploying = false + }() + + // Deploy the contract + s.logger.Infof("deploying contract for payload: %s", payload.Description) + deployTx, contractAddress, err := s.deployGeasContract(ctx, wallet, client, payload) + if err != nil { + return common.Address{}, nil, fmt.Errorf("failed to build deployment transaction: %w", err) + } + + // Deploy contract and wait for confirmation + deployReceipt, err := s.walletPool.GetTxPool().SendAndAwaitTransaction(ctx, wallet, deployTx, &spamoor.SendTransactionOptions{ + Client: client, + Rebroadcast: s.options.Rebroadcast > 0, + LogFn: spamoor.GetDefaultLogFn(s.logger, "deploy", fmt.Sprintf("%6d", txIdx+1), deployTx), + }) + + if err != nil { + return common.Address{}, deployTx, fmt.Errorf("failed to deploy contract: %w", err) + } + + if deployReceipt.Status != 1 { + return common.Address{}, deployTx, fmt.Errorf("contract deployment failed (status: %d)", deployReceipt.Status) + } + + // Mark as deployed + payloadState.IsDeployed = true + payloadState.ContractAddress = contractAddress + + s.logger.Infof("contract deployed successfully at %s for payload: %s", contractAddress.Hex(), payload.Description) + return contractAddress, deployTx, nil +} + +// aiWorker runs in background to generate payloads asynchronously +func (s *Scenario) aiWorker(ctx context.Context) { + s.aiWorkerMutex.Lock() + if s.aiWorkerRunning { + s.aiWorkerMutex.Unlock() + return + } + s.aiWorkerRunning = true + s.aiWorkerMutex.Unlock() + + s.logger.Infof("starting AI worker for background payload generation") + defer s.logger.Infof("AI worker stopped") + + for { + select { + case <-ctx.Done(): + return + case <-s.shutdownChan: + return + case <-s.aiRequestChan: + // Generate new payloads + if err := s.generatePayloads(ctx); err != nil { + s.logger.Errorf("failed to generate payloads: %v", err) + // Wait before retrying + select { + case <-time.After(5 * time.Second): + case <-ctx.Done(): + return + case <-s.shutdownChan: + return + } + } else { + select { + case s.aiReadyChan <- struct{}{}: + default: + } + select { + case <-s.aiRequestChan: + default: + } + } + } + } +} + +// generatePayloads generates new payloads from AI and manages the payload pool +func (s *Scenario) generatePayloads(ctx context.Context) error { + // Check AI limits + if s.aiService.GetCallCount() >= s.options.MaxAICalls { + s.logger.Warnf("maximum AI calls limit reached (%d)", s.options.MaxAICalls) + return fmt.Errorf("maximum AI calls limit reached") + } + + if s.aiService.GetTokenCount() >= s.options.MaxTokens { + s.logger.Warnf("maximum token limit reached (%d)", s.options.MaxTokens) + return fmt.Errorf("maximum token limit reached") + } + + // Check if we need to start a new conversation (after 10 responses) + if s.conversationResponses >= 10 { + s.logger.Infof("resetting conversation after %d responses", s.conversationResponses) + s.conversationHistory = nil + s.conversationResponses = 0 + } + + if len(s.conversationHistory) == 0 { + s.logger.Infof("making AI call #%d - starting new conversation", s.aiService.GetCallCount()+1) + } else { + s.logger.Infof("making AI call #%d - continuing conversation with %d messages", + s.aiService.GetCallCount()+1, len(s.conversationHistory)) + } + + // Generate payloads using conversation continuation + var response *GenerationResponse + var err error + + if len(s.conversationHistory) == 0 { + // Start new conversation + req := GenerationRequest{ + TestDirection: s.options.TestDirection, + GenerationMode: s.options.GenerationMode, + PayloadCount: s.options.PayloadsPerRequest, + PreviousSummary: "", + TransactionFeedback: nil, + } + + // Add feedback if enabled + if s.options.EnableFeedbackLoop { + req.TransactionFeedback = s.feedbackCollector.GenerateFeedback() + } + + response, s.conversationHistory, err = s.aiService.GeneratePayloadsWithConversation(ctx, req, s.processor, nil) + } else { + // Continue existing conversation + feedback := "" + if s.options.EnableFeedbackLoop { + txFeedback := s.feedbackCollector.GenerateFeedback() + if txFeedback != nil { + feedback = fmt.Sprintf("Transaction feedback: %d total (%d success, %d failed), avg gas: %d. Generate more diverse patterns based on this data.", + txFeedback.TotalTransactions, txFeedback.SuccessfulTxs, txFeedback.FailedTxs, txFeedback.AverageGasUsed) + } + } + + if feedback == "" { + feedback = fmt.Sprintf("Generate %d more unique geas init_run contracts with different patterns and behaviors.", s.options.PayloadsPerRequest) + } + + response, s.conversationHistory, err = s.aiService.GeneratePayloadsWithConversation(ctx, GenerationRequest{}, s.processor, &ConversationContinuation{ + History: s.conversationHistory, + Feedback: feedback, + }) + } + + if err != nil { + return fmt.Errorf("AI payload generation failed: %w", err) + } + + // Increment conversation response count + s.conversationResponses++ + + // Add new payloads to the pool + s.addPayloadsToPool(response.Payloads) + + s.logger.Infof("AI call completed, generated %d payloads (conversation: %d responses)", + len(response.Payloads), s.conversationResponses) + + return nil +} + +// addPayloadsToPool adds new payloads to the pool and manages the 100-payload limit +func (s *Scenario) addPayloadsToPool(templates []PayloadTemplate) { + s.payloadMutex.Lock() + defer s.payloadMutex.Unlock() + + // Add new payloads + for _, template := range templates { + payloadState := &PayloadState{ + Template: template, + IsDeployed: false, + SuccessCount: 0, + FailCount: 0, + LastUsed: time.Now(), + } + s.payloadStates = append(s.payloadStates, payloadState) + } + + // Clean up if we exceed 100 payloads + if len(s.payloadStates) > 100 { + s.cleanupPayloads() + } + + s.logger.Infof("added %d payloads to pool, total: %d", len(templates), len(s.payloadStates)) +} + +// cleanupPayloads removes failing payloads first, then payloads with highest success count +func (s *Scenario) cleanupPayloads() { + // Sort by fail count (descending), then by success count (descending) + sort.Slice(s.payloadStates, func(i, j int) bool { + if s.payloadStates[i].FailCount != s.payloadStates[j].FailCount { + return s.payloadStates[i].FailCount > s.payloadStates[j].FailCount + } + return s.payloadStates[i].SuccessCount > s.payloadStates[j].SuccessCount + }) + + // Remove the worst 25% to get back to 75 payloads + targetSize := 75 + if len(s.payloadStates) > targetSize { + removedCount := len(s.payloadStates) - targetSize + s.payloadStates = s.payloadStates[removedCount:] + s.logger.Infof("cleaned up %d payloads, remaining: %d", removedCount, len(s.payloadStates)) + } + + // Reset round-robin index if needed + if s.payloadRoundRobin >= len(s.payloadStates) { + s.payloadRoundRobin = 0 + } +} + +func (s *Scenario) recordPayloadFailure(payloadState *PayloadState, errorType string, errorMsg string) { + payloadState.mutex.Lock() + payloadState.FailCount++ + payloadState.mutex.Unlock() + + // Create dummy payload for feedback + dummyPayload := &PayloadInstance{ + Type: payloadState.Template.Type, + Description: payloadState.Template.Description, + } + s.feedbackCollector.RecordFailure(dummyPayload, errorType, errorMsg) + + s.logger.Debugf("payload failure: %s - %s: %s, success: %d, fail: %d", + payloadState.Template.Description, errorType, errorMsg, + payloadState.SuccessCount, payloadState.FailCount) +} From 955d229b22f43573cd8eb2ec80d1b5173f8a4253 Mon Sep 17 00:00:00 2001 From: pk910 Date: Tue, 8 Jul 2025 05:09:37 +0200 Subject: [PATCH 3/4] continue implementation --- .hack/devnet/kurtosis.devnet.config.yaml | 1 + scenarios/aitx/ai_service.go | 25 +- scenarios/aitx/aitx.go | 327 ++++++++++++++++++++--- scenarios/aitx/feedback_collector.go | 236 ++++++++++++---- scenarios/aitx/payload_processor.go | 21 +- 5 files changed, 499 insertions(+), 111 deletions(-) diff --git a/.hack/devnet/kurtosis.devnet.config.yaml b/.hack/devnet/kurtosis.devnet.config.yaml index e33d20f3..a2e95679 100644 --- a/.hack/devnet/kurtosis.devnet.config.yaml +++ b/.hack/devnet/kurtosis.devnet.config.yaml @@ -9,6 +9,7 @@ network_params: preset: mainnet gas_limit: 100000000 genesis_gaslimit: 100000000 +snooper_enabled: true additional_services: - spamoor # required for config extraction - dora diff --git a/scenarios/aitx/ai_service.go b/scenarios/aitx/ai_service.go index e4e53d5b..5612517e 100644 --- a/scenarios/aitx/ai_service.go +++ b/scenarios/aitx/ai_service.go @@ -394,7 +394,25 @@ func (ai *AIService) buildBasePrompt(generationMode string) string { promptBuilder.WriteString("2. Computation loops: Mathematical operations with clean stack management\n") promptBuilder.WriteString("3. Storage patterns: Read/write with counters or mappings\n") promptBuilder.WriteString("4. Event emission: Log computation results or state changes\n") - promptBuilder.WriteString("5. Memory operations: Expand memory, hash data, manipulate arrays\n\n") + promptBuilder.WriteString("5. Memory operations: Expand memory, hash data, manipulate arrays\n") + promptBuilder.WriteString("6. Precompile/contract calls: Use CALL opcode to interact with other contracts\n\n") + + promptBuilder.WriteString("PRECOMPILE/CONTRACT CALL PATTERN:\n") + promptBuilder.WriteString("To call precompiles (addresses 1-9) or other contracts, use this pattern:\n") + promptBuilder.WriteString("```\n") + promptBuilder.WriteString("PUSH1 0x20 ; retSize\n") + promptBuilder.WriteString("PUSH1 0x00 ; retOffset\n") + promptBuilder.WriteString("PUSH1 0x20 ; argsSize\n") + promptBuilder.WriteString("PUSH1 0x00 ; argsOffset\n") + promptBuilder.WriteString("PUSH1 0x00 ; value\n") + promptBuilder.WriteString("PUSH1 0x05 ; address (example: precompile 5 = modexp)\n") + promptBuilder.WriteString("PUSH2 0xC350 ; gas (50000)\n") + promptBuilder.WriteString("GAS\n") + promptBuilder.WriteString("SUB\n") + promptBuilder.WriteString("CALL\n") + promptBuilder.WriteString("POP ; remove success flag\n") + promptBuilder.WriteString("```\n") + promptBuilder.WriteString("Common precompiles: 1=ecrecover, 2=sha256, 3=ripemd160, 4=identity, 5=modexp, 6=ecadd, 7=ecmul, 8=ecpairing, 9=blake2f\n\n") promptBuilder.WriteString("AVAILABLE PLACEHOLDERS:\n") promptBuilder.WriteString("- ${RANDOM_UINT256}: Random 256-bit unsigned integer\n") @@ -484,7 +502,6 @@ func (ai *AIService) parseResponse(response *OpenRouterResponse) (*GenerationRes } content := response.Choices[0].Message.Content - ai.logger.Infof("AI response content: %s", content) var payloads []PayloadTemplate @@ -539,12 +556,12 @@ func (ai *AIService) extractJSONObjectsFromText(content string) ([]PayloadTempla if strings.HasPrefix(line, "```") && inJSONBlock { // End of JSON block, try to parse it jsonStr := jsonBlock.String() - ai.logger.Infof("Attempting to parse JSON block: %s", jsonStr) + ai.logger.Debugf("Attempting to parse JSON block: %s", jsonStr) var payload PayloadTemplate if err := json.Unmarshal([]byte(jsonStr), &payload); err == nil { payloads = append(payloads, payload) - ai.logger.Infof("Successfully parsed payload: %s", payload.Description) + ai.logger.Debugf("Successfully parsed payload: %s", payload.Description) } else { ai.logger.Errorf("Failed to parse JSON block: %v", err) } diff --git a/scenarios/aitx/aitx.go b/scenarios/aitx/aitx.go index 5d47e981..cae7895d 100644 --- a/scenarios/aitx/aitx.go +++ b/scenarios/aitx/aitx.go @@ -3,7 +3,9 @@ package aitx import ( "context" "encoding/hex" + "encoding/json" "fmt" + "os" "sort" "sync" "time" @@ -51,6 +53,15 @@ type ScenarioOptions struct { // Debug options LogAIConversations bool `yaml:"log_ai_conversations"` + + // Persistence options + MaxPayloads int `yaml:"max_payloads"` + PersistenceFile string `yaml:"persistence_file"` + SavePersistence bool `yaml:"save_persistence"` + LoadPersistence bool `yaml:"load_persistence"` + + // Payload management options + SuccessThreshold int `yaml:"success_threshold"` } type PayloadState struct { @@ -61,6 +72,7 @@ type PayloadState struct { SuccessCount int FailCount int LastUsed time.Time + BatchID int // AI batch ID this payload belongs to mutex sync.Mutex // Protects individual payload state } @@ -119,6 +131,15 @@ var ScenarioDefaultOptions = ScenarioOptions{ // Debug defaults LogAIConversations: false, + + // Persistence defaults + MaxPayloads: 100, + PersistenceFile: "", + SavePersistence: true, + LoadPersistence: true, + + // Payload management defaults + SuccessThreshold: 20, } var ScenarioDescriptor = scenario.Descriptor{ @@ -165,6 +186,15 @@ func (s *Scenario) Flags(flags *pflag.FlagSet) error { // Debug flags flags.BoolVar(&s.options.LogAIConversations, "log-ai-conversations", ScenarioDefaultOptions.LogAIConversations, "Enable detailed logging of AI conversations for debugging") + // Persistence flags + flags.IntVar(&s.options.MaxPayloads, "max-payloads", ScenarioDefaultOptions.MaxPayloads, "Maximum number of payloads to keep in memory") + flags.StringVar(&s.options.PersistenceFile, "persistence-file", ScenarioDefaultOptions.PersistenceFile, "File to save/load payloads for persistence") + flags.BoolVar(&s.options.SavePersistence, "save-persistence", ScenarioDefaultOptions.SavePersistence, "Save payloads to persistence file on shutdown") + flags.BoolVar(&s.options.LoadPersistence, "load-persistence", ScenarioDefaultOptions.LoadPersistence, "Load payloads from persistence file on startup") + + // Payload management flags + flags.IntVar(&s.options.SuccessThreshold, "success-threshold", ScenarioDefaultOptions.SuccessThreshold, "Number of successful calls before requesting new payloads") + return nil } @@ -210,8 +240,8 @@ func (s *Scenario) Init(options *scenario.Options) error { // Initialize AI components s.aiService = NewAIService(s.options.OpenRouterAPIKey, s.options.Model, s.options.LogAIConversations, s.logger) - s.processor = NewPayloadProcessor(s.logger) s.placeholderSubstituter = NewPlaceholderSubstituter(s.walletPool, s.walletPool.GetClient(spamoor.SelectClientByIndex, 0, ""), s.logger) + s.processor = NewPayloadProcessor(s.logger, s.placeholderSubstituter) s.feedbackCollector = NewFeedbackCollector(s.options.FeedbackBatchSize, s.logger) s.geasProcessor = NewGeasProcessor(s.logger) @@ -220,11 +250,22 @@ func (s *Scenario) Init(options *scenario.Options) error { s.aiService.SetBasePrompt(basePrompt) // Initialize async payload management - s.payloadStates = make([]*PayloadState, 0, 100) + s.payloadStates = make([]*PayloadState, 0, s.options.MaxPayloads) s.aiRequestChan = make(chan struct{}, 1) s.aiReadyChan = make(chan struct{}, 1) s.shutdownChan = make(chan struct{}) + // Load payloads from persistence file if enabled + if s.options.LoadPersistence && s.options.PersistenceFile != "" { + if err := s.loadPayloadsFromFile(); err != nil { + s.logger.Warnf("failed to load payloads from persistence file: %v", err) + } else { + s.logger.Infof("loaded %d payloads from persistence file", len(s.payloadStates)) + // Verify contract deployments + s.verifyDeployedContracts() + } + } + return nil } @@ -235,18 +276,20 @@ func (s *Scenario) Run(ctx context.Context) error { // Start background AI worker go s.aiWorker(ctx) - // Initial AI request to get started - select { - case s.aiRequestChan <- struct{}{}: - default: - } + if len(s.payloadStates) == 0 { + // Initial AI request to get started + select { + case s.aiRequestChan <- struct{}{}: + default: + } - // Wait for AI to be ready - s.logger.Infof("waiting for AI payloads to be ready") - select { - case <-s.aiReadyChan: - case <-ctx.Done(): - return ctx.Err() + // Wait for AI to be ready + s.logger.Infof("waiting for AI payloads to be ready") + select { + case <-s.aiReadyChan: + case <-ctx.Done(): + return ctx.Err() + } } s.logger.Infof("AI payloads ready, starting transaction generation") @@ -309,16 +352,30 @@ func (s *Scenario) Run(ctx context.Context) error { // Signal shutdown to AI worker close(s.shutdownChan) + // Save payloads to persistence file if enabled + if s.options.SavePersistence && s.options.PersistenceFile != "" { + if saveErr := s.savePayloadsToFile(); saveErr != nil { + s.logger.Errorf("failed to save payloads to persistence file: %v", saveErr) + } else { + s.logger.Infof("saved %d payloads to persistence file", len(s.payloadStates)) + } + } + return err } func (s *Scenario) sendAITransaction(ctx context.Context, txIdx uint64, onComplete func()) (*types.Transaction, *spamoor.Client, *spamoor.Wallet, error) { // Send single call transaction using round-robin payload selection - defer onComplete() - client := s.walletPool.GetClient(spamoor.SelectClientByIndex, int(txIdx), s.options.ClientGroup) wallet := s.walletPool.GetWallet(spamoor.SelectWalletByPendingTxCount, int(txIdx)) + txSubmitted := false + defer func() { + if !txSubmitted { + onComplete() + } + }() + if client == nil { return nil, client, wallet, fmt.Errorf("no client available") } @@ -355,6 +412,7 @@ func (s *Scenario) sendAITransaction(ctx context.Context, txIdx uint64, onComple } // Send call transaction + txSubmitted = true err = s.walletPool.GetTxPool().SendTransaction(ctx, wallet, callTx, &spamoor.SendTransactionOptions{ Client: client, Rebroadcast: s.options.Rebroadcast > 0, @@ -363,6 +421,7 @@ func (s *Scenario) sendAITransaction(ctx context.Context, txIdx uint64, onComple s.recordPayloadSuccess(payloadState, payload, tx, receipt) }, OnComplete: func(tx *types.Transaction, receipt *types.Receipt, err error) { + onComplete() if err != nil { s.recordPayloadFailure(payloadState, "call_failed", err.Error()) } @@ -384,18 +443,34 @@ func (s *Scenario) getNextPayload(ctx context.Context) (*PayloadState, error) { s.payloadMutex.Lock() defer s.payloadMutex.Unlock() - // Check for payloads with < 20 successes + // Check for payloads that haven't reached success threshold and haven't failed too much for i := 0; i < len(s.payloadStates); i++ { idx := (s.payloadRoundRobin + i) % len(s.payloadStates) payloadState := s.payloadStates[idx] - if payloadState.SuccessCount < 20 && !payloadState.IsDeploying { + + // Skip if currently deploying + if payloadState.IsDeploying { + continue + } + + // Use payload if it hasn't reached success threshold AND hasn't failed excessively + // A payload is considered "exhausted" if it has reached either: + // - Success threshold (working well) + // - Failure threshold (not working, give up) + totalCalls := payloadState.SuccessCount + payloadState.FailCount + hasReachedSuccessThreshold := payloadState.SuccessCount >= s.options.SuccessThreshold + hasReachedFailureThreshold := payloadState.FailCount >= s.options.SuccessThreshold + + if !hasReachedSuccessThreshold && !hasReachedFailureThreshold { s.payloadRoundRobin = (idx + 1) % len(s.payloadStates) payloadState.LastUsed = time.Now() + s.logger.Debugf("selected payload: %s (success: %d, fail: %d, total: %d)", + payloadState.Template.Description, payloadState.SuccessCount, payloadState.FailCount, totalCalls) return payloadState, nil } } - // No payload with < 20 successes, request more payloads + // No payload available that hasn't reached threshold, request more payloads s.requestMorePayloads() // If we have any payloads, return the next one @@ -482,11 +557,12 @@ func (s *Scenario) callGeasContract(ctx context.Context, wallet *spamoor.Wallet, func (s *Scenario) recordPayloadSuccess(payloadState *PayloadState, payload *PayloadInstance, tx *types.Transaction, receipt *types.Receipt) { payloadState.mutex.Lock() + batchID := payloadState.BatchID payloadState.SuccessCount++ payloadState.mutex.Unlock() if receipt == nil { - s.feedbackCollector.RecordFailure(payload, "receipt_nil", "receipt was nil") + s.feedbackCollector.RecordFailure(payload, "receipt_nil", "receipt was nil", batchID) return } @@ -515,9 +591,9 @@ func (s *Scenario) recordPayloadSuccess(payloadState *PayloadState, payload *Pay hex.EncodeToString(log.Data))) } - s.logger.Debugf("transaction confirmed: %s (%s) - %s, gas: %d, fees: %s, logs: %d, success: %d, fail: %d", + s.logger.Debugf("transaction confirmed: %s (%s) - %s, gas: %d, fees: %s, logs: %d, success: %d, fail: %d, batch: %d", payload.Description, payload.Type, status, receipt.GasUsed, txFees.TotalFeeGweiString(), len(receipt.Logs), - payloadState.SuccessCount, payloadState.FailCount) + payloadState.SuccessCount, payloadState.FailCount, batchID) // Record result for feedback result := TransactionResult{ @@ -530,7 +606,7 @@ func (s *Scenario) recordPayloadSuccess(payloadState *PayloadState, payload *Pay LogData: logData, } - s.feedbackCollector.RecordResult(result) + s.feedbackCollector.RecordResult(result, batchID) } // ensureContractDeployed ensures the contract is deployed for the payload @@ -698,38 +774,44 @@ func (s *Scenario) generatePayloads(ctx context.Context) error { // Increment conversation response count s.conversationResponses++ - // Add new payloads to the pool - s.addPayloadsToPool(response.Payloads) + // Start a new feedback batch for the new payloads + s.feedbackCollector.StartNewBatch() + batchID := s.feedbackCollector.GetCurrentBatchID() - s.logger.Infof("AI call completed, generated %d payloads (conversation: %d responses)", - len(response.Payloads), s.conversationResponses) + // Add new payloads to the pool with the current batch ID + s.addPayloadsToPool(response.Payloads, batchID) + + s.logger.Infof("AI call completed, generated %d payloads (conversation: %d responses, batch: %d)", + len(response.Payloads), s.conversationResponses, batchID) return nil } -// addPayloadsToPool adds new payloads to the pool and manages the 100-payload limit -func (s *Scenario) addPayloadsToPool(templates []PayloadTemplate) { +// addPayloadsToPool adds new payloads to the pool and manages the max payload limit +func (s *Scenario) addPayloadsToPool(templates []PayloadTemplate, batchID int) { s.payloadMutex.Lock() defer s.payloadMutex.Unlock() - // Add new payloads + // Add new payloads with batch ID for _, template := range templates { payloadState := &PayloadState{ Template: template, IsDeployed: false, + IsDeploying: false, SuccessCount: 0, FailCount: 0, LastUsed: time.Now(), + BatchID: batchID, } s.payloadStates = append(s.payloadStates, payloadState) } - // Clean up if we exceed 100 payloads - if len(s.payloadStates) > 100 { + // Clean up if we exceed max payloads + if len(s.payloadStates) > s.options.MaxPayloads { s.cleanupPayloads() } - s.logger.Infof("added %d payloads to pool, total: %d", len(templates), len(s.payloadStates)) + s.logger.Infof("added %d payloads to pool (batch %d), total: %d", len(templates), batchID, len(s.payloadStates)) } // cleanupPayloads removes failing payloads first, then payloads with highest success count @@ -742,8 +824,8 @@ func (s *Scenario) cleanupPayloads() { return s.payloadStates[i].SuccessCount > s.payloadStates[j].SuccessCount }) - // Remove the worst 25% to get back to 75 payloads - targetSize := 75 + // Remove the worst 25% to get back to 75% of max + targetSize := (s.options.MaxPayloads * 3) / 4 // 75% of max if len(s.payloadStates) > targetSize { removedCount := len(s.payloadStates) - targetSize s.payloadStates = s.payloadStates[removedCount:] @@ -756,8 +838,179 @@ func (s *Scenario) cleanupPayloads() { } } +// PayloadStatePersistence represents the data to persist for a payload state +type PayloadStatePersistence struct { + Template PayloadTemplate `json:"template"` + IsDeployed bool `json:"is_deployed"` + ContractAddress string `json:"contract_address,omitempty"` + SuccessCount int `json:"success_count"` + FailCount int `json:"fail_count"` + LastUsed time.Time `json:"last_used"` + BatchID int `json:"batch_id"` +} + +// PayloadsPersistenceData represents the complete persistence data +type PayloadsPersistenceData struct { + Payloads []PayloadStatePersistence `json:"payloads"` + ConversationHistory []Message `json:"conversation_history,omitempty"` + ConversationResponses int `json:"conversation_responses"` + SavedAt time.Time `json:"saved_at"` +} + +// savePayloadsToFile saves the current payload states to a JSON file +func (s *Scenario) savePayloadsToFile() error { + s.payloadMutex.RLock() + defer s.payloadMutex.RUnlock() + + var persistenceData PayloadsPersistenceData + persistenceData.ConversationHistory = s.conversationHistory + persistenceData.ConversationResponses = s.conversationResponses + persistenceData.SavedAt = time.Now() + + // Convert payload states to persistence format + persistenceData.Payloads = make([]PayloadStatePersistence, len(s.payloadStates)) + for i, state := range s.payloadStates { + state.mutex.Lock() + persistenceData.Payloads[i] = PayloadStatePersistence{ + Template: state.Template, + IsDeployed: state.IsDeployed, + ContractAddress: state.ContractAddress.Hex(), + SuccessCount: state.SuccessCount, + FailCount: state.FailCount, + LastUsed: state.LastUsed, + BatchID: state.BatchID, + } + state.mutex.Unlock() + } + + // Marshal to JSON + data, err := json.MarshalIndent(persistenceData, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal persistence data: %w", err) + } + + // Write to file + err = os.WriteFile(s.options.PersistenceFile, data, 0644) + if err != nil { + return fmt.Errorf("failed to write persistence file: %w", err) + } + + s.logger.Infof("saved %d payloads to persistence file: %s", len(s.payloadStates), s.options.PersistenceFile) + return nil +} + +// loadPayloadsFromFile loads payload states from a JSON file +func (s *Scenario) loadPayloadsFromFile() error { + // Check if file exists + if _, err := os.Stat(s.options.PersistenceFile); os.IsNotExist(err) { + s.logger.Infof("persistence file does not exist: %s", s.options.PersistenceFile) + return nil + } + + // Read file + data, err := os.ReadFile(s.options.PersistenceFile) + if err != nil { + return fmt.Errorf("failed to read persistence file: %w", err) + } + + // Unmarshal JSON + var persistenceData PayloadsPersistenceData + err = json.Unmarshal(data, &persistenceData) + if err != nil { + return fmt.Errorf("failed to unmarshal persistence data: %w", err) + } + + s.payloadMutex.Lock() + defer s.payloadMutex.Unlock() + + // Restore conversation state + s.conversationHistory = persistenceData.ConversationHistory + s.conversationResponses = persistenceData.ConversationResponses + + // Convert persistence format to payload states + s.payloadStates = make([]*PayloadState, len(persistenceData.Payloads)) + for i, persistedState := range persistenceData.Payloads { + contractAddr := common.Address{} + if persistedState.ContractAddress != "" && persistedState.ContractAddress != "0x0000000000000000000000000000000000000000" { + contractAddr = common.HexToAddress(persistedState.ContractAddress) + } + + s.payloadStates[i] = &PayloadState{ + Template: persistedState.Template, + IsDeployed: persistedState.IsDeployed, + IsDeploying: false, // Always start with false + ContractAddress: contractAddr, + SuccessCount: persistedState.SuccessCount, + FailCount: persistedState.FailCount, + LastUsed: persistedState.LastUsed, + BatchID: persistedState.BatchID, + } + } + + s.logger.Infof("loaded %d payloads from persistence file: %s (saved at: %s)", + len(s.payloadStates), s.options.PersistenceFile, persistenceData.SavedAt.Format(time.RFC3339)) + + return nil +} + +// verifyDeployedContracts checks if contracts are actually deployed at stored addresses +func (s *Scenario) verifyDeployedContracts() { + if len(s.payloadStates) == 0 { + return + } + + // Get a client for verification + client := s.walletPool.GetClient(spamoor.SelectClientByIndex, 0, s.options.ClientGroup) + if client == nil { + s.logger.Warnf("no client available for contract verification") + return + } + + s.payloadMutex.Lock() + defer s.payloadMutex.Unlock() + + verifiedCount := 0 + invalidatedCount := 0 + + for _, payloadState := range s.payloadStates { + payloadState.mutex.Lock() + + if payloadState.IsDeployed && payloadState.ContractAddress != (common.Address{}) { + // Check if code exists at the address + code, err := client.GetEthClient().CodeAt(context.Background(), payloadState.ContractAddress, nil) + if err != nil { + s.logger.Warnf("failed to check contract code at %s: %v", payloadState.ContractAddress.Hex(), err) + // On error, assume contract is not deployed to be safe + payloadState.IsDeployed = false + payloadState.ContractAddress = common.Address{} + invalidatedCount++ + } else if len(code) == 0 { + // No code at address, contract not deployed + s.logger.Debugf("no code found at %s, marking payload as not deployed: %s", + payloadState.ContractAddress.Hex(), payloadState.Template.Description) + payloadState.IsDeployed = false + payloadState.ContractAddress = common.Address{} + invalidatedCount++ + } else { + // Code exists, contract is deployed + s.logger.Debugf("verified contract at %s for payload: %s", + payloadState.ContractAddress.Hex(), payloadState.Template.Description) + verifiedCount++ + } + } + + payloadState.mutex.Unlock() + } + + if verifiedCount > 0 || invalidatedCount > 0 { + s.logger.Infof("contract verification complete: %d verified, %d invalidated", + verifiedCount, invalidatedCount) + } +} + func (s *Scenario) recordPayloadFailure(payloadState *PayloadState, errorType string, errorMsg string) { payloadState.mutex.Lock() + batchID := payloadState.BatchID payloadState.FailCount++ payloadState.mutex.Unlock() @@ -766,9 +1019,9 @@ func (s *Scenario) recordPayloadFailure(payloadState *PayloadState, errorType st Type: payloadState.Template.Type, Description: payloadState.Template.Description, } - s.feedbackCollector.RecordFailure(dummyPayload, errorType, errorMsg) + s.feedbackCollector.RecordFailure(dummyPayload, errorType, errorMsg, batchID) - s.logger.Debugf("payload failure: %s - %s: %s, success: %d, fail: %d", + s.logger.Debugf("payload failure: %s - %s: %s, success: %d, fail: %d, batch: %d", payloadState.Template.Description, errorType, errorMsg, - payloadState.SuccessCount, payloadState.FailCount) + payloadState.SuccessCount, payloadState.FailCount, batchID) } diff --git a/scenarios/aitx/feedback_collector.go b/scenarios/aitx/feedback_collector.go index ae6f32da..c1ba1455 100644 --- a/scenarios/aitx/feedback_collector.go +++ b/scenarios/aitx/feedback_collector.go @@ -9,45 +9,82 @@ import ( "github.com/sirupsen/logrus" ) +type PayloadStats struct { + Description string + SuccessCount int + FailureCount int + TotalGasUsed uint64 + SuccessfulCalls []TransactionResult // Successful transactions for this payload + FailedCalls []TransactionResult // Failed transactions for this payload +} + type FeedbackCollector struct { - results []TransactionResult - mutex sync.RWMutex - maxResults uint64 - totalTransactions uint64 - successfulTxs uint64 - failedTxs uint64 - logger logrus.FieldLogger + payloadStats map[string]*PayloadStats // Per-payload statistics + mutex sync.RWMutex + maxResults uint64 + currentBatchID int // ID of current AI response batch + logger logrus.FieldLogger } func NewFeedbackCollector(maxResults uint64, logger logrus.FieldLogger) *FeedbackCollector { return &FeedbackCollector{ - results: make([]TransactionResult, 0, maxResults), - maxResults: maxResults, - logger: logger.WithField("component", "feedback_collector"), + payloadStats: make(map[string]*PayloadStats), + maxResults: maxResults, + currentBatchID: 0, + logger: logger.WithField("component", "feedback_collector"), } } -func (fc *FeedbackCollector) RecordResult(result TransactionResult) { +func (fc *FeedbackCollector) RecordResult(result TransactionResult, batchID int) { fc.mutex.Lock() defer fc.mutex.Unlock() - fc.results = append(fc.results, result) - if uint64(len(fc.results)) > fc.maxResults { - fc.results = fc.results[1:] + // Only record results for the current batch + if batchID != fc.currentBatchID { + fc.logger.Debugf("ignoring result from old batch %d (current: %d): %s", + batchID, fc.currentBatchID, result.PayloadDescription) + return + } + + // Get or create payload stats + stats, exists := fc.payloadStats[result.PayloadDescription] + if !exists { + stats = &PayloadStats{ + Description: result.PayloadDescription, + SuccessfulCalls: make([]TransactionResult, 0), + FailedCalls: make([]TransactionResult, 0), + } + fc.payloadStats[result.PayloadDescription] = stats } - fc.totalTransactions++ + // Record result if result.Status == "success" { - fc.successfulTxs++ + stats.SuccessCount++ + stats.TotalGasUsed += result.GasUsed + stats.SuccessfulCalls = append(stats.SuccessfulCalls, result) + + // Keep only recent successful calls (limit to maxResults/2 per payload) + maxPerPayload := int(fc.maxResults / 2) + if len(stats.SuccessfulCalls) > maxPerPayload { + stats.SuccessfulCalls = stats.SuccessfulCalls[len(stats.SuccessfulCalls)-maxPerPayload:] + } } else { - fc.failedTxs++ + stats.FailureCount++ + stats.FailedCalls = append(stats.FailedCalls, result) + + // Keep only recent failed calls (limit to maxResults/2 per payload) + maxPerPayload := int(fc.maxResults / 2) + if len(stats.FailedCalls) > maxPerPayload { + stats.FailedCalls = stats.FailedCalls[len(stats.FailedCalls)-maxPerPayload:] + } } - fc.logger.Debugf("recorded transaction result: %s (%s) - %s", - result.PayloadDescription, result.PayloadType, result.Status) + fc.logger.Debugf("recorded transaction result for batch %d: %s (%s) - %s (success: %d, failed: %d)", + batchID, result.PayloadDescription, result.PayloadType, result.Status, + stats.SuccessCount, stats.FailureCount) } -func (fc *FeedbackCollector) RecordFailure(payload *PayloadInstance, status, errorMsg string) { +func (fc *FeedbackCollector) RecordFailure(payload *PayloadInstance, status, errorMsg string, batchID int) { result := TransactionResult{ PayloadType: payload.Type, PayloadDescription: payload.Description, @@ -56,85 +93,166 @@ func (fc *FeedbackCollector) RecordFailure(payload *PayloadInstance, status, err BlockExecTime: "N/A", ErrorMessage: errorMsg, } - fc.RecordResult(result) + fc.RecordResult(result, batchID) } func (fc *FeedbackCollector) GenerateFeedback() *TransactionFeedback { fc.mutex.RLock() defer fc.mutex.RUnlock() - if len(fc.results) == 0 { + if len(fc.payloadStats) == 0 { return nil } - gasValues := make([]uint64, 0, len(fc.results)) - for _, result := range fc.results { - if result.Status == "success" && result.GasUsed > 0 { - gasValues = append(gasValues, result.GasUsed) + var totalTxs, successfulTxs, failedTxs uint64 + var allGasValues []uint64 + var allResults []TransactionResult + + // Aggregate stats from all payloads + for _, stats := range fc.payloadStats { + totalTxs += uint64(stats.SuccessCount + stats.FailureCount) + successfulTxs += uint64(stats.SuccessCount) + failedTxs += uint64(stats.FailureCount) + + // Collect gas values from successful calls + for _, result := range stats.SuccessfulCalls { + if result.GasUsed > 0 { + allGasValues = append(allGasValues, result.GasUsed) + } + allResults = append(allResults, result) } + + // Include failed calls in results + allResults = append(allResults, stats.FailedCalls...) } var avgGas, medianGas uint64 - if len(gasValues) > 0 { - sort.Slice(gasValues, func(i, j int) bool { return gasValues[i] < gasValues[j] }) + if len(allGasValues) > 0 { + sort.Slice(allGasValues, func(i, j int) bool { return allGasValues[i] < allGasValues[j] }) var total uint64 - for _, gas := range gasValues { + for _, gas := range allGasValues { total += gas } - avgGas = total / uint64(len(gasValues)) - medianGas = gasValues[len(gasValues)/2] + avgGas = total / uint64(len(allGasValues)) + medianGas = allGasValues[len(allGasValues)/2] } - summary := fc.generateSummary() + summary := fc.generateDetailedSummary() return &TransactionFeedback{ - TotalTransactions: fc.totalTransactions, - SuccessfulTxs: fc.successfulTxs, - FailedTxs: fc.failedTxs, + TotalTransactions: totalTxs, + SuccessfulTxs: successfulTxs, + FailedTxs: failedTxs, AverageGasUsed: avgGas, MedianGasUsed: medianGas, AverageBlockExecTime: "N/A", - RecentResults: fc.getRecentResults(10), + RecentResults: allResults, Summary: summary, } } -func (fc *FeedbackCollector) generateSummary() string { - if len(fc.results) == 0 { - return "No transaction results yet." +func (fc *FeedbackCollector) generateDetailedSummary() string { + if len(fc.payloadStats) == 0 { + return "No transaction results for current batch yet." } - typeSuccess := make(map[string]int) - typeTotal := make(map[string]int) + var summaryParts []string - for _, result := range fc.results { - typeTotal[result.PayloadType]++ - if result.Status == "success" { - typeSuccess[result.PayloadType]++ + for description, stats := range fc.payloadStats { + total := stats.SuccessCount + stats.FailureCount + if total == 0 { + continue } - } - var summaryParts []string - for payloadType, total := range typeTotal { - success := typeSuccess[payloadType] - successRate := float64(success) / float64(total) * 100 - summaryParts = append(summaryParts, - fmt.Sprintf("%s: %.1f%% success (%d/%d)", payloadType, successRate, success, total)) + successRate := float64(stats.SuccessCount) / float64(total) * 100 + + // Calculate average gas for this payload + var avgGas uint64 + if stats.SuccessCount > 0 { + avgGas = stats.TotalGasUsed / uint64(stats.SuccessCount) + } + + // Get recent logs from successful calls + var recentLogs []string + for _, result := range stats.SuccessfulCalls { + if len(result.LogData) > 0 { + recentLogs = append(recentLogs, result.LogData[0]) // Take first log + if len(recentLogs) >= 2 { // Limit to 2 logs per payload + break + } + } + } + + // Get recent error messages from failed calls + var recentErrors []string + for _, result := range stats.FailedCalls { + if result.ErrorMessage != "" { + recentErrors = append(recentErrors, result.ErrorMessage) + if len(recentErrors) >= 2 { // Limit to 2 errors per payload + break + } + } + } + + payloadSummary := fmt.Sprintf("'%s': %.1f%% success (%d/%d), avg_gas: %d", + description, successRate, stats.SuccessCount, total, avgGas) + + if len(recentLogs) > 0 { + payloadSummary += fmt.Sprintf(", recent_logs: [%s]", strings.Join(recentLogs, ", ")) + } + + if len(recentErrors) > 0 { + payloadSummary += fmt.Sprintf(", recent_errors: [%s]", strings.Join(recentErrors, ", ")) + } + + summaryParts = append(summaryParts, payloadSummary) } - return fmt.Sprintf("Pattern analysis: %s", strings.Join(summaryParts, ", ")) + return fmt.Sprintf("Detailed payload analysis: %s", strings.Join(summaryParts, " | ")) +} + +// StartNewBatch resets the feedback collector for a new AI response batch +func (fc *FeedbackCollector) StartNewBatch() { + fc.mutex.Lock() + defer fc.mutex.Unlock() + + fc.currentBatchID++ + fc.payloadStats = make(map[string]*PayloadStats) + fc.logger.Infof("started new feedback batch %d", fc.currentBatchID) +} + +func (fc *FeedbackCollector) GetCurrentBatchID() int { + fc.mutex.RLock() + defer fc.mutex.RUnlock() + return fc.currentBatchID } -func (fc *FeedbackCollector) getRecentResults(count int) []TransactionResult { - if len(fc.results) <= count { - return fc.results +func (fc *FeedbackCollector) GetCurrentBatchStats() (uint64, uint64, uint64) { + fc.mutex.RLock() + defer fc.mutex.RUnlock() + + var totalTxs, successfulTxs, failedTxs uint64 + for _, stats := range fc.payloadStats { + totalTxs += uint64(stats.SuccessCount + stats.FailureCount) + successfulTxs += uint64(stats.SuccessCount) + failedTxs += uint64(stats.FailureCount) } - return fc.results[len(fc.results)-count:] + + return totalTxs, successfulTxs, failedTxs } -func (fc *FeedbackCollector) GetStats() (uint64, uint64, uint64) { +// GetFailedPayloads returns payloads that have failures for immediate feedback +func (fc *FeedbackCollector) GetFailedPayloads() map[string]*PayloadStats { fc.mutex.RLock() defer fc.mutex.RUnlock() - return fc.totalTransactions, fc.successfulTxs, fc.failedTxs + + failedPayloads := make(map[string]*PayloadStats) + for description, stats := range fc.payloadStats { + if stats.FailureCount > 0 { + failedPayloads[description] = stats + } + } + + return failedPayloads } diff --git a/scenarios/aitx/payload_processor.go b/scenarios/aitx/payload_processor.go index 91b67e9e..42e3dc81 100644 --- a/scenarios/aitx/payload_processor.go +++ b/scenarios/aitx/payload_processor.go @@ -9,12 +9,14 @@ import ( type PayloadProcessor struct { logger logrus.FieldLogger geasProcessor *GeasProcessor + substituter *PlaceholderSubstituter } -func NewPayloadProcessor(logger logrus.FieldLogger) *PayloadProcessor { +func NewPayloadProcessor(logger logrus.FieldLogger, substituter *PlaceholderSubstituter) *PayloadProcessor { return &PayloadProcessor{ logger: logger.WithField("component", "payload_processor"), geasProcessor: NewGeasProcessor(logger), + substituter: substituter, } } @@ -66,18 +68,15 @@ func (pp *PayloadProcessor) ProcessPayloads(templates []PayloadTemplate) ([]Payl } func (pp *PayloadProcessor) validateGeasCompilation(payload *PayloadTemplate) error { - // Create a temporary payload instance for compilation testing - tempPayload := &PayloadInstance{ - Type: payload.Type, - Description: payload.Description, - InitCode: payload.InitCode, - RunCode: payload.RunCode, - PostCode: payload.PostCode, - GasRemainder: 10000, // Default value for validation + // Substitute placeholders first before validation + tempPayloadInstance, err := payload.Substitute(pp.substituter) + if err != nil { + pp.logger.Debugf("placeholder substitution failed for payload '%s': %v", payload.Description, err) + return fmt.Errorf("placeholder substitution failed: %w", err) } - // Attempt to compile the geas code - _, err := pp.geasProcessor.CompileGeasPayload(tempPayload) + // Attempt to compile the geas code with substituted placeholders + _, err = pp.geasProcessor.CompileGeasPayload(tempPayloadInstance) if err != nil { pp.logger.Debugf("geas compilation validation failed for payload '%s': %v", payload.Description, err) return err From 695e2bca38b4515b998ee905f264d62e45b39116 Mon Sep 17 00:00:00 2001 From: pk910 Date: Thu, 17 Jul 2025 16:11:08 +0200 Subject: [PATCH 4/4] continue implementation --- Makefile | 2 +- scenarios/aitx/ai_service.go | 311 +++++++++++++++++++++++---- scenarios/aitx/feedback_collector.go | 268 +++++++++++++---------- scenarios/aitx/payload_template.go | 3 + scenarios/aitx/streaming_callback.go | 234 ++++++++++++++++++++ 5 files changed, 661 insertions(+), 157 deletions(-) create mode 100644 scenarios/aitx/streaming_callback.go diff --git a/Makefile b/Makefile index da749b21..f6a9e458 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ devnet: .hack/devnet/run.sh devnet-run: devnet docs build - bin/spamoor-daemon --rpchost-file .hack/devnet/generated-hosts.txt --privkey 3fd98b5187bf6526734efaa644ffbb4e3670d66f5d0268ce0323ec09124bff61 --port 8080 --db .hack/devnet/custom-spamoor.db + bin/spamoor-daemon --rpchost-file .hack/devnet/generated-hosts.txt --privkey 3fd98b5187bf6526734efaa644ffbb4e3670d66f5d0268ce0323ec09124bff61 --port 8080 --db .hack/devnet/custom-spamoor.db -v devnet-clean: .hack/devnet/cleanup.sh diff --git a/scenarios/aitx/ai_service.go b/scenarios/aitx/ai_service.go index 5612517e..44f5f322 100644 --- a/scenarios/aitx/ai_service.go +++ b/scenarios/aitx/ai_service.go @@ -1,6 +1,7 @@ package aitx import ( + "bufio" "bytes" "context" "encoding/json" @@ -68,9 +69,15 @@ type ConversationContinuation struct { } type OpenRouterRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - MaxTokens int `json:"max_tokens"` + Model string `json:"model"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens"` + Stream bool `json:"stream,omitempty"` + Reasoning *ReasoningConfig `json:"reasoning,omitempty"` +} + +type ReasoningConfig struct { + MaxTokens int `json:"max_tokens"` } type Message struct { @@ -79,20 +86,36 @@ type Message struct { } type OpenRouterResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - Message Message `json:"message"` - FinishReason string `json:"finish_reason"` + ID string `json:"id"` + Provider string `json:"provider,omitempty"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Message *Message `json:"message,omitempty"` + Delta *Delta `json:"delta,omitempty"` + FinishReason *string `json:"finish_reason,omitempty"` + NativeFinishReason *string `json:"native_finish_reason,omitempty"` + Logprobs *string `json:"logprobs,omitempty"` } `json:"choices"` - Usage struct { + Usage *struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` - } `json:"usage"` + } `json:"usage,omitempty"` +} + +type Delta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Reasoning string `json:"reasoning,omitempty"` + ReasoningDetails []string `json:"reasoning_details,omitempty"` +} + +type StreamingCallback interface { + OnContent(content string) error + OnComplete(fullContent string) error } func NewAIService(apiKey, model string, logConversations bool, logger logrus.FieldLogger) *AIService { @@ -102,7 +125,7 @@ func NewAIService(apiKey, model string, logConversations bool, logger logrus.Fie return &AIService{ client: &http.Client{ - Timeout: 60 * time.Second, + Timeout: 30 * time.Minute, // Increased timeout for longer AI requests }, apiKey: apiKey, model: model, @@ -139,6 +162,9 @@ func (ai *AIService) GeneratePayloads(ctx context.Context, req GenerationRequest Model: ai.model, Messages: conversationHistory, MaxTokens: 10000, + Reasoning: &ReasoningConfig{ + MaxTokens: 5000, + }, } response, err := ai.callOpenRouter(ctx, openRouterReq) @@ -147,9 +173,13 @@ func (ai *AIService) GeneratePayloads(ctx context.Context, req GenerationRequest continue } - ai.tokenCount += uint64(response.Usage.TotalTokens) - ai.logger.Infof("AI call #%d completed: %d tokens used, %d total tokens", - ai.callCount, response.Usage.TotalTokens, ai.tokenCount) + if response.Usage != nil { + ai.tokenCount += uint64(response.Usage.TotalTokens) + ai.logger.Infof("AI call #%d completed: %d tokens used, %d total tokens", + ai.callCount, response.Usage.TotalTokens, ai.tokenCount) + } else { + ai.logger.Infof("AI call #%d completed (token usage not available)", ai.callCount) + } // Try to parse the response result, parseErr := ai.parseResponse(response) @@ -226,33 +256,56 @@ func (ai *AIService) GeneratePayloadsWithConversation(ctx context.Context, req G openRouterReq := OpenRouterRequest{ Model: ai.model, Messages: conversationHistory, - MaxTokens: 10000, + MaxTokens: 50000, + Stream: true, + Reasoning: &ReasoningConfig{ + MaxTokens: 20000, + }, } - response, err := ai.callOpenRouter(ctx, openRouterReq) + // Create streaming callback for real-time payload processing + callback := &PayloadStreamingCallback{ + processor: processor, + logger: ai.logger, + payloadBuffer: &strings.Builder{}, + } + + response, fullContent, err := ai.callOpenRouterStreaming(ctx, openRouterReq, callback) if err != nil { - lastError = fmt.Errorf("AI API call failed: %w", err) + ai.logger.Warnf("AI streaming call failed: %v", err) + lastError = fmt.Errorf("AI streaming call failed: %w", err) continue } - ai.tokenCount += uint64(response.Usage.TotalTokens) - ai.logger.Infof("AI call #%d completed: %d tokens used, %d total tokens", - ai.callCount, response.Usage.TotalTokens, ai.tokenCount) + if response.Usage != nil { + ai.tokenCount += uint64(response.Usage.TotalTokens) + ai.logger.Infof("AI call #%d completed: %d tokens used, %d total tokens", + ai.callCount, response.Usage.TotalTokens, ai.tokenCount) + } - // Try to parse the response - result, parseErr := ai.parseResponse(response) + // Try to parse the final response + mockResponse := &OpenRouterResponse{ + Choices: []struct { + Index int `json:"index"` + Message *Message `json:"message,omitempty"` + Delta *Delta `json:"delta,omitempty"` + FinishReason *string `json:"finish_reason,omitempty"` + NativeFinishReason *string `json:"native_finish_reason,omitempty"` + Logprobs *string `json:"logprobs,omitempty"` + }{{Message: &Message{Content: fullContent}}}, + } + + result, parseErr := ai.parseResponse(mockResponse) if parseErr == nil { // Validate payloads (including geas compilation) validPayloads, validationErr := processor.ProcessPayloads(result.Payloads) if validationErr == nil { // Success! Update result with validated payloads and add AI response to history result.Payloads = validPayloads - if len(response.Choices) > 0 { - conversationHistory = append(conversationHistory, Message{ - Role: "assistant", - Content: response.Choices[0].Message.Content, - }) - } + conversationHistory = append(conversationHistory, Message{ + Role: "assistant", + Content: fullContent, + }) // Log AI response for debugging if enabled if ai.logConversations { @@ -270,12 +323,10 @@ func (ai *AIService) GeneratePayloadsWithConversation(ctx context.Context, req G lastError = parseErr // Add AI response to conversation history - if len(response.Choices) > 0 { - conversationHistory = append(conversationHistory, Message{ - Role: "assistant", - Content: response.Choices[0].Message.Content, - }) - } + conversationHistory = append(conversationHistory, Message{ + Role: "assistant", + Content: fullContent, + }) // Add error feedback for retry errorFeedback := ai.buildErrorFeedback(parseErr, attempt+1, maxRetries) @@ -301,7 +352,7 @@ func (ai *AIService) buildPrompt(req GenerationRequest) string { promptBuilder.WriteString(fmt.Sprintf("TEST DIRECTION: %s\n\n", req.TestDirection)) } - promptBuilder.WriteString(fmt.Sprintf("Generate %d transaction payload(s).\n", req.PayloadCount)) + promptBuilder.WriteString(fmt.Sprintf("Generate 10 transaction payload(s) with placeholder variations, so we can test at least %v different patterns.\n", req.PayloadCount)) if req.TransactionFeedback != nil { promptBuilder.WriteString("FEEDBACK FROM PREVIOUS TRANSACTIONS:\n") @@ -424,6 +475,7 @@ func (ai *AIService) buildBasePrompt(generationMode string) string { promptBuilder.WriteString("Generate at least 20 separate JSON objects (do not stop before), each wrapped in ```json and ``` tags:\n\n") promptBuilder.WriteString(`{ + "id": "unique_payload_id_1", "type": "geas", "description": "Brief description of what this contract does", "init_code": "PUSH1 0x00\nSSTORE", @@ -447,6 +499,7 @@ func (ai *AIService) buildBasePrompt(generationMode string) string { promptBuilder.WriteString("IMPORTANT:\n") promptBuilder.WriteString("- Generate ONLY geas init_run contracts (type=\\\"geas\\\")\n") + promptBuilder.WriteString("- Each payload MUST have a unique 'id' field (e.g., 'payload_1', 'payload_2', etc.)\n") promptBuilder.WriteString("- Focus on diverse EVM testing patterns\n") promptBuilder.WriteString("- Reuse previous iteration results to avoid EVM caching\n") promptBuilder.WriteString("- Use SWAPn to manage persistent values on stack\n") @@ -496,12 +549,185 @@ func (ai *AIService) callOpenRouter(ctx context.Context, req OpenRouterRequest) return &openRouterResp, nil } +func (ai *AIService) callOpenRouterStreaming(ctx context.Context, req OpenRouterRequest, callback StreamingCallback) (*OpenRouterResponse, string, error) { + jsonData, err := json.Marshal(req) + if err != nil { + return nil, "", fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", ai.baseURL, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, "", fmt.Errorf("failed to create HTTP request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+ai.apiKey) + httpReq.Header.Set("HTTP-Referer", "https://github.com/ethpandaops/spamoor") + httpReq.Header.Set("X-Title", "Spamoor AI Transaction Generator") + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Cache-Control", "no-cache") + + resp, err := ai.client.Do(httpReq) + if err != nil { + return nil, "", fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, "", fmt.Errorf("OpenRouter API error %d: %s", resp.StatusCode, string(body)) + } + + return ai.parseStreamingResponse(ctx, resp.Body, callback) +} + +func (ai *AIService) parseStreamingResponse(ctx context.Context, body io.Reader, callback StreamingCallback) (*OpenRouterResponse, string, error) { + scanner := bufio.NewScanner(body) + var fullContent strings.Builder + var reasoningBuffer strings.Builder + var reasoningDetailsBuffer []string + var lastResponse *OpenRouterResponse + + for scanner.Scan() { + select { + case <-ctx.Done(): + return nil, fullContent.String(), ctx.Err() + default: + } + + line := scanner.Text() + + //ai.logger.Debugf("streaming rsp: %s", line) + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, ":") { + continue + } + + // Parse SSE data line + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + + // Check for stream end + if data == "[DONE]" { + break + } + + var streamResp OpenRouterResponse + if err := json.Unmarshal([]byte(data), &streamResp); err != nil { + ai.logger.Warnf("failed to parse streaming response chunk: %v", err) + continue + } + + lastResponse = &streamResp + + // Extract content from delta + if len(streamResp.Choices) > 0 && streamResp.Choices[0].Delta != nil { + delta := streamResp.Choices[0].Delta + + // Buffer reasoning and print complete lines only + if delta.Reasoning != "" { + reasoningBuffer.WriteString(delta.Reasoning) + + // Check for complete lines in the buffer + bufferContent := reasoningBuffer.String() + lines := strings.Split(bufferContent, "\n") + + // Print all complete lines (all but the last if it doesn't end with newline) + for i := 0; i < len(lines)-1; i++ { + if lines[i] != "" { + ai.logger.Debugf("AI reasoning: %s", lines[i]) + } + } + + // Keep the incomplete line in the buffer + if len(lines) > 0 && !strings.HasSuffix(bufferContent, "\n") { + reasoningBuffer.Reset() + reasoningBuffer.WriteString(lines[len(lines)-1]) + } else { + reasoningBuffer.Reset() + } + } + + // Accumulate reasoning details + if len(delta.ReasoningDetails) > 0 { + reasoningDetailsBuffer = append(reasoningDetailsBuffer, delta.ReasoningDetails...) + } + + // Process content + if delta.Content != "" { + fullContent.WriteString(delta.Content) + + // Call streaming callback with new content + if callback != nil { + if err := callback.OnContent(delta.Content); err != nil { + ai.logger.Warnf("streaming callback error: %v", err) + } + } + } + } + } + } + + if err := scanner.Err(); err != nil { + return nil, fullContent.String(), fmt.Errorf("error reading streaming response: %w", err) + } + + // Print any remaining reasoning content + if reasoningBuffer.Len() > 0 { + ai.logger.Debugf("AI reasoning: %s", reasoningBuffer.String()) + } + + // Print accumulated reasoning details + if len(reasoningDetailsBuffer) > 0 { + // Join all details and split by lines for cleaner output + allDetails := strings.Join(reasoningDetailsBuffer, "\n") + lines := strings.Split(allDetails, "\n") + + for _, line := range lines { + if line != "" { + ai.logger.Debugf("AI reasoning detail: %s", line) + } + } + } + + // Call completion callback + if callback != nil { + if err := callback.OnComplete(fullContent.String()); err != nil { + ai.logger.Warnf("streaming completion callback error: %v", err) + } + } + + // Return the last response (which should contain usage info) or create a mock response + if lastResponse == nil { + lastResponse = &OpenRouterResponse{ + Choices: []struct { + Index int `json:"index"` + Message *Message `json:"message,omitempty"` + Delta *Delta `json:"delta,omitempty"` + FinishReason *string `json:"finish_reason,omitempty"` + NativeFinishReason *string `json:"native_finish_reason,omitempty"` + Logprobs *string `json:"logprobs,omitempty"` + }{{Message: &Message{Content: fullContent.String()}}}, + } + } + + return lastResponse, fullContent.String(), nil +} + func (ai *AIService) parseResponse(response *OpenRouterResponse) (*GenerationResponse, error) { if len(response.Choices) == 0 { return nil, fmt.Errorf("no choices in AI response") } - content := response.Choices[0].Message.Content + var content string + if response.Choices[0].Message != nil { + content = response.Choices[0].Message.Content + } else if response.Choices[0].Delta != nil { + content = response.Choices[0].Delta.Content + } else { + return nil, fmt.Errorf("no message or delta content in AI response") + } var payloads []PayloadTemplate @@ -529,10 +755,15 @@ func (ai *AIService) parseResponse(response *OpenRouterResponse) (*GenerationRes ai.logger.Infof("Successfully parsed %d payloads from AI response", len(payloads)) summary := fmt.Sprintf("Generated %d payloads using %s", len(payloads), ai.model) + var tokensUsed uint64 + if response.Usage != nil { + tokensUsed = uint64(response.Usage.TotalTokens) + } + return &GenerationResponse{ Payloads: payloads, Summary: summary, - TokensUsed: uint64(response.Usage.TotalTokens), + TokensUsed: tokensUsed, }, nil } diff --git a/scenarios/aitx/feedback_collector.go b/scenarios/aitx/feedback_collector.go index c1ba1455..21df47d9 100644 --- a/scenarios/aitx/feedback_collector.go +++ b/scenarios/aitx/feedback_collector.go @@ -9,32 +9,98 @@ import ( "github.com/sirupsen/logrus" ) -type PayloadStats struct { - Description string - SuccessCount int - FailureCount int - TotalGasUsed uint64 - SuccessfulCalls []TransactionResult // Successful transactions for this payload - FailedCalls []TransactionResult // Failed transactions for this payload +type PayloadFeedback struct { + PayloadID string `json:"payload_id"` + PayloadIndex int `json:"payload_index"` + Description string `json:"description"` + CompilationStatus string `json:"compilation_status"` // "success", "failed" + CompilationError string `json:"compilation_error,omitempty"` + ExecutionStatus string `json:"execution_status"` // "success", "failed", "not_executed" + ExecutionError string `json:"execution_error,omitempty"` + GasUsed uint64 `json:"gas_used"` + ExecutionCount int `json:"execution_count"` + LastExecutionResult string `json:"last_execution_result,omitempty"` } type FeedbackCollector struct { - payloadStats map[string]*PayloadStats // Per-payload statistics - mutex sync.RWMutex - maxResults uint64 - currentBatchID int // ID of current AI response batch - logger logrus.FieldLogger + payloadFeedbacks []PayloadFeedback // Ordered list of payload feedback + payloadLookup map[string]int // Map payload ID to index for quick lookup + mutex sync.RWMutex + maxResults uint64 + currentBatchID int // ID of current AI response batch + logger logrus.FieldLogger } func NewFeedbackCollector(maxResults uint64, logger logrus.FieldLogger) *FeedbackCollector { return &FeedbackCollector{ - payloadStats: make(map[string]*PayloadStats), - maxResults: maxResults, - currentBatchID: 0, - logger: logger.WithField("component", "feedback_collector"), + payloadFeedbacks: make([]PayloadFeedback, 0), + payloadLookup: make(map[string]int), + maxResults: maxResults, + currentBatchID: 0, + logger: logger.WithField("component", "feedback_collector"), } } +// RegisterPayload adds a new payload in order for tracking +func (fc *FeedbackCollector) RegisterPayload(payloadID, description string, batchID int) { + fc.mutex.Lock() + defer fc.mutex.Unlock() + + // Only register payloads for the current batch + if batchID != fc.currentBatchID { + fc.logger.Debugf("ignoring payload registration from old batch %d (current: %d): %s", + batchID, fc.currentBatchID, payloadID) + return + } + + // Check if already registered + if _, exists := fc.payloadLookup[payloadID]; exists { + return + } + + // Add new payload feedback entry + index := len(fc.payloadFeedbacks) + fc.payloadFeedbacks = append(fc.payloadFeedbacks, PayloadFeedback{ + PayloadID: payloadID, + PayloadIndex: index, + Description: description, + CompilationStatus: "pending", + ExecutionStatus: "not_executed", + GasUsed: 0, + ExecutionCount: 0, + }) + fc.payloadLookup[payloadID] = index + + fc.logger.Debugf("registered payload %d for batch %d: %s (%s)", index, batchID, payloadID, description) +} + +// RecordCompilationResult records the compilation status for a payload +func (fc *FeedbackCollector) RecordCompilationResult(payloadID string, success bool, errorMsg string, batchID int) { + fc.mutex.Lock() + defer fc.mutex.Unlock() + + if batchID != fc.currentBatchID { + return + } + + index, exists := fc.payloadLookup[payloadID] + if !exists { + fc.logger.Warnf("compilation result for unknown payload: %s", payloadID) + return + } + + if success { + fc.payloadFeedbacks[index].CompilationStatus = "success" + fc.payloadFeedbacks[index].CompilationError = "" + } else { + fc.payloadFeedbacks[index].CompilationStatus = "failed" + fc.payloadFeedbacks[index].CompilationError = errorMsg + } + + fc.logger.Debugf("recorded compilation result for payload %d: %s - %s", + index, payloadID, fc.payloadFeedbacks[index].CompilationStatus) +} + func (fc *FeedbackCollector) RecordResult(result TransactionResult, batchID int) { fc.mutex.Lock() defer fc.mutex.Unlock() @@ -46,42 +112,29 @@ func (fc *FeedbackCollector) RecordResult(result TransactionResult, batchID int) return } - // Get or create payload stats - stats, exists := fc.payloadStats[result.PayloadDescription] + index, exists := fc.payloadLookup[result.PayloadDescription] if !exists { - stats = &PayloadStats{ - Description: result.PayloadDescription, - SuccessfulCalls: make([]TransactionResult, 0), - FailedCalls: make([]TransactionResult, 0), - } - fc.payloadStats[result.PayloadDescription] = stats + fc.logger.Warnf("execution result for unknown payload: %s", result.PayloadDescription) + return } - // Record result + // Update execution feedback + fc.payloadFeedbacks[index].ExecutionCount++ + fc.payloadFeedbacks[index].GasUsed = result.GasUsed + if result.Status == "success" { - stats.SuccessCount++ - stats.TotalGasUsed += result.GasUsed - stats.SuccessfulCalls = append(stats.SuccessfulCalls, result) - - // Keep only recent successful calls (limit to maxResults/2 per payload) - maxPerPayload := int(fc.maxResults / 2) - if len(stats.SuccessfulCalls) > maxPerPayload { - stats.SuccessfulCalls = stats.SuccessfulCalls[len(stats.SuccessfulCalls)-maxPerPayload:] - } + fc.payloadFeedbacks[index].ExecutionStatus = "success" + fc.payloadFeedbacks[index].ExecutionError = "" + fc.payloadFeedbacks[index].LastExecutionResult = fmt.Sprintf("gas:%d", result.GasUsed) } else { - stats.FailureCount++ - stats.FailedCalls = append(stats.FailedCalls, result) - - // Keep only recent failed calls (limit to maxResults/2 per payload) - maxPerPayload := int(fc.maxResults / 2) - if len(stats.FailedCalls) > maxPerPayload { - stats.FailedCalls = stats.FailedCalls[len(stats.FailedCalls)-maxPerPayload:] - } + fc.payloadFeedbacks[index].ExecutionStatus = "failed" + fc.payloadFeedbacks[index].ExecutionError = result.ErrorMessage + fc.payloadFeedbacks[index].LastExecutionResult = fmt.Sprintf("error:%s", result.ErrorMessage) } - fc.logger.Debugf("recorded transaction result for batch %d: %s (%s) - %s (success: %d, failed: %d)", - batchID, result.PayloadDescription, result.PayloadType, result.Status, - stats.SuccessCount, stats.FailureCount) + fc.logger.Debugf("recorded execution result for payload %d: %s - %s (count: %d)", + index, result.PayloadDescription, fc.payloadFeedbacks[index].ExecutionStatus, + fc.payloadFeedbacks[index].ExecutionCount) } func (fc *FeedbackCollector) RecordFailure(payload *PayloadInstance, status, errorMsg string, batchID int) { @@ -100,7 +153,7 @@ func (fc *FeedbackCollector) GenerateFeedback() *TransactionFeedback { fc.mutex.RLock() defer fc.mutex.RUnlock() - if len(fc.payloadStats) == 0 { + if len(fc.payloadFeedbacks) == 0 { return nil } @@ -108,22 +161,31 @@ func (fc *FeedbackCollector) GenerateFeedback() *TransactionFeedback { var allGasValues []uint64 var allResults []TransactionResult - // Aggregate stats from all payloads - for _, stats := range fc.payloadStats { - totalTxs += uint64(stats.SuccessCount + stats.FailureCount) - successfulTxs += uint64(stats.SuccessCount) - failedTxs += uint64(stats.FailureCount) + // Process payloads in order + for _, feedback := range fc.payloadFeedbacks { + if feedback.ExecutionCount > 0 { + totalTxs += uint64(feedback.ExecutionCount) + + if feedback.ExecutionStatus == "success" { + successfulTxs += uint64(feedback.ExecutionCount) + if feedback.GasUsed > 0 { + allGasValues = append(allGasValues, feedback.GasUsed) + } + } else { + failedTxs += uint64(feedback.ExecutionCount) + } - // Collect gas values from successful calls - for _, result := range stats.SuccessfulCalls { - if result.GasUsed > 0 { - allGasValues = append(allGasValues, result.GasUsed) + // Create a result entry for this payload + result := TransactionResult{ + PayloadType: "geas", + PayloadDescription: feedback.Description, + Status: feedback.ExecutionStatus, + GasUsed: feedback.GasUsed, + BlockExecTime: "N/A", + ErrorMessage: feedback.ExecutionError, } allResults = append(allResults, result) } - - // Include failed calls in results - allResults = append(allResults, stats.FailedCalls...) } var avgGas, medianGas uint64 @@ -153,63 +215,31 @@ func (fc *FeedbackCollector) GenerateFeedback() *TransactionFeedback { } func (fc *FeedbackCollector) generateDetailedSummary() string { - if len(fc.payloadStats) == 0 { - return "No transaction results for current batch yet." + if len(fc.payloadFeedbacks) == 0 { + return "No payloads received for current batch yet." } var summaryParts []string - for description, stats := range fc.payloadStats { - total := stats.SuccessCount + stats.FailureCount - if total == 0 { - continue - } - - successRate := float64(stats.SuccessCount) / float64(total) * 100 - - // Calculate average gas for this payload - var avgGas uint64 - if stats.SuccessCount > 0 { - avgGas = stats.TotalGasUsed / uint64(stats.SuccessCount) - } - - // Get recent logs from successful calls - var recentLogs []string - for _, result := range stats.SuccessfulCalls { - if len(result.LogData) > 0 { - recentLogs = append(recentLogs, result.LogData[0]) // Take first log - if len(recentLogs) >= 2 { // Limit to 2 logs per payload - break - } - } - } - - // Get recent error messages from failed calls - var recentErrors []string - for _, result := range stats.FailedCalls { - if result.ErrorMessage != "" { - recentErrors = append(recentErrors, result.ErrorMessage) - if len(recentErrors) >= 2 { // Limit to 2 errors per payload - break - } - } - } - - payloadSummary := fmt.Sprintf("'%s': %.1f%% success (%d/%d), avg_gas: %d", - description, successRate, stats.SuccessCount, total, avgGas) - - if len(recentLogs) > 0 { - payloadSummary += fmt.Sprintf(", recent_logs: [%s]", strings.Join(recentLogs, ", ")) - } - - if len(recentErrors) > 0 { - payloadSummary += fmt.Sprintf(", recent_errors: [%s]", strings.Join(recentErrors, ", ")) + // Generate feedback for each payload in order + for _, feedback := range fc.payloadFeedbacks { + var status string + if feedback.CompilationStatus == "failed" { + status = fmt.Sprintf("compilation_failed: %s", feedback.CompilationError) + } else if feedback.ExecutionStatus == "not_executed" { + status = "not_executed" + } else if feedback.ExecutionStatus == "failed" { + status = fmt.Sprintf("execution_failed: %s", feedback.ExecutionError) + } else { + status = fmt.Sprintf("success: gas=%d", feedback.GasUsed) } + payloadSummary := fmt.Sprintf("%s ('%s'): %s", + feedback.PayloadID, feedback.Description, status) summaryParts = append(summaryParts, payloadSummary) } - return fmt.Sprintf("Detailed payload analysis: %s", strings.Join(summaryParts, " | ")) + return fmt.Sprintf("ORDERED PAYLOAD FEEDBACK: %s", strings.Join(summaryParts, " | ")) } // StartNewBatch resets the feedback collector for a new AI response batch @@ -218,7 +248,8 @@ func (fc *FeedbackCollector) StartNewBatch() { defer fc.mutex.Unlock() fc.currentBatchID++ - fc.payloadStats = make(map[string]*PayloadStats) + fc.payloadFeedbacks = make([]PayloadFeedback, 0) + fc.payloadLookup = make(map[string]int) fc.logger.Infof("started new feedback batch %d", fc.currentBatchID) } @@ -233,24 +264,29 @@ func (fc *FeedbackCollector) GetCurrentBatchStats() (uint64, uint64, uint64) { defer fc.mutex.RUnlock() var totalTxs, successfulTxs, failedTxs uint64 - for _, stats := range fc.payloadStats { - totalTxs += uint64(stats.SuccessCount + stats.FailureCount) - successfulTxs += uint64(stats.SuccessCount) - failedTxs += uint64(stats.FailureCount) + for _, feedback := range fc.payloadFeedbacks { + if feedback.ExecutionCount > 0 { + totalTxs += uint64(feedback.ExecutionCount) + if feedback.ExecutionStatus == "success" { + successfulTxs += uint64(feedback.ExecutionCount) + } else { + failedTxs += uint64(feedback.ExecutionCount) + } + } } return totalTxs, successfulTxs, failedTxs } // GetFailedPayloads returns payloads that have failures for immediate feedback -func (fc *FeedbackCollector) GetFailedPayloads() map[string]*PayloadStats { +func (fc *FeedbackCollector) GetFailedPayloads() []PayloadFeedback { fc.mutex.RLock() defer fc.mutex.RUnlock() - failedPayloads := make(map[string]*PayloadStats) - for description, stats := range fc.payloadStats { - if stats.FailureCount > 0 { - failedPayloads[description] = stats + var failedPayloads []PayloadFeedback + for _, feedback := range fc.payloadFeedbacks { + if feedback.CompilationStatus == "failed" || feedback.ExecutionStatus == "failed" { + failedPayloads = append(failedPayloads, feedback) } } diff --git a/scenarios/aitx/payload_template.go b/scenarios/aitx/payload_template.go index 987a4484..276682e6 100644 --- a/scenarios/aitx/payload_template.go +++ b/scenarios/aitx/payload_template.go @@ -16,6 +16,7 @@ import ( ) type PayloadTemplate struct { + ID string `json:"id"` Type string `json:"type"` Description string `json:"description"` InitCode string `json:"init_code"` @@ -26,6 +27,7 @@ type PayloadTemplate struct { } type PayloadInstance struct { + ID string Type string Description string InitCode string @@ -51,6 +53,7 @@ func NewPlaceholderSubstituter(walletPool *spamoor.WalletPool, client *spamoor.C func (pt *PayloadTemplate) Substitute(substituter *PlaceholderSubstituter) (*PayloadInstance, error) { instance := &PayloadInstance{ + ID: pt.ID, Type: pt.Type, Description: pt.Description, } diff --git a/scenarios/aitx/streaming_callback.go b/scenarios/aitx/streaming_callback.go new file mode 100644 index 00000000..5d75a6fb --- /dev/null +++ b/scenarios/aitx/streaming_callback.go @@ -0,0 +1,234 @@ +package aitx + +import ( + "encoding/json" + "fmt" + "strings" + "sync" + + "github.com/sirupsen/logrus" +) + +type PayloadStreamingCallback struct { + processor *PayloadProcessor + logger logrus.FieldLogger + payloadBuffer *strings.Builder + + // Track parsing state + mutex sync.Mutex + parsedPayloads []PayloadTemplate + totalContent strings.Builder +} + +func (c *PayloadStreamingCallback) OnContent(content string) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + // Simply accumulate all content + c.totalContent.WriteString(content) + c.payloadBuffer.WriteString(content) + + // Check for complete JSON blocks in the accumulated content + c.processCompleteJSONBlocks() + + return nil +} + +// processCompleteJSONBlocks searches for and processes complete JSON blocks +func (c *PayloadStreamingCallback) processCompleteJSONBlocks() { + content := c.payloadBuffer.String() + + //c.logger.Debugf("streaming rsp: %s", content) + + // Look for complete JSON blocks (```json ... ```) + for { + startIdx := strings.Index(content, "```json") + if startIdx == -1 { + break + } + + // Find the end of this JSON block + endMarker := "```" + endIdx := strings.Index(content[startIdx+7:], endMarker) // Skip past "```json" + if endIdx == -1 { + // Incomplete block, wait for more content + break + } + + // Extract the JSON content + jsonStart := startIdx + 7 // Skip "```json" + jsonEnd := startIdx + 7 + endIdx + jsonStr := strings.TrimSpace(content[jsonStart:jsonEnd]) + + if jsonStr != "" { + c.logger.Debugf("Found complete JSON block: %s", jsonStr) + + // Try to parse and validate immediately + var payload PayloadTemplate + if err := json.Unmarshal([]byte(jsonStr), &payload); err == nil { + c.logger.Infof("Streaming: Parsed payload '%s', validating...", payload.Description) + + // Validate payload in real-time + if validatedPayloads, err := c.processor.ProcessPayloads([]PayloadTemplate{payload}); err == nil { + c.parsedPayloads = append(c.parsedPayloads, validatedPayloads...) + c.logger.Infof("Streaming: Validated and ready payload '%s'", payload.Description) + } else { + c.logger.Warnf("Streaming: Payload validation failed for '%s': %v", payload.Description, err) + } + } else { + c.logger.Warnf("Streaming: Failed to parse JSON block: %v", err) + } + } + + // Remove the processed block from buffer and continue searching + content = content[jsonEnd+3:] // Skip past the closing ``` + } + + // Update the buffer with remaining unprocessed content + c.payloadBuffer.Reset() + c.payloadBuffer.WriteString(content) +} + +func (c *PayloadStreamingCallback) OnComplete(fullContent string) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.logger.Infof("Streaming completed. Total payloads parsed in real-time: %d", len(c.parsedPayloads)) + + // Always try to extract additional payloads from the full content in case streaming missed any + // This handles cases where JSON blocks span multiple content chunks or have formatting issues + additionalPayloads, err := c.extractJSONObjectsFromText(fullContent) + if err == nil && len(additionalPayloads) > len(c.parsedPayloads) { + c.logger.Infof("Found %d additional payloads in complete content (streaming found %d)", + len(additionalPayloads)-len(c.parsedPayloads), len(c.parsedPayloads)) + + // Validate the additional payloads + if validatedPayloads, err := c.processor.ProcessPayloads(additionalPayloads); err == nil { + c.parsedPayloads = validatedPayloads + c.logger.Infof("Final: Validated %d total payloads", len(c.parsedPayloads)) + } else { + c.logger.Warnf("Failed to validate additional payloads: %v", err) + } + } else if len(c.parsedPayloads) == 0 { + c.logger.Debugf("No payloads parsed during streaming, attempting full content parsing") + + // Try to extract payloads from the complete content as fallback + payloads, err := c.extractJSONObjectsFromText(fullContent) + if err != nil { + return fmt.Errorf("failed to parse any payloads from complete content: %w", err) + } + + // Validate all payloads at once as fallback + if validatedPayloads, err := c.processor.ProcessPayloads(payloads); err == nil { + c.parsedPayloads = validatedPayloads + c.logger.Infof("Fallback: Validated %d payloads from complete content", len(c.parsedPayloads)) + } else { + return fmt.Errorf("failed to validate fallback payloads: %w", err) + } + } + + return nil +} + +func (c *PayloadStreamingCallback) GetParsedPayloads() []PayloadTemplate { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.parsedPayloads +} + +// extractJSONObjectsFromText is a fallback method copied from ai_service.go +func (c *PayloadStreamingCallback) extractJSONObjectsFromText(content string) ([]PayloadTemplate, error) { + var payloads []PayloadTemplate + + // Look for JSON code blocks marked with ```json + lines := strings.Split(content, "\n") + var jsonBlock strings.Builder + inJSONBlock := false + + for _, line := range lines { + line = strings.TrimSpace(line) + + if strings.HasPrefix(line, "```json") { + inJSONBlock = true + jsonBlock.Reset() + continue + } + + if strings.HasPrefix(line, "```") && inJSONBlock { + // End of JSON block, try to parse it + jsonStr := jsonBlock.String() + c.logger.Debugf("Attempting to parse JSON block: %s", jsonStr) + + var payload PayloadTemplate + if err := json.Unmarshal([]byte(jsonStr), &payload); err == nil { + payloads = append(payloads, payload) + c.logger.Debugf("Successfully parsed payload: %s", payload.Description) + } else { + c.logger.Errorf("Failed to parse JSON block: %v", err) + } + + inJSONBlock = false + continue + } + + if inJSONBlock { + jsonBlock.WriteString(line) + jsonBlock.WriteString("\n") + } + } + + // If we found payloads, return them + if len(payloads) > 0 { + return payloads, nil + } + + // Fallback: look for individual JSON objects using regex-like approach + return c.extractJSONObjectsWithRegex(content) +} + +func (c *PayloadStreamingCallback) extractJSONObjectsWithRegex(content string) ([]PayloadTemplate, error) { + var payloads []PayloadTemplate + + // Look for patterns like { ... } that might be JSON objects + braceLevel := 0 + var currentObj strings.Builder + inObject := false + + for i, r := range content { + if r == '{' { + if braceLevel == 0 { + inObject = true + currentObj.Reset() + } + braceLevel++ + currentObj.WriteRune(r) + } else if r == '}' { + braceLevel-- + currentObj.WriteRune(r) + + if braceLevel == 0 && inObject { + // Try to parse this object + objStr := strings.TrimSpace(currentObj.String()) + c.logger.Infof("Attempting to parse JSON object: %s", objStr) + + var payload PayloadTemplate + if err := json.Unmarshal([]byte(objStr), &payload); err == nil { + payloads = append(payloads, payload) + c.logger.Infof("Successfully parsed payload: %s", payload.Description) + } else { + c.logger.Errorf("Failed to parse JSON object at position %d: %v", i, err) + } + + inObject = false + } + } else if inObject { + currentObj.WriteRune(r) + } + } + + if len(payloads) == 0 { + return nil, fmt.Errorf("no valid JSON objects found in response") + } + + return payloads, nil +}