diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index a69745d..3933184 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -69,7 +69,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 diff --git a/HISTORY.rst b/HISTORY.rst index 66d0761..1812015 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,6 +3,93 @@ Release History =============== +0.2.1b7 ++++++++ + +Build stage re-entry fix +~~~~~~~~~~~~~~~~~~~~~~~~~~ +* **QA remediation failure now retries on re-entry** — fixed a bug where + ``mark_stage_generated()`` was called after each remediation attempt + inside ``_run_stage_qa()``, leaving the stage with status ``"generated"`` + even when QA subsequently failed. On re-entry, the stage was skipped + instead of retried. Changed to ``mark_stage_validating()`` so failed + stages remain in the retry list. + +QA checklist hardening +~~~~~~~~~~~~~~~~~~~~~~~~ +* **Aligned response_export_values directive** — QA checklist now requires + ``response_export_values = ["*"]`` on EVERY ``azapi_resource``, matching + the terraform agent's mandatory rule (was conditional on output usage). +* **Added deploy.sh -state= flag check** — QA checklist now flags use of + ``terraform output -state=`` which was removed in Terraform 1.9. +* **Added UUID hex validation** — QA checklist now checks that UUID values + in role assignment names contain only valid hex characters ``[0-9a-f]``. + +Full stage retry on QA exhaustion +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* **Full stage retry when QA remediation fails** -- when QA remediation + exhausts all attempts for a stage, the build now retries the entire + stage from scratch (clean artifacts, regenerate, QA) instead of + stopping the build immediately. Previous QA findings are injected + into the new generation prompt — framed as guidance rather than + file-specific instructions — so the model avoids the same classes + of mistakes on the fresh attempt. + + In practice, the same generation prompt produces passing code ~90% + of the time. The remaining ~10% failure rate is stochastic — not a + systematic prompt deficiency — meaning a fresh generation with + knowledge of what went wrong almost always succeeds. Without this + retry, that 10% forces the user to manually re-run the entire build, + losing the progress of all previously generated stages. The retry + doubles the token cost of one stage in the worst case, but saves + the full cost of restarting a 16-stage build from scratch. + + Controlled by ``_MAX_FULL_STAGE_ATTEMPTS`` (default 2: 1 initial + + 1 fresh retry). + +Generation prompt improvements +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* **Front-loaded remote state no-dead-code directive** — when upstream + stages exist, a ``CROSS-STAGE DEPENDENCIES — NO DEAD CODE`` section + now appears before the architecture context in the generation prompt, + reducing unused ``terraform_remote_state`` data sources. + +Agent-level service filtering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* **Agent governance checks now filter by service namespace** — added + ``stage_services`` field to ``AgentContext``, populated by + ``_agent_build_context()``. ``_apply_governance_check()`` now passes + stage services to ``validate_response()``, reducing false positive + anti-pattern warnings for irrelevant service namespaces. + +ReDoS fix in transform handlers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* **Replaced nested-quantifier regex with brace counting** — extracted + shared ``_find_azapi_blocks()`` helper and rewrote + ``_add_response_export_values``, ``_add_resource_group_parent_id``, + and ``_remove_private_endpoint_resources`` to use it. Eliminates + potential exponential backtracking on pathological input. + +Test suite consolidation +~~~~~~~~~~~~~~~~~~~~~~~~~~ +* **Consolidated and enhanced unit test coverage** — migrated flat test + files to a mirrored directory structure (1:1 test-to-source mapping), + merged split test files, and removed ~114 duplicate tests across 10 + files. Test suite reduced from 3,644 to 3,530 tests with zero loss + of unique coverage. + +QA review continuation for large stages +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* **QA review collects complete response before evaluating** — when the + QA review response is truncated (``finish_reason=length``), the build + session now continues requesting until the full review is received, + then evaluates the concatenated result. Uses the existing + ``_execute_with_continuation()`` pattern with a review-specific + continuation prompt that prevents the QA agent from generating code + in the continuation. Conversation history is saved and restored + around QA calls to prevent review messages from contaminating + subsequent stage generation. + 0.2.1b6 +++++++ diff --git a/azext_prototype/agents/base.py b/azext_prototype/agents/base.py index f6dddba..d0e17a5 100644 --- a/azext_prototype/agents/base.py +++ b/azext_prototype/agents/base.py @@ -81,6 +81,7 @@ class AgentContext: artifacts: dict[str, Any] = field(default_factory=dict) shared_state: dict[str, Any] = field(default_factory=dict) mcp_manager: Any = None # MCPManager | None — typed as Any to avoid circular import + stage_services: list[str] | None = None # ARM namespaces for service filtering def add_artifact(self, key: str, value: Any): """Store an artifact for other agents to reference.""" @@ -299,7 +300,7 @@ def _apply_governance_check(self, response: AIResponse, context: AgentContext) - avoid duplicating the governance warning block. """ iac_tool = context.project_config.get("project", {}).get("iac_tool") if context.project_config else None - warnings = self.validate_response(response.content, iac_tool=iac_tool, services=None) + warnings = self.validate_response(response.content, iac_tool=iac_tool, services=context.stage_services) if warnings: for w in warnings: logger.warning("Governance: %s", w) diff --git a/azext_prototype/agents/builtin/qa_engineer.py b/azext_prototype/agents/builtin/qa_engineer.py index 96adf53..893495b 100644 --- a/azext_prototype/agents/builtin/qa_engineer.py +++ b/azext_prototype/agents/builtin/qa_engineer.py @@ -233,6 +233,8 @@ def _encode_image(path: str) -> str: - [ ] deploy.sh includes error handling (set -euo pipefail, trap) - [ ] deploy.sh exports outputs to JSON file for downstream stages - [ ] deploy.sh includes Azure login verification +- [ ] deploy.sh does NOT use `terraform output -state=` — this flag was removed + in Terraform 1.9. Use `jq` on the state file or `cd` into the stage directory ### 4. Output Completeness - [ ] outputs.tf exports resource group name(s) @@ -251,8 +253,8 @@ def _encode_image(path: str) -> str: - [ ] All referenced variables are defined in variables.tf - [ ] All referenced locals are defined in locals.tf - [ ] Application code includes all referenced classes/models/DTOs -- [ ] Every azapi_resource whose `.output.properties` is referenced in - outputs.tf MUST have `response_export_values = ["*"]` declared +- [ ] EVERY `azapi_resource` block MUST have `response_export_values = ["*"]` + declared — no exceptions, even if outputs.tf does not reference its properties - [ ] No .tf file is empty or contains only comments (dead files) ### 7. Terraform File Structure @@ -314,6 +316,8 @@ def _encode_image(path: str) -> str: **NOT** string interpolation on the storage account ID - [ ] RBAC assignments for the worker identity (Stage 1) are **unconditional** (no `count`). The worker identity exists before any service stage runs. +- [ ] UUID values in role assignment names contain only valid hex characters + `[0-9a-f]` — letters `g`-`z` are invalid and ARM rejects with `InvalidName` ### 13. Application Code (app stages only) - [ ] Application source code is syntactically correct and complete diff --git a/azext_prototype/ai/token_tracker.py b/azext_prototype/ai/token_tracker.py index ef7c50b..86f9bad 100644 --- a/azext_prototype/ai/token_tracker.py +++ b/azext_prototype/ai/token_tracker.py @@ -44,30 +44,35 @@ # GitHub Copilot Premium Request Unit (PRU) multipliers. # Each API call costs (1 × multiplier) PRUs. Only applies to the # Copilot provider — models not in this table produce 0 PRUs. -# Source: https://docs.github.com/en/copilot/concepts/billing/copilot-requests +# Source: https://docs.github.com/en/copilot/managing-copilot/monitoring-usage-and-entitlements/about-premium-requests +# Last updated: 2026-04-08 _PRU_MULTIPLIERS: dict[str, float] = { # Included with paid plans (0 PRUs) "gpt-5-mini": 0, "gpt-4.1": 0, "gpt-4o": 0, + "raptor-mini": 0, # Low-cost (0.25–0.33 PRUs per request) "grok-code-fast-1": 0.25, "claude-haiku-4.5": 0.33, "gemini-3-flash": 0.33, - "gpt-5.1-codex-mini": 0.33, "gpt-5.4-mini": 0.33, # Standard (1 PRU per request) "claude-sonnet-4": 1, "claude-sonnet-4.5": 1, "claude-sonnet-4.6": 1, + "gemini-2.5-pro": 1, "gemini-3-pro": 1, - "gemini-3-pro-1.5": 1, + "gemini-3.1-pro": 1, "gpt-5.1": 1, "gpt-5.2": 1, + "gpt-5.2-codex": 1, + "gpt-5.3-codex": 1, "gpt-5.4": 1, # Premium (3+ PRUs per request) "claude-opus-4.5": 3, "claude-opus-4.6": 3, + "claude-opus-4.6-fast": 30, } diff --git a/azext_prototype/azext_metadata.json b/azext_prototype/azext_metadata.json index 33cd7d7..5cb05a4 100644 --- a/azext_prototype/azext_metadata.json +++ b/azext_prototype/azext_metadata.json @@ -2,7 +2,7 @@ "azext.isPreview": true, "azext.minCliCoreVersion": "2.50.0", "name": "prototype", - "version": "0.2.1b6", + "version": "0.2.1b7", "azext.summary": "Azure CLI extension for building rapid prototypes with GitHub Copilot.", "license": "MIT", "classifiers": [ diff --git a/azext_prototype/governance/transforms/__init__.py b/azext_prototype/governance/transforms/__init__.py index f5ff700..af7359a 100644 --- a/azext_prototype/governance/transforms/__init__.py +++ b/azext_prototype/governance/transforms/__init__.py @@ -268,6 +268,32 @@ def _remove_unused_remote_state(content: str, stage_content: str | None = None) return result +def _find_azapi_blocks(content: str) -> list[tuple[int, int, str, str]]: + """Find all ``azapi_resource`` blocks using brace counting. + + Returns a list of ``(start, end, resource_name, block_text)`` tuples + where *start*/*end* are character offsets into *content*. + """ + pattern = re.compile(r'resource\s+"azapi_resource"\s+"(\w+)"\s*\{') + blocks: list[tuple[int, int, str, str]] = [] + for match in pattern.finditer(content): + name = match.group(1) + start = match.start() + brace_start = match.end() - 1 + depth = 1 + pos = brace_start + 1 + while pos < len(content) and depth > 0: + if content[pos] == "{": + depth += 1 + elif content[pos] == "}": + depth -= 1 + pos += 1 + if depth != 0: + continue # malformed block + blocks.append((start, pos, name, content[start:pos])) + return blocks + + def _remove_private_endpoint_resources(content: str) -> str: """Remove private endpoint and DNS zone resources from non-networking stages. @@ -287,42 +313,20 @@ def _remove_private_endpoint_resources(content: str) -> str: "virtualnetworklinks", ) - # Find resource block starts and use brace counting to find the end - block_start_pattern = re.compile( - r'resource\s+"azapi_resource"\s+"(\w+)"\s*\{', - ) - removed_names: list[str] = [] result = content - for match in reversed(list(block_start_pattern.finditer(result))): - resource_name = match.group(1) - # Find the matching closing brace using brace counting - start = match.start() - brace_start = match.end() - 1 # position of opening { - depth = 1 - pos = brace_start + 1 - while pos < len(result) and depth > 0: - if result[pos] == "{": - depth += 1 - elif result[pos] == "}": - depth -= 1 - pos += 1 - if depth != 0: - continue # malformed block, skip - - block_text = result[start:pos] - # Check if this block's type is a PE/DNS type + for start, end, resource_name, block_text in reversed(_find_azapi_blocks(result)): type_match = re.search(r'type\s*=\s*"([^"]+)"', block_text) if not type_match: continue resource_type = type_match.group(1).lower() if any(pt in resource_type for pt in pe_types): # Remove the block plus any trailing whitespace/newlines - end = pos - while end < len(result) and result[end] in ("\n", "\r", " "): - end += 1 - result = result[:start] + result[end:] + trim_end = end + while trim_end < len(result) and result[trim_end] in ("\n", "\r", " "): + trim_end += 1 + result = result[:start] + result[trim_end:] removed_names.append(resource_name) logger.debug("Removed PE/DNS resource: azapi_resource.%s", resource_name) @@ -351,29 +355,26 @@ def _remove_private_endpoint_resources(content: str) -> str: def _add_response_export_values(content: str) -> str: """Add ``response_export_values = ["*"]`` to azapi_resource blocks missing it. - Finds each ``resource "azapi_resource" "name" { ... }`` block and checks - if ``response_export_values`` appears inside it. If missing, inserts it - after the ``parent_id`` line (or after ``type`` if no ``parent_id``). + Uses brace-counting via :func:`_find_azapi_blocks` to avoid nested-quantifier + regex (ReDoS risk). Inserts after ``parent_id``, ``location``, or ``type``. """ - # Match azapi_resource blocks - block_pattern = re.compile( - r'(resource\s+"azapi_resource"\s+"\w+"\s*\{)(.*?\n)((?:.*?\n)*?)(})', - re.DOTALL, - ) - - def _inject(match: re.Match) -> str: # type: ignore[type-arg] - full = match.group(0) - if "response_export_values" in full: - return full # already has it + result = content + for start, end, _name, block_text in reversed(_find_azapi_blocks(result)): + if "response_export_values" in block_text: + continue - header = match.group(1) - first_line = match.group(2) - body = match.group(3) - closing = match.group(4) + # Split block body (after the opening { line) into lines + header_end = block_text.index("{") + 1 + header = block_text[:header_end] + body_plus_close = block_text[header_end:] + # Remove the final closing brace + body = body_plus_close.rstrip() + if body.endswith("}"): + body = body[:-1] + closing = "}" - # Find insertion point: after parent_id, or after location, or after type - lines = (first_line + body).splitlines(keepends=True) - insert_idx = len(lines) # fallback: before closing brace + lines = body.splitlines(keepends=True) + insert_idx = len(lines) for i, line in enumerate(lines): stripped = line.strip() if stripped.startswith("parent_id"): @@ -384,48 +385,42 @@ def _inject(match: re.Match) -> str: # type: ignore[type-arg] elif stripped.startswith("type") and insert_idx == len(lines): insert_idx = i + 1 - # Detect indentation from the type/parent_id line indent = " " - if insert_idx > 0 and insert_idx <= len(lines): + if 0 < insert_idx <= len(lines): prev_line = lines[insert_idx - 1] leading = len(prev_line) - len(prev_line.lstrip()) indent = " " * leading lines.insert(insert_idx, f'\n{indent}response_export_values = ["*"]\n') - return header + "".join(lines) + closing + new_block = header + "".join(lines) + closing + result = result[:start] + new_block + result[end:] - new_content = block_pattern.sub(_inject, content) - if new_content != content: + if result != content: logger.debug("Added response_export_values to azapi_resource blocks") - return new_content + return result def _add_resource_group_parent_id(content: str) -> str: """Add ``parent_id`` to resource group azapi_resource blocks missing it. - Finds ``azapi_resource`` blocks whose type contains - ``Microsoft.Resources/resourceGroups`` and injects - ``parent_id = "/subscriptions/${var.subscription_id}"`` - after the ``name`` line. + Uses brace-counting via :func:`_find_azapi_blocks` to avoid nested-quantifier + regex (ReDoS risk). Injects after the ``name`` line. """ - # Match azapi_resource blocks with resourceGroups type - block_pattern = re.compile( - r'(resource\s+"azapi_resource"\s+"\w+"\s*\{)(.*?)(})', - re.DOTALL, - ) - - def _inject(match: re.Match) -> str: # type: ignore[type-arg] - full = match.group(0) - if "resourcegroups" not in full.lower(): - return full - if "parent_id" in full: - return full # already has it + result = content + for start, end, _name, block_text in reversed(_find_azapi_blocks(result)): + if "resourcegroups" not in block_text.lower(): + continue + if "parent_id" in block_text: + continue - header = match.group(1) - body = match.group(2) - closing = match.group(3) + header_end = block_text.index("{") + 1 + header = block_text[:header_end] + body_plus_close = block_text[header_end:] + body = body_plus_close.rstrip() + if body.endswith("}"): + body = body[:-1] + closing = "}" - # Insert after the name line lines = body.splitlines(keepends=True) insert_idx = len(lines) for i, line in enumerate(lines): @@ -433,20 +428,19 @@ def _inject(match: re.Match) -> str: # type: ignore[type-arg] insert_idx = i + 1 break - # Detect indentation indent = " " - if insert_idx > 0 and insert_idx <= len(lines): + if 0 < insert_idx <= len(lines): prev_line = lines[insert_idx - 1] leading = len(prev_line) - len(prev_line.lstrip()) indent = " " * leading lines.insert(insert_idx, f'{indent}parent_id = "/subscriptions/${{var.subscription_id}}"\n') - return header + "".join(lines) + closing + new_block = header + "".join(lines) + closing + result = result[:start] + new_block + result[end:] - new_content = block_pattern.sub(_inject, content) - if new_content != content: + if result != content: logger.debug("Added parent_id to resource group azapi_resource") - return new_content + return result _STRUCTURED_HANDLERS: dict[str, Callable] = { diff --git a/azext_prototype/stages/build_session.py b/azext_prototype/stages/build_session.py index 4b8a6e1..602edb5 100644 --- a/azext_prototype/stages/build_session.py +++ b/azext_prototype/stages/build_session.py @@ -29,7 +29,6 @@ from azext_prototype.agents.base import AgentCapability, AgentContext from azext_prototype.agents.governance import GovernanceContext -from azext_prototype.agents.orchestrator import AgentOrchestrator from azext_prototype.agents.registry import AgentRegistry from azext_prototype.config import ProjectConfig from azext_prototype.naming import create_naming_strategy @@ -59,6 +58,11 @@ # Maximum remediation cycles per stage before proceeding _MAX_STAGE_REMEDIATION_ATTEMPTS = 3 +# Maximum full stage attempts (generation + QA cycle). When QA remediation +# exhausts all attempts, the stage is cleaned and regenerated from scratch +# with prior QA findings injected into the generation prompt. +_MAX_FULL_STAGE_ATTEMPTS = 2 # 1 initial + 1 fresh retry + # Keywords that indicate QA found actionable issues (fallback tier) _QA_ISSUE_KEYWORDS = frozenset({"critical", "error", "missing", "fix", "issue", "broken"}) # Phrases that indicate QA found no issues (tier 2) @@ -247,6 +251,10 @@ def _first(cap: AgentCapability) -> Any | None: {"naming": {"strategy": "simple"}, "project": {"name": self._project_name}} ) + # Last QA content from a failed QA remediation — injected into the + # generation prompt on full stage retry so the model knows what to avoid. + self._last_qa_content: str = "" + # ------------------------------------------------------------------ # # Public API # ------------------------------------------------------------------ # @@ -395,7 +403,8 @@ def run( else: # Branch C: No design changes pending_check = self._build_state.get_pending_stages() - if pending_check: + validating_check = self._build_state.get_validating_stages() + if pending_check or validating_check: _print("Resuming from existing deployment plan.") _print("") else: @@ -490,7 +499,7 @@ def run( # Handle re-entry: "validating" stages need QA re-run only if stage_status == "validating": _print(f"[{generated_count}/{total_stages}] Stage {stage_num}: {stage_name} (re-validating)") - if layer in ("core", "infra", "data", "app"): + if stage.get("files"): qa_passed = self._run_stage_qa(stage, architecture, templates, use_styled, _print) if qa_passed: self._build_state.mark_stage_generated(stage_num, stage.get("files", []), "user-fix") @@ -521,57 +530,139 @@ def run( # Use condensed per-stage context (from one-time condensation call) focused_context = stage_contexts.get(stage_num, "") - # App-layer stages use architect → developer delegation - sub_layer_context = "" - if layer == "app" and (self._app_architect or self._csharp_dev or self._python_dev or self._react_dev): - agent, sub_layer_context = self._decompose_app_stage(stage, focused_context, _print) - else: - agent = self._select_agent(stage) - if not agent: - _print(f" Skipped (no agent for capability '{stage.get('capability', '')}')") - continue + # ---- Full stage retry loop ---- + # When QA remediation exhausts all attempts, clean the stage and + # regenerate from scratch with prior QA findings injected. + prior_qa_findings = "" + stage_completed = False + written_paths: list[str] = [] + agent = None - with self._agent_build_context(agent, stage): - # Clear conversation history so prior stage context cannot - # bleed into this stage (especially after truncation/continuation). - self._context.conversation_history.clear() + for full_attempt in range(_MAX_FULL_STAGE_ATTEMPTS): + if full_attempt > 0: + _print( + f" Full retry ({full_attempt}/{_MAX_FULL_STAGE_ATTEMPTS - 1}): " + f"regenerating stage from scratch..." + ) + self._build_state.clean_stage_artifacts(stage_num, self._context.project_dir) + + # App-layer stages use architect → developer delegation + sub_layer_context = "" + if layer == "app" and (self._app_architect or self._csharp_dev or self._python_dev or self._react_dev): + agent, sub_layer_context = self._decompose_app_stage(stage, focused_context, _print) + else: + agent = self._select_agent(stage) + if not agent: + _print(f" Skipped (no agent for capability '{stage.get('capability', '')}')") + break + + with self._agent_build_context(agent, stage): + # Clear conversation history so prior stage context cannot + # bleed into this stage (especially after truncation/continuation). + self._context.conversation_history.clear() + + _, task = self._build_stage_task( + stage, focused_context, templates, prior_qa_findings=prior_qa_findings + ) + + # Inject sub-layer guidance for app stages + if sub_layer_context: + task += f"\n{sub_layer_context}\n" + + _dbg_flow( + "build_session.generate", + f"Stage {stage_num} task prompt", + layer=layer, + capability=stage.get("capability", ""), + agent_name=agent.name, + delegated=bool(sub_layer_context), + task_len=len(task), + has_service_policies="MANDATORY RESOURCE POLICIES" in task, + has_api_versions="Resource API Versions" in task, + has_companion="Companion Resource Requirements" in task, + has_networking_note="Networking Stage" in task, + task_full=task, + ) - _, task = self._build_stage_task(stage, focused_context, templates) + self._build_state.mark_stage_generating(stage_num) + try: + with self._maybe_spinner(f"Building Stage {stage_num}: {stage_name}...", use_styled): + response = self._execute_with_retry( + agent, task, stage_num, stage_name, _print, stage_capability=layer + ) + if response is None: + # All retry attempts exhausted — stop build + build_stopped = True + break + except Exception as exc: + _print(f" Agent error in Stage {stage_num} — routing to QA for diagnosis...") + svc_names_list = [s.get("name", "") for s in services if s.get("name")] + route_error_to_qa( + exc, + f"Build Stage {stage_num}: {stage_name}", + self._qa_agent, + self._context, + self._token_tracker, + _print, + services=svc_names_list, + escalation_tracker=self._escalation_tracker, + source_agent=agent.name, + source_stage="build", + ) + break # Agent error — not stochastic, don't retry - # Inject sub-layer guidance for app stages - if sub_layer_context: - task += f"\n{sub_layer_context}\n" + if response: + self._token_tracker.record(response) + content = response.content if response else "" _dbg_flow( "build_session.generate", - f"Stage {stage_num} task prompt", + f"Stage {stage_num} response", layer=layer, capability=stage.get("capability", ""), agent_name=agent.name, - delegated=bool(sub_layer_context), - task_len=len(task), - has_service_policies="MANDATORY RESOURCE POLICIES" in task, - has_api_versions="Resource API Versions" in task, - has_companion="Companion Resource Requirements" in task, - has_networking_note="Networking Stage" in task, - task_full=task, + content_len=len(content) if content else 0, + content_type=type(content).__name__, + content_full=content if content else "(empty)", ) - self._build_state.mark_stage_generating(stage_num) - try: - with self._maybe_spinner(f"Building Stage {stage_num}: {stage_name}...", use_styled): - response = self._execute_with_retry( - agent, task, stage_num, stage_name, _print, stage_capability=layer + # Debug: scan response for anti-pattern violations before policy resolver + # Skip scanning for docs and app stages — docs describe the architecture + # and app stages generate source code, not IaC. Both trigger false positives. + if content and layer not in ("docs", "app"): + try: + from azext_prototype.governance.anti_patterns import ( + scan as _ap_scan, ) - if response is None: - # All retry attempts exhausted — stop build - build_stopped = True - break - except Exception as exc: - _print(f" Agent error in Stage {stage_num} — routing to QA for diagnosis...") + + stage_svc_types = [s.get("resource_type", "") for s in services if s.get("resource_type")] + _ap_violations = _ap_scan( + content, iac_tool=self._iac_tool, agent_name=agent.name, services=stage_svc_types + ) + if _ap_violations: + _dbg_flow( + "build_session.generate", + f"Stage {stage_num} anti-pattern violations detected", + violation_count=len(_ap_violations), + violations=_ap_violations, + ) + except Exception: + pass + + # Debug: check what the parser would extract + _dbg_files = parse_file_blocks(content) if content else {} + _dbg_flow( + "build_session.generate", + f"Stage {stage_num} parse_file_blocks", + file_count=len(_dbg_files), + filenames=list(_dbg_files.keys())[:10], + ) + + if not content: + _print(f" Empty response for Stage {stage_num} — routing to QA for diagnosis...") svc_names_list = [s.get("name", "") for s in services if s.get("name")] route_error_to_qa( - exc, + "Agent returned empty response", f"Build Stage {stage_num}: {stage_name}", self._qa_agent, self._context, @@ -582,150 +673,97 @@ def run( source_agent=agent.name, source_stage="build", ) - continue - - if response: - self._token_tracker.record(response) - content = response.content if response else "" - - _dbg_flow( - "build_session.generate", - f"Stage {stage_num} response", - layer=layer, - capability=stage.get("capability", ""), - agent_name=agent.name, - content_len=len(content) if content else 0, - content_type=type(content).__name__, - content_full=content if content else "(empty)", - ) - - # Debug: scan response for anti-pattern violations before policy resolver - # Skip scanning for docs and app stages — docs describe the architecture - # and app stages generate source code, not IaC. Both trigger false positives. - if content and layer not in ("docs", "app"): - try: - from azext_prototype.governance.anti_patterns import ( - scan as _ap_scan, - ) - - stage_svc_types = [s.get("resource_type", "") for s in services if s.get("resource_type")] - _ap_violations = _ap_scan( - content, iac_tool=self._iac_tool, agent_name=agent.name, services=stage_svc_types - ) - if _ap_violations: - _dbg_flow( - "build_session.generate", - f"Stage {stage_num} anti-pattern violations detected", - violation_count=len(_ap_violations), - violations=_ap_violations, - ) - except Exception: - pass + written_paths = self._write_stage_files(stage, content) + written_paths = self._apply_stage_transforms(stage, written_paths, _print) - # Debug: check what the parser would extract - _dbg_files = parse_file_blocks(content) if content else {} - _dbg_flow( - "build_session.generate", - f"Stage {stage_num} parse_file_blocks", - file_count=len(_dbg_files), - filenames=list(_dbg_files.keys())[:10], - ) - - if not content: - _print(f" Empty response for Stage {stage_num} — routing to QA for diagnosis...") - svc_names_list = [s.get("name", "") for s in services if s.get("name")] - route_error_to_qa( - "Agent returned empty response", - f"Build Stage {stage_num}: {stage_name}", - self._qa_agent, - self._context, - self._token_tracker, - _print, - services=svc_names_list, - escalation_tracker=self._escalation_tracker, - source_agent=agent.name, - source_stage="build", + _dbg_flow( + "build_session.generate", + f"Stage {stage_num} written_paths", + count=len(written_paths), + paths=written_paths[:5], ) - written_paths = self._write_stage_files(stage, content) - written_paths = self._apply_stage_transforms(stage, written_paths, _print) - - _dbg_flow( - "build_session.generate", - f"Stage {stage_num} written_paths", - count=len(written_paths), - paths=written_paths[:5], - ) - # Files written — mark as validating (ready for QA) - self._build_state.mark_stage_validating(stage_num, written_paths) + # Files written — mark as validating (ready for QA) + self._build_state.mark_stage_validating(stage_num, written_paths) - if written_paths: - if use_styled: - self._console.print_file_list(written_paths) + if written_paths: + if use_styled: + self._console.print_file_list(written_paths) + else: + for f in written_paths: + _print(f" {f}") else: - for f in written_paths: - _print(f" {f}") - else: - _print(" No files extracted from response.") - - # Policy check — runs on all stage categories - if content: - resolutions, needs_regen = self._policy_resolver.check_and_resolve( - agent.name, - content, - self._build_state, - stage_num, - input_fn=input_fn, - print_fn=print_fn, - iac_tool=self._iac_tool, - ) - - if needs_regen: - fix_instructions = self._policy_resolver.build_fix_instructions(resolutions) - _print("Regenerating with fix instructions...") + _print(" No files extracted from response.") + + # Policy check — runs on all stage categories + if content: + resolutions, needs_regen = self._policy_resolver.check_and_resolve( + agent.name, + content, + self._build_state, + stage_num, + input_fn=input_fn, + print_fn=print_fn, + iac_tool=self._iac_tool, + ) - try: - with self._maybe_spinner(f"Re-building Stage {stage_num}...", use_styled): - response = self._execute_with_retry( - agent, - task + fix_instructions, - stage_num, - stage_name, + if needs_regen: + fix_instructions = self._policy_resolver.build_fix_instructions(resolutions) + _print("Regenerating with fix instructions...") + + try: + with self._maybe_spinner(f"Re-building Stage {stage_num}...", use_styled): + response = self._execute_with_retry( + agent, + task + fix_instructions, + stage_num, + stage_name, + _print, + stage_capability=layer, + ) + if response is None: + build_stopped = True + break + except Exception as exc: + svc_names_list = [s.get("name", "") for s in services if s.get("name")] + route_error_to_qa( + exc, + f"Build Stage {stage_num} (regen): {stage_name}", + self._qa_agent, + self._context, + self._token_tracker, _print, - stage_capability=layer, + services=svc_names_list, + escalation_tracker=self._escalation_tracker, + source_agent=agent.name, + source_stage="build", ) - if response is None: - build_stopped = True - break - except Exception as exc: - svc_names_list = [s.get("name", "") for s in services if s.get("name")] - route_error_to_qa( - exc, - f"Build Stage {stage_num} (regen): {stage_name}", - self._qa_agent, - self._context, - self._token_tracker, - _print, - services=svc_names_list, - escalation_tracker=self._escalation_tracker, - source_agent=agent.name, - source_stage="build", - ) - continue + break # Agent error — not stochastic, don't retry + + if response: + self._token_tracker.record(response) + content = response.content if response else "" + written_paths = self._write_stage_files(stage, content) + written_paths = self._apply_stage_transforms(stage, written_paths, _print) + self._build_state.mark_stage_validating(stage_num, written_paths) + + # Per-stage QA validation — runs on all stages that produce files + qa_passed = True + if written_paths: + qa_passed = self._run_stage_qa(stage, architecture, templates, use_styled, _print) - if response: - self._token_tracker.record(response) - content = response.content if response else "" - written_paths = self._write_stage_files(stage, content) - written_paths = self._apply_stage_transforms(stage, written_paths, _print) - self._build_state.mark_stage_validating(stage_num, written_paths) + if qa_passed: + stage_completed = True + break # QA passed — exit retry loop + + # QA failed — capture findings for next full attempt + prior_qa_findings = self._last_qa_content - # Per-stage QA validation — runs on all stages that produce files - qa_passed = True - if written_paths: - qa_passed = self._run_stage_qa(stage, architecture, templates, use_styled, _print) + # ---- After full stage retry loop ---- + if build_stopped: + break # Propagate to stage loop - if qa_passed: + if stage_completed and agent is not None: self._build_state.mark_stage_generated(stage_num, written_paths, agent.name) # Per-stage advisory (non-blocking — failure is logged, not fatal) @@ -735,8 +773,10 @@ def run( if self._update_task_fn: self._update_task_fn(task_id, "completed") + elif not agent: + continue # No agent — skip to next stage else: - # QA failed after max attempts — stop build + # All full attempts exhausted — stop build build_stopped = True if use_styled: self._console.print_token_status(self._token_tracker.format_status()) @@ -1803,10 +1843,13 @@ def _agent_build_context(self, agent: Any, stage: dict) -> Iterator[Any]: layer = stage.get("layer", "") self._apply_governor_brief(agent, stage.get("name", ""), stage.get("services", []), layer) self._apply_stage_knowledge(agent, stage) + svc_types = [s.get("resource_type", "") for s in stage.get("services", []) if s.get("resource_type")] + self._context.stage_services = svc_types or None try: yield agent finally: agent.set_knowledge_override("") + self._context.stage_services = None def _apply_stage_knowledge(self, agent: Any, stage: dict) -> None: """Set stage-specific knowledge on the agent. @@ -2113,6 +2156,7 @@ def _build_stage_task( stage: dict, architecture: str, templates: list, + prior_qa_findings: str = "", ) -> tuple[Any | None, str]: """Build the task prompt for a stage. @@ -2199,9 +2243,37 @@ def _build_stage_task( layer = stage.get("layer", "") - task = ( - f"Generate{tool_label} code for deployment " - f"Stage {stage['stage']}: {stage_name}.\n\n" + task = f"Generate{tool_label} code for deployment " f"Stage {stage['stage']}: {stage_name}.\n\n" + + # Inject prior QA findings from a failed full-stage attempt so the + # model avoids repeating the same classes of mistakes. + if prior_qa_findings: + task += ( + "## CRITICAL: Previous QA Failures (DO NOT REPEAT THESE ISSUES)\n" + "A previous generation of this stage failed QA review after multiple\n" + "remediation attempts. The following issues could not be resolved.\n\n" + "NOTE: The file names and code structure below are from a previous\n" + "generation and may not match yours. Focus on the **underlying issues**\n" + "described — avoid the same classes of mistakes regardless of file\n" + "names or layout.\n\n" + f"{prior_qa_findings}\n\n" + "Generate this stage from scratch, ensuring these classes of issues\n" + "are avoided from the start.\n\n" + ) + + # Front-load the no-dead-code remote state directive so the model + # sees it BEFORE the architecture context (reduces unused data sources). + if is_iac and prev_context: + task += ( + "## CRITICAL: CROSS-STAGE DEPENDENCIES — NO DEAD CODE\n" + "ONLY declare `terraform_remote_state` data sources for stages whose\n" + "outputs you actually reference in resource definitions or locals.\n" + "Do NOT declare remote state data sources 'for completeness' or 'in case needed.'\n" + "Every `data.terraform_remote_state` block MUST have at least one output\n" + "referenced in `locals.tf` or `main.tf`. If it doesn't, do not create it.\n\n" + ) + + task += ( f"## Architecture Context\n{architecture}\n\n" f"## This Stage\n" f"Name: {stage_name}\n" @@ -3272,7 +3344,13 @@ def _run_stage_qa( from azext_prototype.debug_log import log_flow as _dbg stage_num = stage["stage"] - orchestrator = AgentOrchestrator(self._registry, self._context) + + _QA_CONTINUATION_PROMPT = ( + "Your previous review was cut off. Continue your review " + "EXACTLY where you left off. Do NOT regenerate or emit " + "any code — only continue the review analysis and " + "provide your VERDICT." + ) # Build context briefs once for all QA attempts services = stage.get("services", []) @@ -3288,7 +3366,7 @@ def _run_stage_qa( # 2. Build QA task qa_task = self._build_qa_task(stage_num, stage["name"], attempt, file_content, qa_context, layer) - # 3. Run QA (with timeout/rate-limit retry) + # 3. Run QA (with timeout/rate-limit retry and continuation) from azext_prototype.ai.copilot_provider import ( CopilotRateLimitError, CopilotTimeoutError, @@ -3297,12 +3375,16 @@ def _run_stage_qa( qa_result = None max_attempts = len(self._TIMEOUT_BACKOFFS) + 1 for qa_attempt in range(max_attempts): + # Save conversation history — QA messages must not + # leak into the generation context for subsequent stages. + saved_history = list(self._context.conversation_history) try: with self._maybe_spinner(f"QA reviewing Stage {stage_num}...", use_styled): - qa_result = orchestrator.delegate( - from_agent="build-session", - to_agent_name=self._qa_agent.name, - sub_task=qa_task, + qa_result = self._execute_with_continuation( + self._qa_agent, + qa_task, + max_continuations=3, + continuation_prompt=_QA_CONTINUATION_PROMPT, ) break except CopilotRateLimitError as exc: @@ -3319,6 +3401,8 @@ def _run_stage_qa( f"Stage {stage_num} will be retried on next build." ) return False + finally: + self._context.conversation_history = saved_history if qa_result: self._token_tracker.record(qa_result) @@ -3340,6 +3424,7 @@ def _run_stage_qa( if not has_issues: _print(f" Stage {stage_num} passed QA.") + self._last_qa_content = "" return True # 5. If at max attempts, report issues concisely and fail @@ -3353,6 +3438,7 @@ def _run_stage_qa( if stripped.startswith(("CRITICAL", "WARNING", "**CRITICAL", "**WARNING", "- [", "| ")): _print(f" {stripped}") _print("") + self._last_qa_content = qa_content or "" return False # 6. Remediate — re-invoke IaC agent with focused context + governance + knowledge @@ -3405,7 +3491,7 @@ def _run_stage_qa( content = response.content if response else "" written_paths = self._write_stage_files(stage, content) written_paths = self._apply_stage_transforms(stage, written_paths, _print) - self._build_state.mark_stage_generated(stage_num, written_paths, agent.name) + self._build_state.mark_stage_validating(stage_num, written_paths) return True # All remediation attempts completed without hitting max @@ -3559,6 +3645,7 @@ def _execute_with_continuation( stage_num: int = 0, stage_name: str = "", stage_capability: str = "", + continuation_prompt: str | None = None, ) -> Any: """Execute an agent task, automatically continuing if truncated. @@ -3567,6 +3654,11 @@ def _execute_with_continuation( an assistant message so the model can see what it already generated. A continuation prompt is then sent as a new user message, and the model picks up where it left off. + + If *continuation_prompt* is provided it is used verbatim instead of + the default code-generation continuation text. This is useful for + QA reviews where the continuation should request more review, not + more code. """ from azext_prototype.ai.provider import AIMessage, AIResponse @@ -3585,22 +3677,25 @@ def _execute_with_continuation( # model sees what it already generated when continuing. self._context.conversation_history.append(AIMessage(role="assistant", content=response.content or "")) - # Include stage context so the model stays on track - stage_hint = "" - if stage_num and stage_name: - stage_hint = ( - f" You are generating Stage {stage_num}: {stage_name} " - f"(layer: {stage_capability}). " - "Stay within this stage's scope — do not generate content " - "for any other stage." - ) + if continuation_prompt: + cont_task = continuation_prompt + else: + # Include stage context so the model stays on track + stage_hint = "" + if stage_num and stage_name: + stage_hint = ( + f" You are generating Stage {stage_num}: {stage_name} " + f"(layer: {stage_capability}). " + "Stay within this stage's scope — do not generate content " + "for any other stage." + ) - cont_task = ( - "Your previous response was cut off mid-generation. " - "Continue EXACTLY where you left off — do not repeat any " - "file or content already generated. Pick up mid-line if " - f"necessary. Maintain the same code block format.{stage_hint}" - ) + cont_task = ( + "Your previous response was cut off mid-generation. " + "Continue EXACTLY where you left off — do not repeat any " + "file or content already generated. Pick up mid-line if " + f"necessary. Maintain the same code block format.{stage_hint}" + ) self._context.conversation_history.append(AIMessage(role="user", content=cont_task)) cont = agent.execute(self._context, cont_task) diff --git a/benchmarks/2026-03-31-11-16-46.html b/benchmarks/2026-03-31-11-16-46.html deleted file mode 100644 index 3f4898e..0000000 --- a/benchmarks/2026-03-31-11-16-46.html +++ /dev/null @@ -1,484 +0,0 @@ - - - - - - Benchmark Run: {{DATE}} - - - - - - - -
-
-
-

- GitHub Copilot - vs - Claude Code -

-

Benchmark Run —

-
-
-

Project:

-

Model:

-

Stages won — GHCP: • Claude Code:

-
-
-
- - - - -
- - -
- - -
-
-

Project:

-

-
-
- - -
-

Benchmark Scores

-
- - - - - - - - - - - - - - - - - -
BenchmarkDescriptionGHCPClaude CodeDeltaWinner
Overall Average
-
-
- - -
-

Aggregate Scores by Stage

-
- - - - - - - - - -
StageServiceGHCPClaude CodeWinner
-
-
- - -
- - -
- - -
- - -
-

Final Verdict

-
-
-
-
-
GitHub Copilot
-

-
-
-
-
Claude Code
-

-
-
-
-
-
- -
- -
- - - - - - - - - - diff --git a/benchmarks/2026-04-08-14-40-57.html b/benchmarks/2026-04-08-14-40-57.html new file mode 100644 index 0000000..4b382ed --- /dev/null +++ b/benchmarks/2026-04-08-14-40-57.html @@ -0,0 +1,664 @@ + + + + + + Benchmark Run: 2026-04-08 14:40:57 + + + + + +
+
+
+

+ GitHub Copilot + vs + Claude Code +

+

Benchmark Run —

+
+
+

Project:

+

Model:

+

Stages won — GHCP: • Claude Code:

+
+
+
+ + + + +
+ + +
+ + +
+
+

Project:

+

+
+
+ + +
+

Benchmark Scores

+
+ + + + + + + + + + + + + + + + + +
BenchmarkDescriptionGHCPClaude CodeDeltaWinner
Overall Average
+
+
+ + +
+

Aggregate Scores by Stage

+
+ + + + + + + + + +
StageServiceGHCPClaude CodeWinner
+
+
+ + +
+ + +
+ + +
+ + +
+

Final Verdict

+
+
+
+
+
GitHub Copilot
+

+
+
+
+
Claude Code
+

+
+
+
+
+
+ +
+ +
+ + + + + + + + + + diff --git a/benchmarks/2026-04-08_Benchmark_Report.pdf b/benchmarks/2026-04-08_Benchmark_Report.pdf new file mode 100644 index 0000000..b161d26 Binary files /dev/null and b/benchmarks/2026-04-08_Benchmark_Report.pdf differ diff --git a/benchmarks/INSTRUCTIONS.md b/benchmarks/INSTRUCTIONS.md index ed166c8..813363f 100644 --- a/benchmarks/INSTRUCTIONS.md +++ b/benchmarks/INSTRUCTIONS.md @@ -42,74 +42,19 @@ The raw AI response is also available in the log (`"Stage N response"` → `cont Content boundaries: each multi-line value starts after `=` on the marker line and continues until the next line matching `^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} \|` (a timestamp-prefixed log entry). -### Extraction Script Template - -```python -#!/usr/bin/env python3 -"""Extract stage prompts and responses from debug log.""" -import re, os, sys - -LOG = sys.argv[1] # Path to debug log -OUT = sys.argv[2] if len(sys.argv) > 2 else "COMPARE" -TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} \|") - -with open(LOG, "r", encoding="utf-8", errors="replace") as f: - lines = f.readlines() - -def find_line(pattern, start=0): - for i in range(start, len(lines)): - if pattern in lines[i]: - return i - return -1 - -def extract_content(start_line, prefix): - first_line = lines[start_line] - idx = first_line.find(prefix + "=") - if idx == -1: - return "" - parts = [first_line[idx + len(prefix) + 1:]] - for i in range(start_line + 1, len(lines)): - if TIMESTAMP_RE.match(lines[i]): - break - parts.append(lines[i]) - return "".join(parts) - -os.makedirs(OUT, exist_ok=True) -for stage_num in range(1, 50): - prompt_line = find_line(f"Stage {stage_num} task prompt") - if prompt_line == -1: - break - # Extract post-transform output (final quality after governance transforms) - transform_line = find_line(f"Stage {stage_num} post-transform", prompt_line) - task_full_line = next((i for i in range(prompt_line, min(prompt_line+10, len(lines))) - if "task_full=" in lines[i]), -1) - transformed_full_line = -1 - if transform_line != -1: - transformed_full_line = next((i for i in range(transform_line, min(transform_line+10, len(lines))) - if "transformed_full=" in lines[i]), -1) - # Fallback to raw response if no post-transform entry (e.g., no transforms applied) - if transformed_full_line == -1: - response_line = find_line(f"Stage {stage_num} response", prompt_line) - if response_line != -1: - transformed_full_line = next((i for i in range(response_line, min(response_line+10, len(lines))) - if "content_full=" in lines[i]), -1) - content_key = "content_full" - else: - continue - else: - content_key = "transformed_full" - if task_full_line == -1 or transformed_full_line == -1: - continue - prompt = extract_content(task_full_line, "task_full") - response = extract_content(transformed_full_line, content_key) - with open(os.path.join(OUT, f"INPUT_{stage_num}.md"), "w") as f: - f.write(prompt) - with open(os.path.join(OUT, f"CP_RESPONSE_{stage_num}.md"), "w") as f: - f.write(response) - print(f"Stage {stage_num}: INPUT={len(prompt)}B CP_RESPONSE={len(response)}B (source: {content_key})") +### Extraction Script + +Use `benchmarks/extract.py` to extract from the debug log: + +```bash +python3 benchmarks/extract.py debug_20260408144057.log COMPARE ``` -Usage: `python3 extract.py debug_20260328024351.log COMPARE` +The script handles **full stage retries**: when a stage has multiple `task prompt` +entries (from the retry loop), it uses the **last** attempt's input and the final +post-transform output. Retried stages are marked with `[RETRY]` in the console output. + +See `benchmarks/extract.py` for the full implementation. --- diff --git a/benchmarks/extract.py b/benchmarks/extract.py index 27a3081..deae56b 100644 --- a/benchmarks/extract.py +++ b/benchmarks/extract.py @@ -1,6 +1,14 @@ #!/usr/bin/env python3 -"""Extract stage prompts and responses from debug log.""" -import re, os, sys +"""Extract stage prompts and responses from debug log. + +When a stage has a full retry (second ``task prompt`` entry), uses the +**last** task prompt and the final post-transform output for that stage. +This ensures benchmarks measure the retry attempt's input — the one that +includes prior QA findings — rather than the original attempt. +""" +import os +import re +import sys LOG = sys.argv[1] # Path to debug log OUT = sys.argv[2] if len(sys.argv) > 2 else "COMPARE" @@ -9,54 +17,105 @@ with open(LOG, "r", encoding="utf-8", errors="replace") as f: lines = f.readlines() -def find_line(pattern, start=0): + +def find_all_lines(pattern: str) -> list[int]: + """Return line indices of ALL occurrences of *pattern*.""" + return [i for i, line in enumerate(lines) if pattern in line] + + +def find_line(pattern: str, start: int = 0) -> int: for i in range(start, len(lines)): if pattern in lines[i]: return i return -1 -def extract_content(start_line, prefix): + +def extract_content(start_line: int, prefix: str) -> str: first_line = lines[start_line] idx = first_line.find(prefix + "=") if idx == -1: return "" - parts = [first_line[idx + len(prefix) + 1:]] + parts = [first_line[idx + len(prefix) + 1 :]] for i in range(start_line + 1, len(lines)): if TIMESTAMP_RE.match(lines[i]): break parts.append(lines[i]) return "".join(parts) + os.makedirs(OUT, exist_ok=True) + for stage_num in range(1, 50): - prompt_line = find_line(f"Stage {stage_num} task prompt") - if prompt_line == -1: + # Find ALL task prompt entries for this stage — use the LAST one + # (if a full retry happened, the second prompt includes prior QA findings) + all_prompts = find_all_lines(f"Stage {stage_num} task prompt") + if not all_prompts: break - # Extract post-transform output (final quality after governance transforms) + + prompt_line = all_prompts[-1] # Use last (retry) attempt + retried = len(all_prompts) > 1 + + # Find task_full= within a few lines of the prompt marker + task_full_line = next( + (i for i in range(prompt_line, min(prompt_line + 15, len(lines))) if "task_full=" in lines[i]), + -1, + ) + + # Find the LAST post-transform output after the last task prompt transform_line = find_line(f"Stage {stage_num} post-transform", prompt_line) - task_full_line = next((i for i in range(prompt_line, min(prompt_line+10, len(lines))) - if "task_full=" in lines[i]), -1) + # Walk forward to find the very last post-transform for this stage + # (there may be multiple from QA remediation cycles) + while True: + next_transform = find_line(f"Stage {stage_num} post-transform", transform_line + 1) + # Stop if we hit a different stage's task prompt or end of file + next_stage_prompt = find_line(f"Stage {stage_num + 1} task prompt", transform_line + 1) + if next_transform == -1: + break + if next_stage_prompt != -1 and next_transform > next_stage_prompt: + break + transform_line = next_transform + transformed_full_line = -1 + content_key = "transformed_full" + if transform_line != -1: - transformed_full_line = next((i for i in range(transform_line, min(transform_line+10, len(lines))) - if "transformed_full=" in lines[i]), -1) - # Fallback to raw response if no post-transform entry (e.g., no transforms applied) + transformed_full_line = next( + ( + i + for i in range(transform_line, min(transform_line + 15, len(lines))) + if "transformed_full=" in lines[i] + ), + -1, + ) + + # Fallback to raw response if no post-transform entry if transformed_full_line == -1: response_line = find_line(f"Stage {stage_num} response", prompt_line) if response_line != -1: - transformed_full_line = next((i for i in range(response_line, min(response_line+10, len(lines))) - if "content_full=" in lines[i]), -1) + transformed_full_line = next( + ( + i + for i in range(response_line, min(response_line + 15, len(lines))) + if "content_full=" in lines[i] + ), + -1, + ) content_key = "content_full" else: continue - else: - content_key = "transformed_full" + if task_full_line == -1 or transformed_full_line == -1: continue + prompt = extract_content(task_full_line, "task_full") response = extract_content(transformed_full_line, content_key) + + retry_tag = " [RETRY]" if retried else "" with open(os.path.join(OUT, f"INPUT_{stage_num}.md"), "w") as f: f.write(prompt) with open(os.path.join(OUT, f"CP_RESPONSE_{stage_num}.md"), "w") as f: f.write(response) - print(f"Stage {stage_num}: INPUT={len(prompt)}B CP_RESPONSE={len(response)}B (source: {content_key})") \ No newline at end of file + print( + f"Stage {stage_num}{retry_tag}: INPUT={len(prompt)}B " + f"CP_RESPONSE={len(response)}B (source: {content_key})" + ) diff --git a/benchmarks/overall.html b/benchmarks/overall.html index c0c16b2..679e78a 100644 --- a/benchmarks/overall.html +++ b/benchmarks/overall.html @@ -163,6 +163,56 @@

0.3 - def test_load_agents_from_directory(self, tmp_path): - from azext_prototype.agents.loader import load_agents_from_directory - - (tmp_path / "agent1.yaml").write_text( - "name: agent1\ndescription: A\ncapabilities: []\nsystem_prompt: test\n", - encoding="utf-8", - ) - (tmp_path / "agent2.yaml").write_text( - "name: agent2\ndescription: B\ncapabilities: []\nsystem_prompt: test\n", - encoding="utf-8", - ) - (tmp_path / "_skip.py").write_text("# skipped", encoding="utf-8") - - agents = load_agents_from_directory(str(tmp_path)) - assert len(agents) == 2 - - def test_load_agents_from_nonexistent_dir(self, tmp_path): - from azext_prototype.agents.loader import load_agents_from_directory - - agents = load_agents_from_directory(str(tmp_path / "nonexistent")) - assert agents == [] - def test_load_agents_handles_invalid_files(self, tmp_path): from azext_prototype.agents.loader import load_agents_from_directory @@ -956,19 +934,6 @@ def test_yaml_agent_missing_name_raises(self): with pytest.raises(CLIError, match="must include 'name'"): YAMLAgent({"description": "no name"}) - def test_load_yaml_agent_not_found(self): - from azext_prototype.agents.loader import load_yaml_agent - - with pytest.raises(CLIError, match="not found"): - load_yaml_agent("/nonexistent/path.yaml") - - def test_load_yaml_agent_wrong_ext(self, tmp_path): - from azext_prototype.agents.loader import load_yaml_agent - - (tmp_path / "test.txt").write_text("test") - with pytest.raises(CLIError, match=".yaml"): - load_yaml_agent(str(tmp_path / "test.txt")) - def test_load_yaml_agent_not_mapping(self, tmp_path): from azext_prototype.agents.loader import load_yaml_agent @@ -976,40 +941,6 @@ def test_load_yaml_agent_not_mapping(self, tmp_path): with pytest.raises(CLIError, match="mapping"): load_yaml_agent(str(tmp_path / "bad.yaml")) - def test_load_python_agent_not_found(self): - from azext_prototype.agents.loader import load_python_agent - - with pytest.raises(CLIError, match="not found"): - load_python_agent("/nonexistent/agent.py") - - def test_load_python_agent_wrong_ext(self, tmp_path): - from azext_prototype.agents.loader import load_python_agent - - (tmp_path / "test.yaml").write_text("test") - with pytest.raises(CLIError, match=".py"): - load_python_agent(str(tmp_path / "test.yaml")) - - def test_load_python_agent_with_agent_class(self, tmp_path): - from azext_prototype.agents.loader import load_python_agent - - code = """ -from azext_prototype.agents.base import BaseAgent, AgentCapability - -class MyAgent(BaseAgent): - def __init__(self): - super().__init__( - name="py-agent", - description="Python agent", - capabilities=[AgentCapability.DEVELOP], - system_prompt="test", - ) - -AGENT_CLASS = MyAgent -""" - (tmp_path / "my_agent.py").write_text(code, encoding="utf-8") - agent = load_python_agent(str(tmp_path / "my_agent.py")) - assert agent.name == "py-agent" - def test_load_python_agent_auto_discover(self, tmp_path): from azext_prototype.agents.loader import load_python_agent diff --git a/tests/ai/__init__.py b/tests/ai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_ai.py b/tests/ai/test_ai.py similarity index 98% rename from tests/test_ai.py rename to tests/ai/test_ai.py index a9b3e2e..4bad4e9 100644 --- a/tests/test_ai.py +++ b/tests/ai/test_ai.py @@ -268,10 +268,5 @@ def test_config_default(self): assert DEFAULT_CONFIG["ai"]["model"] == "claude-sonnet-4.5" - def test_copilot_default(self): - from azext_prototype.ai.copilot_provider import CopilotProvider - - assert CopilotProvider.DEFAULT_MODEL == "claude-sonnet-4" - def test_github_models_default(self): assert GitHubModelsProvider.DEFAULT_MODEL == "gpt-4o" diff --git a/tests/test_auth.py b/tests/ai/test_auth.py similarity index 100% rename from tests/test_auth.py rename to tests/ai/test_auth.py diff --git a/tests/test_copilot_auth.py b/tests/ai/test_copilot_auth.py similarity index 100% rename from tests/test_copilot_auth.py rename to tests/ai/test_copilot_auth.py diff --git a/tests/test_token_tracker.py b/tests/ai/test_token_tracker.py similarity index 100% rename from tests/test_token_tracker.py rename to tests/ai/test_token_tracker.py diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_config.py b/tests/config/test_config.py similarity index 100% rename from tests/test_config.py rename to tests/config/test_config.py diff --git a/tests/governance/__init__.py b/tests/governance/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/governance/anti_patterns/__init__.py b/tests/governance/anti_patterns/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_anti_patterns.py b/tests/governance/anti_patterns/test_anti_patterns.py similarity index 100% rename from tests/test_anti_patterns.py rename to tests/governance/anti_patterns/test_anti_patterns.py diff --git a/tests/governance/policies/__init__.py b/tests/governance/policies/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_policies.py b/tests/governance/policies/test_policies.py similarity index 99% rename from tests/test_policies.py rename to tests/governance/policies/test_policies.py index 7f5dde5..ef9405b 100644 --- a/tests/test_policies.py +++ b/tests/governance/policies/test_policies.py @@ -680,7 +680,7 @@ def test_no_duplicate_rule_id_target_pairs(self) -> None: def test_builtin_policies_pass_strict_validation(self) -> None: """All built-in .policy.yaml files must pass strict validation.""" - builtin_dir = Path(__file__).resolve().parent.parent / "azext_prototype" / "policies" + builtin_dir = Path(__file__).resolve().parent.parent.parent.parent / "azext_prototype" / "policies" errors = validate_policy_directory(builtin_dir) actual_errors = [e for e in errors if e.severity == "error"] warnings = [e for e in errors if e.severity == "warning"] diff --git a/tests/governance/standards/__init__.py b/tests/governance/standards/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_standards.py b/tests/governance/standards/test_standards.py similarity index 100% rename from tests/test_standards.py rename to tests/governance/standards/test_standards.py diff --git a/tests/test_governance.py b/tests/governance/test_governance.py similarity index 95% rename from tests/test_governance.py rename to tests/governance/test_governance.py index 2ded917..82da25c 100644 --- a/tests/test_governance.py +++ b/tests/governance/test_governance.py @@ -500,44 +500,6 @@ def test_cloud_architect_validates_response(self, mock_agent_context): assert "Governance warnings" not in result.content -# ------------------------------------------------------------------ # -# Credential detection patterns — exhaustive -# ------------------------------------------------------------------ # - - -class TestCredentialDetection: - """Test all credential patterns are detected.""" - - @pytest.fixture(autouse=True) - def _setup_governance(self, policy_engine, template_registry): - import azext_prototype.agents.governance as gov_mod - - gov_mod._policy_engine = policy_engine - gov_mod._template_registry = template_registry - - @pytest.mark.parametrize( - "pattern", - [ - "connection_string", - "connectionstring", - "access_key", - "accesskey", - "account_key", - "accountkey", - "shared_access_key", - "client_secret", - 'password="bad"', - "password='bad'", - "password = foo", - ], - ) - def test_credential_pattern_detected(self, pattern, governance_ctx): - warnings = governance_ctx.check_response_for_violations("terraform-agent", f"Use {pattern} for auth") - assert any( - "credential" in w.lower() or "secret" in w.lower() or "managed identity" in w.lower() for w in warnings - ), f"Pattern '{pattern}' should be detected as credential" - - # ------------------------------------------------------------------ # # GovernanceContext — edge cases # ------------------------------------------------------------------ # diff --git a/tests/test_governor.py b/tests/governance/test_governor.py similarity index 100% rename from tests/test_governor.py rename to tests/governance/test_governor.py diff --git a/tests/test_governor_agent.py b/tests/governance/test_governor_agent.py similarity index 100% rename from tests/test_governor_agent.py rename to tests/governance/test_governor_agent.py diff --git a/tests/governance/transforms/__init__.py b/tests/governance/transforms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_transforms.py b/tests/governance/transforms/test_transforms.py similarity index 85% rename from tests/test_transforms.py rename to tests/governance/transforms/test_transforms.py index 44fa4ee..cdb9709 100644 --- a/tests/test_transforms.py +++ b/tests/governance/transforms/test_transforms.py @@ -393,3 +393,61 @@ def test_cross_file_unused_remote_state_removed(self): ) assert "TFM-TF-001" in ids assert "terraform_remote_state" not in result + + +# ------------------------------------------------------------------ +# ReDoS safety: brace-counting replaces nested quantifier regex +# ------------------------------------------------------------------ + + +class TestBraceCountingSafety: + """Transform handlers must use brace counting, not nested-quantifier regex.""" + + def test_response_export_values_pathological_input(self): + """Long line with no newlines must complete in <1 second (no backtracking).""" + import time + + # Pathological: very long body with no newlines, followed by closing brace + long_body = "x" * 50000 + content = f'resource "azapi_resource" "kv" {{\n type = "Microsoft.KeyVault/vaults@2023-07-01"\n {long_body}\n}}\n' + + start = time.monotonic() + result = _add_response_export_values(content) + elapsed = time.monotonic() - start + + assert elapsed < 1.0, f"_add_response_export_values took {elapsed:.2f}s on pathological input (ReDoS?)" + assert 'response_export_values = ["*"]' in result + + def test_resource_group_parent_id_pathological_input(self): + """Long line with no newlines must complete in <1 second (no backtracking).""" + import time + + long_body = "x" * 50000 + content = f'resource "azapi_resource" "rg" {{\n type = "Microsoft.Resources/resourceGroups@2024-03-01"\n name = var.rg\n {long_body}\n}}\n' + + start = time.monotonic() + result = _add_resource_group_parent_id(content) + elapsed = time.monotonic() - start + + assert elapsed < 1.0, f"_add_resource_group_parent_id took {elapsed:.2f}s on pathological input (ReDoS?)" + assert "parent_id" in result + + def test_find_azapi_blocks_nested_braces(self): + """Brace counting must handle nested blocks correctly.""" + from azext_prototype.governance.transforms import _find_azapi_blocks + + content = """resource "azapi_resource" "kv" { + type = "Microsoft.KeyVault/vaults@2023-07-01" + body = { + properties = { + tenantId = var.tenant_id + } + } +} +""" + blocks = _find_azapi_blocks(content) + assert len(blocks) == 1 + start, end, name, block_text = blocks[0] + assert name == "kv" + assert block_text.startswith('resource "azapi_resource"') + assert block_text.rstrip().endswith("}") diff --git a/tests/knowledge/__init__.py b/tests/knowledge/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_knowledge.py b/tests/knowledge/test_knowledge.py similarity index 100% rename from tests/test_knowledge.py rename to tests/knowledge/test_knowledge.py diff --git a/tests/test_resource_metadata.py b/tests/knowledge/test_resource_metadata.py similarity index 100% rename from tests/test_resource_metadata.py rename to tests/knowledge/test_resource_metadata.py diff --git a/tests/test_web_search.py b/tests/knowledge/test_web_search.py similarity index 100% rename from tests/test_web_search.py rename to tests/knowledge/test_web_search.py diff --git a/tests/mcp/__init__.py b/tests/mcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_mcp.py b/tests/mcp/test_mcp.py similarity index 100% rename from tests/test_mcp.py rename to tests/mcp/test_mcp.py diff --git a/tests/test_mcp_integration.py b/tests/mcp/test_mcp_integration.py similarity index 100% rename from tests/test_mcp_integration.py rename to tests/mcp/test_mcp_integration.py diff --git a/tests/naming/__init__.py b/tests/naming/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_naming.py b/tests/naming/test_naming.py similarity index 100% rename from tests/test_naming.py rename to tests/naming/test_naming.py diff --git a/tests/parsers/__init__.py b/tests/parsers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_binary_reader.py b/tests/parsers/test_binary_reader.py similarity index 100% rename from tests/test_binary_reader.py rename to tests/parsers/test_binary_reader.py diff --git a/tests/test_parse_requirements.py b/tests/parsers/test_parse_requirements.py similarity index 100% rename from tests/test_parse_requirements.py rename to tests/parsers/test_parse_requirements.py diff --git a/tests/test_parsers.py b/tests/parsers/test_parsers.py similarity index 100% rename from tests/test_parsers.py rename to tests/parsers/test_parsers.py diff --git a/tests/stages/__init__.py b/tests/stages/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/stages/test_backlog_push.py b/tests/stages/test_backlog_push.py new file mode 100644 index 0000000..4dadf8d --- /dev/null +++ b/tests/stages/test_backlog_push.py @@ -0,0 +1,464 @@ +"""Tests for backlog push helpers — GitHub and Azure DevOps work item creation. + +Tier 2: Conditional branches with multiple paths. + +Covers: +- check_gh_auth(): success, failure, FileNotFoundError +- check_devops_ext(): success, failure, FileNotFoundError +- format_github_body(): description, acceptance criteria, tasks (str/dict/done), + children with nested tasks, labels from epic/effort +- format_devops_description(): description, AC, tasks (str/dict/done), effort +- push_github_issue(): success (URL parsing), failure (returncode != 0), + FileNotFoundError, labels from epic/effort, no-epic title +- push_devops_feature/story/task(): success with JSON, failure, + FileNotFoundError, parent linking, JSON decode error +- _link_parent(): success, failure (swallowed) +""" + +import json +from unittest.mock import MagicMock, patch + +from azext_prototype.stages.backlog_push import ( + _link_parent, + check_devops_ext, + check_gh_auth, + format_devops_description, + format_github_body, + push_devops_feature, + push_devops_story, + push_devops_task, + push_github_issue, +) + +# ------------------------------------------------------------------ +# Auth checks +# ------------------------------------------------------------------ + + +class TestCheckGhAuth: + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_authenticated(self, mock_run): + mock_run.return_value = MagicMock(returncode=0) + assert check_gh_auth() is True + mock_run.assert_called_once() + cmd = mock_run.call_args[0][0] + assert cmd == ["gh", "auth", "status"] + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_not_authenticated(self, mock_run): + mock_run.return_value = MagicMock(returncode=1) + assert check_gh_auth() is False + + @patch("azext_prototype.stages.backlog_push.subprocess.run", side_effect=FileNotFoundError) + def test_gh_not_installed(self, mock_run): + assert check_gh_auth() is False + + +class TestCheckDevopsExt: + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_installed(self, mock_run): + mock_run.return_value = MagicMock(returncode=0) + assert check_devops_ext() is True + cmd = mock_run.call_args[0][0] + assert cmd == ["az", "devops", "--help"] + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_not_installed(self, mock_run): + mock_run.return_value = MagicMock(returncode=1) + assert check_devops_ext() is False + + @patch("azext_prototype.stages.backlog_push.subprocess.run", side_effect=FileNotFoundError) + def test_az_not_found(self, mock_run): + assert check_devops_ext() is False + + +# ------------------------------------------------------------------ +# Formatters — GitHub +# ------------------------------------------------------------------ + + +class TestFormatGithubBody: + def test_description_section(self): + body = format_github_body({"description": "Build an API"}) + assert "## Description" in body + assert "Build an API" in body + + def test_no_description(self): + body = format_github_body({"title": "Something"}) + assert "## Description" not in body + + def test_acceptance_criteria(self): + body = format_github_body({"acceptance_criteria": ["AC1", "AC2"]}) + assert "## Acceptance Criteria" in body + assert "1. AC1" in body + assert "2. AC2" in body + + def test_empty_acceptance_criteria(self): + body = format_github_body({"acceptance_criteria": []}) + assert "## Acceptance Criteria" not in body + + def test_tasks_as_strings(self): + body = format_github_body({"tasks": ["Task A", "Task B"]}) + assert "## Tasks" in body + assert "- [ ] Task A" in body + assert "- [ ] Task B" in body + + def test_tasks_as_dicts_unchecked(self): + body = format_github_body({"tasks": [{"title": "Task A", "done": False}]}) + assert "- [ ] Task A" in body + + def test_tasks_as_dicts_checked(self): + body = format_github_body({"tasks": [{"title": "Task Done", "done": True}]}) + assert "- [x] Task Done" in body + + def test_empty_tasks(self): + body = format_github_body({"tasks": []}) + assert "## Tasks" not in body + + def test_children_section(self): + item = { + "children": [ + { + "title": "Story 1", + "effort": "M", + "description": "Story desc", + "acceptance_criteria": ["AC1"], + "tasks": ["Sub task"], + } + ] + } + body = format_github_body(item) + assert "## Stories" in body + assert "### Story 1 [M]" in body + assert "Story desc" in body + assert "1. AC1" in body + assert "- [ ] Sub task" in body + + def test_children_with_dict_tasks(self): + item = { + "children": [ + { + "title": "Story", + "effort": "S", + "tasks": [{"title": "Done task", "done": True}], + } + ] + } + body = format_github_body(item) + assert "- [x] Done task" in body + + def test_labels_from_epic_and_effort(self): + body = format_github_body({"epic": "Infrastructure", "effort": "L"}) + assert "`infrastructure`" in body + assert "`effort/L`" in body + + def test_no_labels_without_epic_and_effort(self): + body = format_github_body({"title": "Plain item"}) + assert "**Labels:**" not in body + + +# ------------------------------------------------------------------ +# Formatters — Azure DevOps +# ------------------------------------------------------------------ + + +class TestFormatDevopsDescription: + def test_description_paragraph(self): + html = format_devops_description({"description": "Build API"}) + assert "

Build API

" in html + + def test_no_description(self): + html = format_devops_description({"title": "X"}) + assert "

" not in html + + def test_acceptance_criteria(self): + html = format_devops_description({"acceptance_criteria": ["AC1", "AC2"]}) + assert "

Acceptance Criteria

" in html + assert "
  • AC1
  • " in html + assert "
  • AC2
  • " in html + + def test_empty_acceptance_criteria(self): + html = format_devops_description({"acceptance_criteria": []}) + assert "Acceptance Criteria" not in html + + def test_tasks_as_strings(self): + html = format_devops_description({"tasks": ["T1"]}) + assert "

    Tasks

    " in html + assert "
  • T1
  • " in html + + def test_tasks_as_dicts_done(self): + html = format_devops_description({"tasks": [{"title": "T", "done": True}]}) + assert "☑" in html + assert "T" in html + + def test_tasks_as_dicts_not_done(self): + html = format_devops_description({"tasks": [{"title": "T", "done": False}]}) + assert "☐" in html + + def test_effort_paragraph(self): + html = format_devops_description({"effort": "XL"}) + assert "Effort: XL" in html + + def test_no_effort(self): + html = format_devops_description({"title": "X"}) + assert "Effort:" not in html + + +# ------------------------------------------------------------------ +# push_github_issue() +# ------------------------------------------------------------------ + + +class TestPushGithubIssue: + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_success(self, mock_run): + mock_run.return_value = MagicMock( + returncode=0, + stdout="https://github.com/contoso/myproj/issues/42\n", + ) + result = push_github_issue("contoso", "myproj", {"title": "Add Auth"}) + assert result["url"] == "https://github.com/contoso/myproj/issues/42" + assert result["number"] == "42" + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_title_with_epic(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="https://github.com/o/p/issues/1\n") + push_github_issue("o", "p", {"title": "Setup VNet", "epic": "Infrastructure"}) + cmd = mock_run.call_args[0][0] + assert "[Infrastructure] Setup VNet" in cmd + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_title_without_epic(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="https://github.com/o/p/issues/1\n") + push_github_issue("o", "p", {"title": "Plain task"}) + cmd = mock_run.call_args[0][0] + assert "Plain task" in cmd + # No bracket prefix + for arg in cmd: + if arg == "Plain task": + break + else: + # If full_title is used, it should not have brackets + title_idx = cmd.index("--title") + 1 + assert not cmd[title_idx].startswith("[") + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_labels_from_params_and_item(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="https://github.com/o/p/issues/1\n") + push_github_issue( + "o", + "p", + {"title": "T", "effort": "M", "epic": "Networking"}, + labels=["prototype", "poc"], + ) + cmd = mock_run.call_args[0][0] + # Should have --label for each label + label_indices = [i for i, v in enumerate(cmd) if v == "--label"] + labels = [cmd[i + 1] for i in label_indices] + assert "prototype" in labels + assert "poc" in labels + assert "effort/M" in labels + assert "networking" in labels + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_failure_stderr(self, mock_run): + mock_run.return_value = MagicMock( + returncode=1, + stderr="HTTP 422: Validation Failed", + stdout="", + ) + result = push_github_issue("o", "p", {"title": "T"}) + assert "error" in result + assert "422" in result["error"] + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_failure_stdout_fallback(self, mock_run): + mock_run.return_value = MagicMock( + returncode=1, + stderr="", + stdout="something went wrong", + ) + result = push_github_issue("o", "p", {"title": "T"}) + assert "something went wrong" in result["error"] + + @patch("azext_prototype.stages.backlog_push.subprocess.run", side_effect=FileNotFoundError) + def test_gh_not_found(self, mock_run): + result = push_github_issue("o", "p", {"title": "T"}) + assert "error" in result + assert "gh CLI not found" in result["error"] + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_repo_flag(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="https://github.com/o/p/issues/1\n") + push_github_issue("contoso", "my-repo", {"title": "T"}) + cmd = mock_run.call_args[0][0] + repo_idx = cmd.index("--repo") + 1 + assert cmd[repo_idx] == "contoso/my-repo" + + +# ------------------------------------------------------------------ +# push_devops_feature / push_devops_story / push_devops_task +# ------------------------------------------------------------------ + + +class TestPushDevopsWorkItem: + def _mock_success(self, wi_id=123, url="https://dev.azure.com/o/p/_workitems/edit/123"): + return MagicMock( + returncode=0, + stdout=json.dumps( + { + "id": wi_id, + "_links": {"html": {"href": url}}, + } + ), + ) + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_feature_success(self, mock_run): + mock_run.return_value = self._mock_success(wi_id=10) + result = push_devops_feature("myorg", "myproj", {"title": "Infra Setup"}) + assert result["id"] == 10 + assert "dev.azure.com" in result["url"] + # Check work item type + cmd = mock_run.call_args[0][0] + type_idx = cmd.index("--type") + 1 + assert cmd[type_idx] == "Feature" + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_story_success(self, mock_run): + mock_run.return_value = self._mock_success(wi_id=20) + result = push_devops_story("myorg", "myproj", {"title": "API Story"}) + assert result["id"] == 20 + cmd = mock_run.call_args[0][0] + type_idx = cmd.index("--type") + 1 + assert cmd[type_idx] == "User Story" + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + @patch("azext_prototype.stages.backlog_push._link_parent") + def test_story_with_parent(self, mock_link, mock_run): + mock_run.return_value = self._mock_success(wi_id=20) + push_devops_story("org", "proj", {"title": "Story"}, parent_id=10) + mock_link.assert_called_once_with("org", "proj", 20, 10) + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_task_success(self, mock_run): + mock_run.return_value = self._mock_success(wi_id=30) + result = push_devops_task("org", "proj", {"title": "Sub task"}) + assert result["id"] == 30 + cmd = mock_run.call_args[0][0] + type_idx = cmd.index("--type") + 1 + assert cmd[type_idx] == "Task" + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + @patch("azext_prototype.stages.backlog_push._link_parent") + def test_task_with_parent(self, mock_link, mock_run): + mock_run.return_value = self._mock_success(wi_id=30) + push_devops_task("org", "proj", {"title": "Task"}, parent_id=20) + mock_link.assert_called_once_with("org", "proj", 30, 20) + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_failure(self, mock_run): + mock_run.return_value = MagicMock( + returncode=1, + stderr="TF401019: Access denied", + stdout="", + ) + result = push_devops_feature("org", "proj", {"title": "T"}) + assert "error" in result + assert "Access denied" in result["error"] + + @patch("azext_prototype.stages.backlog_push.subprocess.run", side_effect=FileNotFoundError) + def test_az_not_found(self, mock_run): + result = push_devops_feature("org", "proj", {"title": "T"}) + assert "error" in result + assert "az CLI not found" in result["error"] + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_json_decode_error(self, mock_run): + mock_run.return_value = MagicMock( + returncode=0, + stdout="not valid json", + ) + result = push_devops_feature("org", "proj", {"title": "T"}) + # Falls back to raw stdout + assert result["url"] == "" + assert result["id"] == "not valid json" + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_epic_area_path(self, mock_run): + mock_run.return_value = self._mock_success() + push_devops_feature("org", "proj", {"title": "T", "epic": "Infrastructure"}) + cmd = mock_run.call_args[0][0] + area_idx = cmd.index("--area") + 1 + assert cmd[area_idx] == "proj\\Infrastructure" + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_no_epic_no_area(self, mock_run): + mock_run.return_value = self._mock_success() + push_devops_feature("org", "proj", {"title": "T"}) + cmd = mock_run.call_args[0][0] + assert "--area" not in cmd + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_org_url_format(self, mock_run): + mock_run.return_value = self._mock_success() + push_devops_feature("contoso", "myproj", {"title": "T"}) + cmd = mock_run.call_args[0][0] + org_idx = cmd.index("--org") + 1 + assert cmd[org_idx] == "https://dev.azure.com/contoso" + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_url_fallback_to_data_url(self, mock_run): + mock_run.return_value = MagicMock( + returncode=0, + stdout=json.dumps( + { + "id": 99, + "_links": {}, + "url": "https://dev.azure.com/o/p/_apis/wit/workItems/99", + } + ), + ) + result = push_devops_feature("o", "p", {"title": "T"}) + assert result["url"] == "https://dev.azure.com/o/p/_apis/wit/workItems/99" + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + @patch("azext_prototype.stages.backlog_push._link_parent") + def test_no_parent_link_when_parent_id_none(self, mock_link, mock_run): + mock_run.return_value = self._mock_success(wi_id=50) + push_devops_story("o", "p", {"title": "T"}, parent_id=None) + mock_link.assert_not_called() + + +# ------------------------------------------------------------------ +# _link_parent() +# ------------------------------------------------------------------ + + +class TestLinkParent: + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_link_parent_success(self, mock_run): + mock_run.return_value = MagicMock(returncode=0) + _link_parent("org", "proj", child_id=20, parent_id=10) + cmd = mock_run.call_args[0][0] + assert "relation" in cmd + assert "add" in cmd + assert "--id" in cmd + assert "20" in cmd + assert "--target-id" in cmd + assert "10" in cmd + assert "--relation-type" in cmd + assert "parent" in cmd + + @patch("azext_prototype.stages.backlog_push.subprocess.run", side_effect=FileNotFoundError) + def test_link_parent_file_not_found(self, mock_run): + # Should not raise + _link_parent("org", "proj", child_id=20, parent_id=10) + + @patch("azext_prototype.stages.backlog_push.subprocess.run") + def test_link_parent_subprocess_error(self, mock_run): + import subprocess + + mock_run.side_effect = subprocess.SubprocessError("broken pipe") + # Should not raise + _link_parent("org", "proj", child_id=20, parent_id=10) diff --git a/tests/test_generate_backlog.py b/tests/stages/test_backlog_session.py similarity index 78% rename from tests/test_generate_backlog.py rename to tests/stages/test_backlog_session.py index 6f07d14..c8b3398 100644 --- a/tests/test_generate_backlog.py +++ b/tests/stages/test_backlog_session.py @@ -1,20 +1,660 @@ -"""Tests for backlog generation — BacklogState, BacklogSession, push helpers, scope injection. - -Keeps the new backlog tests separate from test_custom.py to prevent file bloat. +"""Tests for backlog_session.py — branch coverage for cache vs regeneration, +quick mode vs interactive, item enrichment, push routing (GitHub vs DevOps), +review loop, /add command handling, _parse_items, _mutate_items, _save_backlog_md, +and slash commands. """ -import json +from __future__ import annotations + +from pathlib import Path from unittest.mock import MagicMock, patch import pytest -import yaml -from knack.util import CLIError + +from azext_prototype.agents.base import AgentCapability, AgentContext _CUSTOM_MODULE = "azext_prototype.custom" +_SESSION_MODULE = "azext_prototype.stages.backlog_session" + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def backlog_context(project_with_design, sample_config): + provider = MagicMock() + provider.provider_name = "github-models" + provider.default_model = "gpt-4o" + provider.chat.return_value = MagicMock( + content='[{"epic": "API", "title": "Build REST API", "description": "desc", ' + '"acceptance_criteria": ["AC1"], "tasks": [{"title": "T1", "done": false}], ' + '"effort": "M", "status": "todo"}]', + model="test", + usage={"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}, + ) + return AgentContext( + project_config=sample_config, + project_dir=str(project_with_design), + ai_provider=provider, + ) + + +@pytest.fixture +def backlog_registry(): + registry = MagicMock() + + mock_pm = MagicMock() + mock_pm.name = "project-manager" + mock_pm.get_system_messages.return_value = [] + mock_pm._temperature = 0.3 + mock_pm._max_tokens = 8192 + + mock_qa = MagicMock() + mock_qa.name = "qa-engineer" + + def find_by_cap(cap): + mapping = { + AgentCapability.BACKLOG_GENERATION: [mock_pm], + AgentCapability.QA: [mock_qa], + } + return mapping.get(cap, []) + + registry.find_by_capability.side_effect = find_by_cap + return registry + + +def _make_session(ctx, registry, items_response=None): + from azext_prototype.stages.backlog_session import BacklogSession + + session = BacklogSession(ctx, registry) + + # Override the AI response if specified AFTER session is created + if items_response is not None: + ctx.ai_provider.chat.return_value = MagicMock( + content=items_response, + model="test", + usage={"prompt_tokens": 50, "completion_tokens": 100, "total_tokens": 150}, + ) + + return session + + +# ------------------------------------------------------------------ +# BacklogResult +# ------------------------------------------------------------------ + + +class TestBacklogResult: + def test_defaults(self): + from azext_prototype.stages.backlog_session import BacklogResult + + result = BacklogResult() + assert result.items_generated == 0 + assert result.items_pushed == 0 + assert result.items_failed == 0 + assert result.push_urls == [] + assert result.cancelled is False + + def test_with_values(self): + from azext_prototype.stages.backlog_session import BacklogResult + + result = BacklogResult( + items_generated=5, + items_pushed=3, + items_failed=1, + push_urls=["https://github.com/issues/1"], + cancelled=False, + ) + assert result.items_generated == 5 + assert len(result.push_urls) == 1 + + +# ------------------------------------------------------------------ +# _parse_items +# ------------------------------------------------------------------ + + +class TestParseItems: + def test_valid_json_array(self): + from azext_prototype.stages.backlog_session import BacklogSession + + items = BacklogSession._parse_items('[{"title": "A"}, {"title": "B"}]') + assert len(items) == 2 + assert items[0]["title"] == "A" + + def test_json_with_fences(self): + from azext_prototype.stages.backlog_session import BacklogSession + + items = BacklogSession._parse_items('```json\n[{"title": "X"}]\n```') + assert len(items) == 1 + assert items[0]["title"] == "X" + + def test_invalid_json_returns_empty(self): + from azext_prototype.stages.backlog_session import BacklogSession + + items = BacklogSession._parse_items("not json at all") + assert items == [] + + def test_json_object_not_array_returns_empty(self): + from azext_prototype.stages.backlog_session import BacklogSession + + items = BacklogSession._parse_items('{"title": "single"}') + assert items == [] + + +# ------------------------------------------------------------------ +# Run — cached items path +# ------------------------------------------------------------------ + + +class TestRunCachedItems: + def test_cached_items_skip_generation(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + + # Pre-populate cached items + session._backlog_state._state["items"] = [ + {"epic": "API", "title": "Build API", "status": "todo"}, + ] + session._backlog_state._state["context_hash"] = "" + session._backlog_state.matches_context = MagicMock(return_value=True) + + output = [] + result = session.run( + design_context="arch", + input_fn=lambda p: "done", + print_fn=lambda m: output.append(m), + ) + assert result.items_generated == 1 + assert not result.cancelled + + +# ------------------------------------------------------------------ +# Run — generation path +# ------------------------------------------------------------------ + + +class TestRunGeneration: + def test_generation_creates_items(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + + output = [] + result = session.run( + design_context="Build an API with Cosmos DB", + input_fn=lambda p: "done", + print_fn=lambda m: output.append(m), + ) + assert result.items_generated >= 1 + assert not result.cancelled + + def test_no_pm_agent_cancels(self, backlog_context): + registry = MagicMock() + registry.find_by_capability.return_value = [] + + from azext_prototype.stages.backlog_session import BacklogSession + + session = BacklogSession(backlog_context, registry) + + result = session.run( + design_context="test", + input_fn=lambda p: "done", + print_fn=lambda m: None, + ) + assert result.cancelled is True + + def test_no_ai_provider_cancels(self, backlog_context, backlog_registry): + backlog_context.ai_provider = None + + from azext_prototype.stages.backlog_session import BacklogSession + + session = BacklogSession(backlog_context, backlog_registry) + + result = session.run( + design_context="test", + input_fn=lambda p: "done", + print_fn=lambda m: None, + ) + assert result.cancelled is True + + def test_empty_ai_response_cancels(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry, items_response="not json") + + result = session.run( + design_context="test", + input_fn=lambda p: "done", + print_fn=lambda m: None, + ) + assert result.cancelled is True + + +# ------------------------------------------------------------------ +# Run — quick mode +# ------------------------------------------------------------------ + + +class TestRunQuickMode: + @patch("azext_prototype.stages.backlog_session.check_gh_auth", return_value=True) + @patch("azext_prototype.stages.backlog_session.push_github_issue", return_value={"url": "https://gh/1"}) + def test_quick_mode_confirm_pushes(self, mock_push, mock_auth, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + + result = session.run( + design_context="test", + provider="github", + org="myorg", + project="myrepo", + quick=True, + input_fn=lambda p: "y", + print_fn=lambda m: None, + ) + assert result.items_pushed >= 1 + + def test_quick_mode_cancel(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + + result = session.run( + design_context="test", + quick=True, + input_fn=lambda p: "n", + print_fn=lambda m: None, + ) + assert result.cancelled is True + + def test_quick_mode_eof(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + + def raise_eof(p): + raise EOFError + + result = session.run( + design_context="test", + quick=True, + input_fn=raise_eof, + print_fn=lambda m: None, + ) + assert result.cancelled is True + + +# ------------------------------------------------------------------ +# Interactive review loop +# ------------------------------------------------------------------ + + +class TestInteractiveLoop: + def test_quit_in_loop(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + + inputs = iter(["quit"]) + result = session.run( + design_context="test", + input_fn=lambda p: next(inputs), + print_fn=lambda m: None, + ) + assert result.cancelled is True + + def test_slash_quit(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + + inputs = iter(["/quit"]) + result = session.run( + design_context="test", + input_fn=lambda p: next(inputs), + print_fn=lambda m: None, + ) + assert result.cancelled is True + + def test_eof_in_loop(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + + call_count = [0] + + def input_fn(p): + call_count[0] += 1 + if call_count[0] > 1: + raise EOFError + return "not a command" + + # Override to return a parseable mutation + backlog_context.ai_provider.chat.side_effect = [ + # Initial generation + MagicMock( + content='[{"epic": "A", "title": "T1"}]', + model="test", + usage={}, + ), + # Mutation call + MagicMock( + content='[{"epic": "A", "title": "T1 updated"}]', + model="test", + usage={}, + ), + ] + + result = session.run( + design_context="test", + input_fn=input_fn, + print_fn=lambda m: None, + ) + assert result.cancelled is True + + def test_empty_input_ignored(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + + inputs = iter(["", "", "done"]) + result = session.run( + design_context="test", + input_fn=lambda p: next(inputs), + print_fn=lambda m: None, + ) + assert not result.cancelled + + +# ------------------------------------------------------------------ +# Slash commands +# ------------------------------------------------------------------ + + +class TestSlashCommands: + def _run_with_commands(self, ctx, registry, commands): + session = _make_session(ctx, registry) + inputs = iter(commands + ["done"]) + output = [] + session.run( + design_context="test", + input_fn=lambda p: next(inputs), + print_fn=lambda m: output.append(m), + ) + return output + + def test_list_command(self, backlog_context, backlog_registry): + output = self._run_with_commands(backlog_context, backlog_registry, ["/list"]) + # Should have printed the backlog summary + assert any("Backlog" in str(m) or "item" in str(m).lower() for m in output) + + def test_show_valid_index(self, backlog_context, backlog_registry): + output = self._run_with_commands(backlog_context, backlog_registry, ["/show 1"]) + # Should show item details + assert len(output) > 0 + + def test_show_invalid_arg(self, backlog_context, backlog_registry): + output = self._run_with_commands(backlog_context, backlog_registry, ["/show"]) + assert any("Usage" in str(m) for m in output) + + def test_remove_valid_index(self, backlog_context, backlog_registry): + output = self._run_with_commands(backlog_context, backlog_registry, ["/remove 1"]) + assert any("Removed" in str(m) for m in output) + + def test_remove_invalid_arg(self, backlog_context, backlog_registry): + output = self._run_with_commands(backlog_context, backlog_registry, ["/remove"]) + assert any("Usage" in str(m) for m in output) + + def test_help_command(self, backlog_context, backlog_registry): + output = self._run_with_commands(backlog_context, backlog_registry, ["/help"]) + assert any("Available commands" in str(m) for m in output) + + def test_status_command(self, backlog_context, backlog_registry): + output = self._run_with_commands(backlog_context, backlog_registry, ["/status"]) + assert len(output) > 0 + + def test_preview_command(self, backlog_context, backlog_registry): + output = self._run_with_commands( + backlog_context, + backlog_registry, + ["/preview"], + ) + assert len(output) > 0 + + +# ------------------------------------------------------------------ +# _push_all — provider routing +# ------------------------------------------------------------------ + + +class TestPushAll: + def _set_items_with_status(self, session, items, statuses=None): + """Helper to properly set items with matching push_status and push_results arrays.""" + session._backlog_state._state["items"] = items + n = len(items) + session._backlog_state._state["push_status"] = statuses or ["pending"] * n + session._backlog_state._state["push_results"] = [None] * n + + @patch("azext_prototype.stages.backlog_session.check_gh_auth", return_value=False) + def test_github_auth_failure(self, mock_auth, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + self._set_items_with_status(session, [{"title": "A", "status": "todo"}]) + + output = [] + result = session._push_all("github", "org", "repo", lambda m: output.append(m), False) + assert result.cancelled is True + + @patch("azext_prototype.stages.backlog_session.check_devops_ext", return_value=False) + def test_devops_ext_not_available(self, mock_ext, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + self._set_items_with_status(session, [{"title": "A", "status": "todo"}]) + + output = [] + result = session._push_all("devops", "org", "project", lambda m: output.append(m), False) + assert result.cancelled is True + + @patch("azext_prototype.stages.backlog_session.check_gh_auth", return_value=True) + @patch("azext_prototype.stages.backlog_session.push_github_issue", return_value={"url": "https://gh/1"}) + def test_github_push_success(self, mock_push, mock_auth, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + self._set_items_with_status(session, [{"title": "A", "status": "todo"}]) + + output = [] + result = session._push_all("github", "org", "repo", lambda m: output.append(m), False) + assert result.items_pushed == 1 + assert "https://gh/1" in result.push_urls + + @patch("azext_prototype.stages.backlog_session.check_gh_auth", return_value=True) + @patch("azext_prototype.stages.backlog_session.push_github_issue", return_value={"error": "rate limited"}) + def test_github_push_failure(self, mock_push, mock_auth, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + self._set_items_with_status(session, [{"title": "A", "status": "todo"}]) + + output = [] + result = session._push_all("github", "org", "repo", lambda m: output.append(m), False) + assert result.items_failed == 1 + + @patch("azext_prototype.stages.backlog_session.check_devops_ext", return_value=True) + @patch("azext_prototype.stages.backlog_session.push_devops_feature") + @patch("azext_prototype.stages.backlog_session.push_devops_story") + @patch("azext_prototype.stages.backlog_session.push_devops_task") + def test_devops_push_with_children( + self, mock_task, mock_story, mock_feature, mock_ext, backlog_context, backlog_registry + ): + mock_feature.return_value = {"url": "https://devops/1", "id": "1"} + mock_story.return_value = {"url": "https://devops/s1", "id": "2"} + mock_task.return_value = {"url": "https://devops/t1"} + + session = _make_session(backlog_context, backlog_registry) + items = [ + { + "title": "Feature A", + "status": "todo", + "children": [ + { + "title": "Story 1", + "tasks": [{"title": "Task 1", "done": False}], + } + ], + } + ] + self._set_items_with_status(session, items) + + output = [] + result = session._push_all("devops", "org", "proj", lambda m: output.append(m), False) + assert result.items_pushed == 1 + mock_story.assert_called_once() + mock_task.assert_called_once() + + def test_no_pending_items(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + self._set_items_with_status(session, [{"title": "A", "status": "pushed"}], statuses=["pushed"]) + + output = [] + result = session._push_all("github", "org", "repo", lambda m: output.append(m), False) + # items_pushed reflects historical pushed count (1 already pushed) + assert result.items_pushed == 1 + assert any("No pending" in str(m) for m in output) + + +# ------------------------------------------------------------------ +# _enrich_new_item +# ------------------------------------------------------------------ + + +class TestEnrichNewItem: + def test_no_pm_agent_returns_bare(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + session._pm_agent = None + + item = session._enrich_new_item("Build rate limiter") + assert item["title"] == "Build rate limiter" + assert item["epic"] == "Added" + + def test_no_ai_provider_returns_bare(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + session._context.ai_provider = None + + item = session._enrich_new_item("Build rate limiter") + assert item["title"] == "Build rate limiter" + + def test_successful_enrichment(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + + backlog_context.ai_provider.chat.return_value = MagicMock( + content='{"epic": "Performance", "title": "API Rate Limiter", ' + '"description": "Implement rate limiting", ' + '"acceptance_criteria": ["Limit 100 req/s"], ' + '"tasks": ["Add middleware"], "effort": "M"}', + model="test", + usage={}, + ) + + item = session._enrich_new_item("Build rate limiter") + assert item["title"] == "API Rate Limiter" + assert item["epic"] == "Performance" + + def test_enrichment_failure_falls_back(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + + backlog_context.ai_provider.chat.side_effect = Exception("AI failed") + + item = session._enrich_new_item("Build rate limiter") + assert item["title"] == "Build rate limiter" + assert item["epic"] == "Added" + + def test_enrichment_with_fenced_json(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + + backlog_context.ai_provider.chat.return_value = MagicMock( + content='```json\n{"epic": "Infra", "title": "Add CDN"}\n```', + model="test", + usage={}, + ) + + item = session._enrich_new_item("Add CDN") + assert item["title"] == "Add CDN" + assert item["epic"] == "Infra" + + +# ------------------------------------------------------------------ +# _mutate_items +# ------------------------------------------------------------------ + + +class TestMutateItems: + def test_successful_mutation(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + session._backlog_state._state["items"] = [{"title": "Old title"}] + + backlog_context.ai_provider.chat.return_value = MagicMock( + content='[{"title": "Updated title"}]', + model="test", + usage={}, + ) + + result = session._mutate_items("Change title to Updated title", "design context") + assert result is not None + assert result[0]["title"] == "Updated title" + + def test_no_pm_agent_returns_none(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + session._pm_agent = None + + result = session._mutate_items("Change title", "ctx") + assert result is None + + def test_no_ai_provider_returns_none(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + session._context.ai_provider = None + + result = session._mutate_items("Change title", "ctx") + assert result is None + + +# ------------------------------------------------------------------ +# _save_backlog_md +# ------------------------------------------------------------------ + + +class TestSaveBacklogMd: + def test_saves_markdown(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + session._backlog_state._state["items"] = [ + {"epic": "API", "title": "Build endpoints", "effort": "M", "description": "REST API"}, + ] + + output = [] + session._save_backlog_md(lambda m: output.append(m)) + + md_path = Path(backlog_context.project_dir) / "concept" / "docs" / "BACKLOG.md" + assert md_path.exists() + content = md_path.read_text() + assert "Build endpoints" in content + + def test_empty_items_prints_message(self, backlog_context, backlog_registry): + session = _make_session(backlog_context, backlog_registry) + session._backlog_state._state["items"] = [] + + output = [] + session._save_backlog_md(lambda m: output.append(m)) + assert any("No items" in str(m) for m in output) + + +# --- Additional imports from merged flat test --- +from azext_prototype.agents.base import AgentContext +from azext_prototype.agents.builtin import register_all_builtin +from azext_prototype.agents.registry import AgentRegistry +from azext_prototype.ai.provider import AIResponse +from azext_prototype.config import ProjectConfig +from azext_prototype.custom import _generate_templates +from azext_prototype.custom import _load_discovery_scope +from azext_prototype.custom import prototype_generate_backlog +from azext_prototype.custom import prototype_generate_docs +from azext_prototype.custom import prototype_generate_speckit +from azext_prototype.stages.backlog_push import _link_parent +from azext_prototype.stages.backlog_push import check_devops_ext +from azext_prototype.stages.backlog_push import check_gh_auth +from azext_prototype.stages.backlog_push import format_devops_description +from azext_prototype.stages.backlog_push import format_github_body +from azext_prototype.stages.backlog_push import push_devops_feature +from azext_prototype.stages.backlog_push import push_devops_story +from azext_prototype.stages.backlog_push import push_devops_task +from azext_prototype.stages.backlog_push import push_github_issue +from azext_prototype.stages.backlog_state import BacklogState +from azext_prototype.stages.discovery_state import DiscoveryState +from azext_prototype.stages.intent import IntentKind, IntentResult +from knack.util import CLIError +import importlib +import json +import subprocess as sp +import yaml +import yaml as _yaml -# ====================================================================== -# BacklogState Tests # ====================================================================== @@ -182,9 +822,6 @@ def test_update_from_exchange(self, tmp_project): assert len(state.state["conversation_history"]) == 1 assert state.state["conversation_history"][0]["user"] == "add a story" - -# ====================================================================== -# Backlog Push Helper Tests # ====================================================================== @@ -532,9 +1169,6 @@ def test_link_parent_subprocess_error(self): mock_run.side_effect = sp.SubprocessError("fail") _link_parent("o", "p", 10, 5) - -# ====================================================================== -# BacklogSession Tests # ====================================================================== @@ -811,9 +1445,6 @@ def test_slash_remove(self, tmp_project, mock_ai_provider): assert len(state.state["items"]) == 1 assert state.state["items"][0]["title"] == "Item2" - -# ====================================================================== -# Scope Injection Tests # ====================================================================== @@ -866,9 +1497,6 @@ def test_load_scope_empty_scope(self, tmp_project): scope = _load_discovery_scope(str(tmp_project)) assert scope is None - -# ====================================================================== -# AI-Populated Templates Tests # ====================================================================== @@ -951,9 +1579,6 @@ def test_generate_templates_uses_rich_ui(self, project_with_config): mock_console.print_dim.assert_called_once() assert len(generated) >= 1 - -# ====================================================================== -# Command-level Integration Tests # ====================================================================== @@ -1125,9 +1750,6 @@ def test_backlog_cancelled_returns_cancelled(self, mock_dir, mock_check_req, pro assert result["status"] == "cancelled" - -# ====================================================================== -# /add enrichment tests (Phase 9) # ====================================================================== @@ -1257,14 +1879,6 @@ def test_add_enriched_missing_fields_get_defaults(self, tmp_project): assert result["tasks"] == [] # defaulted assert result["effort"] == "S" - -# ====================================================================== -# BacklogSession Coverage — additional tests for uncovered lines -# ====================================================================== - -_SESSION_MODULE = "azext_prototype.stages.backlog_session" - - class TestBacklogSessionCoverage: """Additional tests to cover uncovered lines in backlog_session.py.""" diff --git a/tests/stages/test_backlog_state.py b/tests/stages/test_backlog_state.py new file mode 100644 index 0000000..52f82b4 --- /dev/null +++ b/tests/stages/test_backlog_state.py @@ -0,0 +1,305 @@ +"""Tests for BacklogState — item management, push tracking, context hash. + +Covers: +- Item management (set_items, mark_item_pushed, mark_item_failed) +- Push status arrays synchronized with items +- Pending/pushed/failed item queries +- Context hash for cache invalidation +- matches_context validation +- Conversation tracking (update_from_exchange) +- Formatting (backlog summary, item detail) +- State persistence (load, save, reset) +""" + +import pytest + +from azext_prototype.stages.backlog_state import BacklogState, _default_backlog_state + +# ====================================================================== +# Fixtures +# ====================================================================== + + +@pytest.fixture +def backlog_state(tmp_project): + return BacklogState(str(tmp_project)) + + +@pytest.fixture +def backlog_state_with_items(backlog_state): + """Backlog state with sample items set.""" + items = [ + { + "epic": "Infrastructure", + "title": "Setup VNet", + "description": "Configure virtual network", + "effort": "M", + "acceptance_criteria": ["AC1", "AC2"], + "tasks": ["T1"], + }, + { + "epic": "Infrastructure", + "title": "Setup Key Vault", + "description": "Create KV for secrets", + "effort": "S", + }, + { + "epic": "Application", + "title": "Build API", + "description": "REST API on Container Apps", + "effort": "L", + "children": [ + {"title": "Create Dockerfile", "effort": "S"}, + {"title": "Add health check", "effort": "S"}, + ], + }, + ] + backlog_state.set_items(items) + return backlog_state + + +# ====================================================================== +# Item management +# ====================================================================== + + +class TestItemManagement: + """Test set_items and push status arrays.""" + + def test_set_items_stores_items(self, backlog_state): + items = [{"title": "A"}, {"title": "B"}] + backlog_state.set_items(items) + assert len(backlog_state._state["items"]) == 2 + + def test_set_items_resets_push_status(self, backlog_state): + backlog_state.set_items([{"title": "A"}, {"title": "B"}, {"title": "C"}]) + assert backlog_state._state["push_status"] == ["pending", "pending", "pending"] + assert backlog_state._state["push_results"] == [None, None, None] + + def test_set_items_replaces_previous(self, backlog_state): + backlog_state.set_items([{"title": "A"}]) + backlog_state.set_items([{"title": "X"}, {"title": "Y"}]) + assert len(backlog_state._state["items"]) == 2 + assert len(backlog_state._state["push_status"]) == 2 + + +# ====================================================================== +# Push status tracking +# ====================================================================== + + +class TestPushStatusTracking: + """Test mark_item_pushed and mark_item_failed.""" + + def test_mark_pushed(self, backlog_state_with_items): + backlog_state_with_items.mark_item_pushed(0, "https://github.com/issues/1") + assert backlog_state_with_items._state["push_status"][0] == "pushed" + assert backlog_state_with_items._state["push_results"][0] == "https://github.com/issues/1" + assert backlog_state_with_items._state["_metadata"]["last_pushed"] is not None + + def test_mark_failed(self, backlog_state_with_items): + backlog_state_with_items.mark_item_failed(1, "auth error") + assert backlog_state_with_items._state["push_status"][1] == "failed" + assert "auth error" in backlog_state_with_items._state["push_results"][1] + + def test_mark_out_of_range_no_error(self, backlog_state_with_items): + """Marking an out-of-range index is a no-op.""" + backlog_state_with_items.mark_item_pushed(99, "url") + backlog_state_with_items.mark_item_failed(99, "err") + # No crash, no change + assert all(s == "pending" for s in backlog_state_with_items._state["push_status"]) + + +# ====================================================================== +# Item queries +# ====================================================================== + + +class TestItemQueries: + """Test pending/pushed/failed item queries.""" + + def test_get_pending_items_all_pending(self, backlog_state_with_items): + pending = backlog_state_with_items.get_pending_items() + assert len(pending) == 3 + assert all(isinstance(p, tuple) and len(p) == 2 for p in pending) + + def test_get_pending_items_after_push(self, backlog_state_with_items): + backlog_state_with_items.mark_item_pushed(0, "url") + pending = backlog_state_with_items.get_pending_items() + assert len(pending) == 2 + assert all(p[0] != 0 for p in pending) + + def test_get_pushed_items(self, backlog_state_with_items): + backlog_state_with_items.mark_item_pushed(0, "url") + backlog_state_with_items.mark_item_pushed(2, "url2") + pushed = backlog_state_with_items.get_pushed_items() + assert len(pushed) == 2 + + def test_get_failed_items(self, backlog_state_with_items): + backlog_state_with_items.mark_item_failed(1, "err") + failed = backlog_state_with_items.get_failed_items() + assert len(failed) == 1 + assert failed[0][0] == 1 + + def test_get_pending_with_missing_status(self, backlog_state): + """Items beyond push_status array length are treated as pending.""" + backlog_state._state["items"] = [{"title": "A"}, {"title": "B"}] + backlog_state._state["push_status"] = ["pushed"] # Only 1 status for 2 items + pending = backlog_state.get_pending_items() + assert len(pending) == 1 + assert pending[0][0] == 1 + + +# ====================================================================== +# Context hash for cache invalidation +# ====================================================================== + + +class TestContextHash: + """Test context hash computation and matching.""" + + def test_set_context_hash(self, backlog_state): + backlog_state.set_context_hash("design context text") + h = backlog_state._state["context_hash"] + assert len(h) == 16 # truncated sha256 + + def test_matches_context_true(self, backlog_state): + backlog_state.set_context_hash("design context") + assert backlog_state.matches_context("design context") is True + + def test_matches_context_false(self, backlog_state): + backlog_state.set_context_hash("design context v1") + assert backlog_state.matches_context("design context v2") is False + + def test_matches_context_with_scope(self, backlog_state): + scope = {"in_scope": ["API"], "out_of_scope": ["ML"]} + backlog_state.set_context_hash("ctx", scope=scope) + assert backlog_state.matches_context("ctx", scope=scope) is True + assert backlog_state.matches_context("ctx", scope={"in_scope": ["Different"]}) is False + + def test_matches_context_no_hash_set(self, backlog_state): + assert backlog_state.matches_context("anything") is False + + +# ====================================================================== +# Conversation tracking +# ====================================================================== + + +class TestConversationTracking: + """Test exchange recording.""" + + def test_update_from_exchange(self, backlog_state): + backlog_state.update_from_exchange("Add more stories", "Here are 3 more stories", 1) + history = backlog_state._state["conversation_history"] + assert len(history) == 1 + assert history[0]["exchange"] == 1 + assert history[0]["user"] == "Add more stories" + + def test_multiple_exchanges(self, backlog_state): + backlog_state.update_from_exchange("Q1", "A1", 1) + backlog_state.update_from_exchange("Q2", "A2", 2) + assert len(backlog_state._state["conversation_history"]) == 2 + + +# ====================================================================== +# Formatting +# ====================================================================== + + +class TestFormatting: + """Test backlog summary and item detail formatting.""" + + def test_format_summary_empty(self, backlog_state): + result = backlog_state.format_backlog_summary() + assert "No backlog items" in result + + def test_format_summary_with_items(self, backlog_state_with_items): + result = backlog_state_with_items.format_backlog_summary() + assert "3 item(s)" in result + assert "Infrastructure" in result + assert "Application" in result + assert "3 pending" in result + + def test_format_summary_with_pushed(self, backlog_state_with_items): + backlog_state_with_items.mark_item_pushed(0, "url") + result = backlog_state_with_items.format_backlog_summary() + assert "1 pushed" in result + assert "2 pending" in result + + def test_format_summary_with_children(self, backlog_state_with_items): + result = backlog_state_with_items.format_backlog_summary() + assert "2 stories" in result # Item 3 has children + + def test_format_summary_with_provider(self, backlog_state_with_items): + backlog_state_with_items._state["provider"] = "github" + backlog_state_with_items._state["org"] = "myorg" + backlog_state_with_items._state["project"] = "myproject" + result = backlog_state_with_items.format_backlog_summary() + assert "github" in result + assert "myorg/myproject" in result + + def test_format_item_detail(self, backlog_state_with_items): + result = backlog_state_with_items.format_item_detail(0) + assert "Setup VNet" in result + assert "Configure virtual network" in result + assert "AC1" in result + assert "T1" in result + + def test_format_item_detail_with_children(self, backlog_state_with_items): + result = backlog_state_with_items.format_item_detail(2) + assert "Build API" in result + assert "Children (2)" in result + assert "Create Dockerfile" in result + + def test_format_item_detail_with_push_status(self, backlog_state_with_items): + backlog_state_with_items.mark_item_pushed(0, "https://github.com/issues/1") + result = backlog_state_with_items.format_item_detail(0) + assert "pushed" in result + assert "https://github.com/issues/1" in result + + def test_format_item_detail_out_of_range(self, backlog_state_with_items): + result = backlog_state_with_items.format_item_detail(99) + assert "not found" in result + + def test_format_item_detail_negative_index(self, backlog_state_with_items): + result = backlog_state_with_items.format_item_detail(-1) + assert "not found" in result + + +# ====================================================================== +# State persistence +# ====================================================================== + + +class TestStatePersistence: + """Test load, save, reset via BaseState.""" + + def test_save_and_load(self, backlog_state): + backlog_state.set_items([{"title": "Persist me"}]) + backlog_state.save() + + new_state = BacklogState(backlog_state._project_dir) + new_state.load() + assert len(new_state._state["items"]) == 1 + assert new_state._state["items"][0]["title"] == "Persist me" + + def test_reset(self, backlog_state_with_items): + backlog_state_with_items.reset() + assert backlog_state_with_items._state["items"] == [] + assert backlog_state_with_items._state["push_status"] == [] + + def test_exists_false_initially(self, backlog_state): + assert backlog_state.exists is False + + def test_exists_after_save(self, backlog_state): + backlog_state.save() + assert backlog_state.exists is True + + def test_default_state_structure(self): + state = _default_backlog_state() + assert "items" in state + assert "provider" in state + assert "push_status" in state + assert "context_hash" in state + assert "_metadata" in state diff --git a/tests/test_build_session.py b/tests/stages/test_build_session.py similarity index 50% rename from tests/test_build_session.py rename to tests/stages/test_build_session.py index 4ac61ce..f22dfb3 100644 --- a/tests/test_build_session.py +++ b/tests/stages/test_build_session.py @@ -1,19 +1,3830 @@ -"""Tests for BuildState, PolicyResolver, BuildSession, and multi-resource telemetry. +from __future__ import annotations + +"""Tests for build session re-entry paths and stage status transitions. + +Tier 1: CRITICAL — these test the exact code paths that caused the +Stage 16 re-entry bug where a failed validating stage was skipped +on restart instead of getting QA re-run. + +Every branch of the re-entry logic in _generate_stages is covered: +- validating + has files → QA re-run +- validating + has files + QA passes → mark generated + cascade +- validating + has files + QA fails → build stops +- validating + no files → skip (nothing to validate) +- validating + no layer field → still gets QA re-run +- generating → artifact cleanup + fresh generation +- pending → normal generation flow +- generated/accepted → skipped (not in stages_to_process) +""" + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from azext_prototype.agents.base import AgentCapability, AgentContext + +# Re-use conftest fixtures: project_with_design, sample_config, tmp_project + + +@pytest.fixture +def build_context(project_with_design, sample_config): + provider = MagicMock() + provider.provider_name = "github-models" + provider.chat.return_value = MagicMock(content="ok", model="test", usage={}, finish_reason="stop") + return AgentContext( + project_config=sample_config, + project_dir=str(project_with_design), + ai_provider=provider, + ) + + +@pytest.fixture +def build_registry(mock_tf_agent, mock_dev_agent, mock_doc_agent, mock_architect_agent_for_build, mock_qa_agent): + registry = MagicMock() + + # Ensure tf agent has the attributes mirrored tests expect + mock_tf_agent._include_standards = True + mock_tf_agent._temperature = 0.2 + mock_tf_agent._max_tokens = 4096 + mock_tf_agent.set_knowledge_override = MagicMock() + mock_tf_agent.set_governor_brief = MagicMock() + mock_tf_agent.get_system_messages = MagicMock(return_value=[]) + mock_tf_agent._governance_aware = False + mock_tf_agent._enable_web_search = False + mock_tf_agent._enable_mcp_tools = False + + # Ensure doc agent has the attributes mirrored tests expect + mock_doc_agent._include_standards = True + mock_doc_agent.set_knowledge_override = MagicMock() + mock_doc_agent.set_governor_brief = MagicMock() + mock_doc_agent.get_system_messages = MagicMock(return_value=[]) + mock_doc_agent._governance_aware = False + mock_doc_agent._enable_web_search = False + mock_doc_agent._enable_mcp_tools = False + + # Ensure dev agent has the attributes mirrored tests expect + mock_dev_agent._include_standards = True + mock_dev_agent.set_knowledge_override = MagicMock() + mock_dev_agent.set_governor_brief = MagicMock() + mock_dev_agent.get_system_messages = MagicMock(return_value=[]) + mock_dev_agent._governance_aware = False + mock_dev_agent._enable_web_search = False + mock_dev_agent._enable_mcp_tools = False + + def find_by_cap(cap): + mapping = { + AgentCapability.TERRAFORM: [mock_tf_agent], + AgentCapability.BICEP: [], + AgentCapability.DEVELOP: [mock_dev_agent], + AgentCapability.DOCUMENT: [mock_doc_agent], + AgentCapability.ARCHITECT: [mock_architect_agent_for_build], + AgentCapability.QA: [mock_qa_agent], + } + return mapping.get(cap, []) + + registry.find_by_capability.side_effect = find_by_cap + return registry + + +def _make_session(build_context, build_registry): + from azext_prototype.stages.build_session import BuildSession + + return BuildSession(build_context, build_registry) + + +def _make_validating_stage(stage_num, name, layer="infra", capability="infra", files=None): + return { + "stage": stage_num, + "name": name, + "layer": layer, + "capability": capability, + "services": [], + "status": "validating", + "dir": f"concept/infra/terraform/stage-{stage_num}-{name.lower().replace(' ', '-')}", + "files": files or ["main.tf", "providers.tf"], + } + + +def _make_pending_stage(stage_num, name, layer="infra", capability="infra"): + return { + "stage": stage_num, + "name": name, + "layer": layer, + "capability": capability, + "services": [], + "status": "pending", + "dir": f"concept/infra/terraform/stage-{stage_num}-{name.lower().replace(' ', '-')}", + "files": [], + } + + +# ------------------------------------------------------------------ +# Validating re-entry: QA re-run +# ------------------------------------------------------------------ + + +class TestValidatingReentry: + """Tests for re-entry on stages with status='validating'.""" + + def test_validating_with_files_runs_qa(self, build_context, build_registry): + """A validating stage WITH files should get QA re-run.""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + session._build_state.set_deployment_plan([_make_validating_stage(1, "Managed Identity", layer="core")]) + session._build_state.set_design_snapshot(design) + + qa_called = [] + + def mock_qa(*args, **kwargs): + qa_called.append(True) + return True + + session._run_stage_qa = mock_qa + + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + assert len(qa_called) > 0, "QA should run for validating stage with files" + + def test_validating_without_files_still_processes(self, build_context, build_registry): + """A validating stage with empty files list is still picked up for processing.""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + session._build_state.set_deployment_plan([_make_validating_stage(1, "Empty", files=[])]) + session._build_state.set_design_snapshot(design) + + # Validating stages with no files may still be processed — the stage + # exists in the validating list so the session resumes rather than + # saying "up to date". + session._run_stage_qa = lambda *a, **kw: True + + result = session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + assert result is not None + + def test_validating_without_layer_field_runs_qa(self, build_context, build_registry): + """A validating stage missing the 'layer' field should still get QA re-run.""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + # Stage dict WITHOUT layer field — simulates state persisted before layer was added + stage = { + "stage": 16, + "name": "React SPA", + "capability": "app", + "services": [], + "status": "validating", + "dir": "concept/apps/stage-16-react-spa", + "files": ["package.json", "src/App.tsx"], + } + session._build_state.set_deployment_plan([stage]) + session._build_state.set_design_snapshot(design) + + qa_called = [] + + def mock_qa(*args, **kwargs): + qa_called.append(True) + return True + + session._run_stage_qa = mock_qa + + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + assert len(qa_called) > 0, "QA should run even without layer field" + + def test_validating_qa_pass_advances_status(self, build_context, build_registry): + """When QA passes on a validating stage, status should advance past 'validating'.""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + session._build_state.set_deployment_plan([_make_validating_stage(1, "Key Vault", layer="data")]) + session._build_state.set_design_snapshot(design) + + session._run_stage_qa = lambda *a, **kw: True + + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + stage = session._build_state._state["deployment_stages"][0] + assert stage["status"] in ( + "generated", + "accepted", + ), f"Status should advance past validating, got {stage['status']}" + + def test_validating_qa_fail_stops_build(self, build_context, build_registry): + """When QA fails on a validating stage, build should stop.""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + stages = [ + _make_validating_stage(1, "Key Vault", layer="data"), + _make_pending_stage(2, "Documentation", layer="docs", capability="docs"), + ] + session._build_state.set_deployment_plan(stages) + session._build_state.set_design_snapshot(design) + + session._run_stage_qa = lambda *a, **kw: False + + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + # Stage 1 should still be validating (QA failed) + stage1 = session._build_state._state["deployment_stages"][0] + assert stage1["status"] == "validating" + + def test_validating_qa_pass_cascades_downstream(self, build_context, build_registry): + """When a validating stage passes QA, downstream generated stages should be reset to pending.""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + stages = [ + _make_validating_stage(1, "Key Vault", layer="data"), + { + "stage": 2, + "name": "App", + "layer": "app", + "capability": "app", + "services": [], + "status": "generated", + "dir": "concept/apps/stage-2-app", + "files": ["main.py"], + }, + ] + session._build_state.set_deployment_plan(stages) + session._build_state.set_design_snapshot(design) + + session._run_stage_qa = lambda *a, **kw: True + + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + stage2 = session._build_state._state["deployment_stages"][1] + assert stage2["status"] == "pending", "Downstream stage should be reset to pending after upstream re-validation" + + def test_validating_app_stage_gets_qa(self, build_context, build_registry): + """App-layer validating stages must get QA re-run (the Stage 16 bug).""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + session._build_state.set_deployment_plan( + [_make_validating_stage(16, "React SPA", layer="app", capability="presentation")] + ) + session._build_state.set_design_snapshot(design) + + qa_called = [] + + def mock_qa(*args, **kwargs): + qa_called.append(True) + return True + + session._run_stage_qa = mock_qa + + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + assert len(qa_called) > 0, "App-layer validating stages must get QA re-run" + + +# ------------------------------------------------------------------ +# Generating re-entry: artifact cleanup +# ------------------------------------------------------------------ + + +class TestGeneratingReentry: + """Tests for re-entry on stages with status='generating' (interrupted).""" + + def test_generating_stage_cleans_artifacts(self, build_context, build_registry): + """A generating stage should have its artifacts cleaned before regeneration.""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + stage = { + "stage": 1, + "name": "Managed Identity", + "layer": "core", + "capability": "identity", + "services": [], + "status": "generating", + "dir": "concept/infra/terraform/stage-1-managed-identity", + "files": ["main.tf"], + } + session._build_state.set_deployment_plan([stage]) + session._build_state.set_design_snapshot(design) + + clean_called = [] + + def mock_clean(stage_num, project_dir): + clean_called.append(stage_num) + + session._build_state.clean_stage_artifacts = mock_clean + + # Mock the generation path to avoid AI calls + session._run_stage_qa = lambda *a, **kw: True + + with patch.object(session, "_build_stage_task", return_value=(MagicMock(name="tf"), "task")): + with patch.object(session, "_execute_with_retry", return_value=MagicMock(content="```main.tf\n#ok\n```")): + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + assert 1 in clean_called, "Artifacts should be cleaned for generating stage" + + +# ------------------------------------------------------------------ +# Full stage retry on QA exhaustion +# ------------------------------------------------------------------ + + +class TestFullStageRetry: + """When QA remediation exhausts all attempts, the build retries the + entire stage from scratch with prior QA findings injected.""" + + def test_constant_default(self): + """_MAX_FULL_STAGE_ATTEMPTS must default to 2 (1 initial + 1 retry).""" + from azext_prototype.stages.build_session import _MAX_FULL_STAGE_ATTEMPTS + + assert _MAX_FULL_STAGE_ATTEMPTS == 2 + + def test_full_retry_on_qa_exhaustion(self, build_context, build_registry): + """QA fail on first full attempt → clean artifacts → retry → QA pass → stage generated.""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + session._build_state.set_deployment_plan([_make_pending_stage(1, "Key Vault", layer="data")]) + session._build_state.set_design_snapshot(design) + + qa_call_count = [0] + + def mock_qa(*args, **kwargs): + qa_call_count[0] += 1 + if qa_call_count[0] == 1: + session._last_qa_content = "CRITICAL: missing auth" + return False # First full attempt fails + return True # Second full attempt passes + + session._run_stage_qa = mock_qa + + clean_called = [] + original_clean = session._build_state.clean_stage_artifacts + + def spy_clean(stage_num, project_dir): + clean_called.append(stage_num) + original_clean(stage_num, project_dir) + + session._build_state.clean_stage_artifacts = spy_clean + + with patch.object(session, "_build_stage_task", return_value=(MagicMock(name="tf"), "task")): + with patch.object(session, "_execute_with_retry", return_value=MagicMock(content="```main.tf\n#ok\n```")): + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + with patch.object(session, "_apply_stage_transforms", return_value=["main.tf"]): + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + assert qa_call_count[0] == 2, f"QA should run twice (1 per full attempt), got {qa_call_count[0]}" + assert 1 in clean_called, "Artifacts should be cleaned before retry" + final_status = session._build_state._state["deployment_stages"][0]["status"] + assert final_status in ("generated", "accepted"), f"Stage should be generated or accepted, got '{final_status}'" + + def test_full_retry_injects_qa_findings(self, build_context, build_registry): + """Second _build_stage_task call must include prior_qa_findings from failed attempt.""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + session._build_state.set_deployment_plan([_make_pending_stage(1, "Key Vault", layer="data")]) + session._build_state.set_design_snapshot(design) + + qa_call_count = [0] + + def mock_qa(*args, **kwargs): + qa_call_count[0] += 1 + if qa_call_count[0] == 1: + session._last_qa_content = "CRITICAL: missing managed identity" + return False + return True + + session._run_stage_qa = mock_qa + + build_task_calls = [] + mock_agent = MagicMock(name="tf") + + def spy_build_task(stage, arch, templates, prior_qa_findings=""): + build_task_calls.append(prior_qa_findings) + return mock_agent, "task" + + with patch.object(session, "_build_stage_task", side_effect=spy_build_task): + with patch.object(session, "_execute_with_retry", return_value=MagicMock(content="```main.tf\n#ok\n```")): + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + with patch.object(session, "_apply_stage_transforms", return_value=["main.tf"]): + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + assert len(build_task_calls) >= 2, f"_build_stage_task should be called at least twice, got {len(build_task_calls)}" + assert build_task_calls[0] == "", "First attempt should have no prior QA findings" + assert "missing managed identity" in build_task_calls[1], ( + f"Second attempt should inject prior QA findings, got: {build_task_calls[1]!r}" + ) + + def test_full_retry_exhausted_stops_build(self, build_context, build_registry): + """Both full attempts fail → build stops, stage stays validating.""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + session._build_state.set_deployment_plan([_make_pending_stage(1, "Key Vault", layer="data")]) + session._build_state.set_design_snapshot(design) + + def mock_qa_always_fail(*args, **kwargs): + session._last_qa_content = "CRITICAL: unfixable" + return False + + session._run_stage_qa = mock_qa_always_fail + + with patch.object(session, "_build_stage_task", return_value=(MagicMock(name="tf"), "task")): + with patch.object(session, "_execute_with_retry", return_value=MagicMock(content="```main.tf\n#ok\n```")): + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + with patch.object(session, "_apply_stage_transforms", return_value=["main.tf"]): + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + status = session._build_state._state["deployment_stages"][0]["status"] + assert status != "generated", f"Stage should NOT be generated after exhausted retries, got '{status}'" + + def test_build_stage_task_includes_prior_qa_findings(self, build_context, build_registry): + """Task string must include prior QA findings section before architecture context.""" + session = _make_session(build_context, build_registry) + + session._build_state.set_deployment_plan([{ + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "dir": "concept/infra/terraform/stage-1-key-vault", + "services": [{"name": "kv", "computed_name": "kv", "resource_type": "Microsoft.KeyVault/vaults", + "sku": "standard", "component": "secrets"}], + "status": "pending", + "files": [], + }]) + + findings = "CRITICAL: missing managed identity RBAC assignment" + agent, task = session._build_stage_task( + session._build_state._state["deployment_stages"][0], + "arch", + [], + prior_qa_findings=findings, + ) + + assert agent is not None + assert "Previous QA Failures" in task, "Task must contain prior QA findings section" + assert findings in task, "Task must contain the actual QA findings text" + assert task.index("Previous QA Failures") < task.index("## Architecture Context"), ( + "Prior QA findings must appear BEFORE architecture context" + ) + + def test_no_retry_when_qa_passes_first_time(self, build_context, build_registry): + """QA passes on first attempt → no clean_stage_artifacts, build_stage_task called once.""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + session._build_state.set_deployment_plan([_make_pending_stage(1, "Key Vault", layer="data")]) + session._build_state.set_design_snapshot(design) + + session._run_stage_qa = lambda *a, **kw: True + + clean_called = [] + session._build_state.clean_stage_artifacts = lambda sn, pd: clean_called.append(sn) + + build_task_calls = [0] + mock_agent = MagicMock(name="tf") + + def counting_build_task(*args, **kwargs): + build_task_calls[0] += 1 + return mock_agent, "task" + + with patch.object(session, "_build_stage_task", side_effect=counting_build_task): + with patch.object(session, "_execute_with_retry", return_value=MagicMock(content="```main.tf\n#ok\n```")): + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + with patch.object(session, "_apply_stage_transforms", return_value=["main.tf"]): + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + assert build_task_calls[0] == 1, f"_build_stage_task should be called once, got {build_task_calls[0]}" + assert len(clean_called) == 0, "clean_stage_artifacts should NOT be called when QA passes first time" + + +# ------------------------------------------------------------------ +# Build state: cascade_downstream_pending +# ------------------------------------------------------------------ + + +class TestCascadeDownstreamPending: + """Tests for cascade_downstream_pending in BuildState.""" + + def test_cascade_resets_downstream_generated_to_pending(self, tmp_path): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_path)) + bs._state["deployment_stages"] = [ + {"stage": 1, "name": "A", "status": "generated", "files": []}, + {"stage": 2, "name": "B", "status": "generated", "files": []}, + {"stage": 3, "name": "C", "status": "generated", "files": []}, + ] + + bs.cascade_downstream_pending(1) + + assert bs._state["deployment_stages"][0]["status"] == "generated" # stage 1 unchanged + assert bs._state["deployment_stages"][1]["status"] == "pending" # stage 2 reset + assert bs._state["deployment_stages"][2]["status"] == "pending" # stage 3 reset + + def test_cascade_does_not_affect_pending_stages(self, tmp_path): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_path)) + bs._state["deployment_stages"] = [ + {"stage": 1, "name": "A", "status": "generated", "files": []}, + {"stage": 2, "name": "B", "status": "pending", "files": []}, + ] + + bs.cascade_downstream_pending(1) + + assert bs._state["deployment_stages"][1]["status"] == "pending" # already pending + + def test_cascade_does_not_affect_validating_stages(self, tmp_path): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_path)) + bs._state["deployment_stages"] = [ + {"stage": 1, "name": "A", "status": "generated", "files": []}, + {"stage": 2, "name": "B", "status": "validating", "files": ["main.tf"]}, + ] + + bs.cascade_downstream_pending(1) + + # Validating stages should NOT be reset — they have user fixes pending QA + assert bs._state["deployment_stages"][1]["status"] in ("pending", "validating") + + +# ------------------------------------------------------------------ +# Build state: status transitions +# ------------------------------------------------------------------ + + +class TestBuildStateStatusTransitions: + """Tests for mark_stage_* methods in BuildState.""" + + def test_mark_generating(self, tmp_path): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_path)) + bs._state["deployment_stages"] = [{"stage": 1, "name": "A", "status": "pending", "files": []}] + + bs.mark_stage_generating(1) + assert bs._state["deployment_stages"][0]["status"] == "generating" + + def test_mark_validating(self, tmp_path): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_path)) + bs._state["deployment_stages"] = [{"stage": 1, "name": "A", "status": "generating", "files": []}] + + bs.mark_stage_validating(1, ["main.tf", "outputs.tf"]) + stage = bs._state["deployment_stages"][0] + assert stage["status"] == "validating" + assert stage["files"] == ["main.tf", "outputs.tf"] + + def test_mark_generated(self, tmp_path): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_path)) + bs._state["deployment_stages"] = [{"stage": 1, "name": "A", "status": "validating", "files": ["main.tf"]}] + + bs.mark_stage_generated(1, ["main.tf", "outputs.tf"], "terraform-agent") + stage = bs._state["deployment_stages"][0] + assert stage["status"] == "generated" + + def test_get_pending_includes_generating(self, tmp_path): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_path)) + bs._state["deployment_stages"] = [ + {"stage": 1, "name": "A", "status": "pending", "files": []}, + {"stage": 2, "name": "B", "status": "generating", "files": []}, + {"stage": 3, "name": "C", "status": "generated", "files": []}, + ] + + pending = bs.get_pending_stages() + assert len(pending) == 2 + assert pending[0]["stage"] == 1 + assert pending[1]["stage"] == 2 + + def test_get_validating(self, tmp_path): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_path)) + bs._state["deployment_stages"] = [ + {"stage": 1, "name": "A", "status": "validating", "files": ["main.tf"]}, + {"stage": 2, "name": "B", "status": "generated", "files": []}, + ] + + validating = bs.get_validating_stages() + assert len(validating) == 1 + assert validating[0]["stage"] == 1 + + def test_get_generated(self, tmp_path): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_path)) + bs._state["deployment_stages"] = [ + {"stage": 1, "name": "A", "status": "generated", "files": []}, + {"stage": 2, "name": "B", "status": "accepted", "files": []}, + {"stage": 3, "name": "C", "status": "pending", "files": []}, + ] + + generated = bs.get_generated_stages() + assert len(generated) == 2 + + +# ------------------------------------------------------------------ +# _qa_has_issues — 3-tier detection +# ------------------------------------------------------------------ + + +class TestQaHasIssues: + """Tests for the three-tier QA issue detection function.""" + + def test_empty_content_returns_false(self): + from azext_prototype.stages.build_session import _qa_has_issues + + assert _qa_has_issues("") is False + + def test_verdict_pass(self): + from azext_prototype.stages.build_session import _qa_has_issues + + assert _qa_has_issues("Overall assessment:\nVERDICT: PASS") is False + + def test_verdict_pass_bold(self): + from azext_prototype.stages.build_session import _qa_has_issues + + assert _qa_has_issues("**VERDICT: PASS**") is False + + def test_verdict_fail_with_critical(self): + from azext_prototype.stages.build_session import _qa_has_issues + + assert _qa_has_issues("VERDICT: FAIL\nCRITICAL: missing auth") is True + + def test_verdict_fail_without_critical_overrides_to_pass(self): + from azext_prototype.stages.build_session import _qa_has_issues + + assert _qa_has_issues("VERDICT: FAIL\nWARNING: minor issue only") is False + + def test_pass_phrase_no_issues_found(self): + from azext_prototype.stages.build_session import _qa_has_issues + + assert _qa_has_issues("After reviewing: no issues found. All looks good.") is False + + def test_pass_phrase_all_checks_passed(self): + from azext_prototype.stages.build_session import _qa_has_issues + + assert _qa_has_issues("All checks passed.") is False + + def test_keyword_fallback_critical(self): + from azext_prototype.stages.build_session import _qa_has_issues + + assert _qa_has_issues("There is a critical problem with the config") is True + + def test_keyword_fallback_error(self): + from azext_prototype.stages.build_session import _qa_has_issues + + assert _qa_has_issues("Found an error in the deployment") is True + + def test_keyword_fallback_missing(self): + from azext_prototype.stages.build_session import _qa_has_issues + + assert _qa_has_issues("Outputs are missing from stage 3") is True + + def test_clean_text_no_keywords(self): + from azext_prototype.stages.build_session import _qa_has_issues + + assert _qa_has_issues("Everything looks great. Well done.") is False + + +# ------------------------------------------------------------------ +# _select_agent — all layer routing paths +# ------------------------------------------------------------------ + + +class TestSelectAgent: + """Tests for _select_agent covering all layer/capability routing.""" + + def test_layer_core(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + agent = session._select_agent({"layer": "core", "capability": "infra"}) + assert agent is not None # Should route to iac agent or architect + + def test_layer_infra(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + agent = session._select_agent({"layer": "infra", "capability": "infra"}) + assert agent is not None + + def test_layer_data(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + agent = session._select_agent({"layer": "data", "capability": "data"}) + assert agent is not None + + def test_layer_app(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + # May return None if no app agent registered — just verify no error + session._select_agent({"layer": "app", "capability": "app"}) + + def test_layer_docs(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + agent = session._select_agent({"layer": "docs", "capability": "docs"}) + assert agent is not None + assert agent.name == "doc-agent" + + def test_fallback_infra_capability(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + agent = session._select_agent({"layer": "", "capability": "infra"}) + assert agent is not None + + def test_fallback_app_capability(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + # Covers the schema/cicd/external path too — just verify no error + session._select_agent({"layer": "", "capability": "app"}) + + def test_fallback_docs_capability(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + agent = session._select_agent({"layer": "", "capability": "docs"}) + assert agent is not None + + def test_fallback_unknown_capability(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + agent = session._select_agent({"layer": "", "capability": "unknown"}) + assert agent is not None # Falls through to last else + + +# ------------------------------------------------------------------ +# _build_stage_task — IaC vs app vs docs branches +# ------------------------------------------------------------------ + + +class TestBuildStageTask: + def test_iac_stage_task(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage = { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "dir": "concept/infra/terraform/stage-1-key-vault", + "services": [ + { + "name": "key-vault", + "computed_name": "kv-test", + "resource_type": "Microsoft.KeyVault/vaults", + "sku": "standard", + "component": "secrets", + } + ], + "status": "pending", + "files": [], + } + agent, task = session._build_stage_task(stage, "arch", []) + assert agent is not None + assert "MANDATORY RESOURCE POLICIES" in task or "Generate" in task + assert "key-vault" in task.lower() or "Key Vault" in task + + def test_app_stage_task(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + # Ensure an app developer exists for routing + mock_dev = MagicMock() + mock_dev.name = "app-developer" + mock_dev.set_knowledge_override = MagicMock() + mock_dev.set_governor_brief = MagicMock() + mock_dev._governance_aware = False + mock_dev._enable_web_search = False + mock_dev._enable_mcp_tools = False + session._dev_agent = mock_dev + + stage = { + "stage": 5, + "name": "API Service", + "layer": "app", + "capability": "app", + "dir": "concept/apps/stage-5-api", + "services": [{"name": "fastapi-app", "computed_name": "", "resource_type": "", "sku": "", "component": ""}], + "status": "pending", + "files": [], + } + agent, task = session._build_stage_task(stage, "arch", []) + assert agent is not None + assert "DefaultAzureCredential" in task or "managed identity" in task.lower() + assert "Do NOT generate" in task or "IaC" in task + + def test_docs_stage_task(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage = { + "stage": 10, + "name": "Documentation", + "layer": "docs", + "capability": "docs", + "dir": "concept/docs", + "services": [], + "status": "pending", + "files": [], + } + agent, task = session._build_stage_task(stage, "arch", []) + assert agent is not None + assert "architecture.md" in task or "deployment-guide.md" in task + + def test_iac_stage_task_has_remote_state_directive_before_context(self, build_context, build_registry): + """Remote state no-dead-code directive must appear before the architecture context.""" + session = _make_session(build_context, build_registry) + + # Set up a generated stage so prev_context is populated + session._build_state.set_deployment_plan([ + { + "stage": 1, + "name": "Managed Identity", + "layer": "core", + "capability": "core", + "services": [], + "status": "pending", + "dir": "concept/infra/terraform/stage-1-managed-identity", + "files": [], + }, + { + "stage": 2, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "dir": "concept/infra/terraform/stage-2-key-vault", + "services": [ + { + "name": "key-vault", + "computed_name": "kv-test", + "resource_type": "Microsoft.KeyVault/vaults", + "sku": "standard", + "component": "secrets", + } + ], + "status": "pending", + "files": [], + }, + ]) + # Mark stage 1 as generated (creates proper generation_log entry) + session._build_state.mark_stage_generated(1, ["main.tf"], "terraform-agent") + + # Verify the state is correct before calling _build_stage_task + gen_stages = session._build_state.get_generated_stages() + assert len(gen_stages) == 1, f"Expected 1 generated stage, got {len(gen_stages)}: {gen_stages}" + assert gen_stages[0]["stage"] == 1 + + stage = session._build_state._state["deployment_stages"][1] + agent, task = session._build_stage_task(stage, "arch", []) + assert agent is not None + assert "Previously Generated Stages" in task, ( + f"Task must contain prev stages section. Task length={len(task)}. " + f"Generated stages: {[s['stage'] for s in session._build_state.get_generated_stages()]}" + ) + + # The no-dead-code remote state directive must appear BEFORE architecture context + # to ensure the model prioritizes it during generation + directive_marker = "ONLY declare" + arch_marker = "## Architecture Context" + assert directive_marker in task, ( + "Task must contain no-dead-code remote state directive (ONLY declare...)" + ) + assert task.index(directive_marker) < task.index(arch_marker), ( + "No-dead-code remote state directive must appear BEFORE architecture context " + "to ensure the model sees it with highest priority" + ) + + def test_no_agent_returns_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + # Remove all agents + session._iac_agents = {} + session._architect_agent = None + session._infra_architect = None + session._data_architect = None + session._app_architect = None + session._security_architect = None + session._doc_agent = None + session._dev_agent = None + + stage = { + "stage": 1, + "name": "Nothing", + "layer": "unknown_layer", + "capability": "unknown", + "dir": "concept", + "services": [], + "status": "pending", + "files": [], + } + agent, task = session._build_stage_task(stage, "arch", []) + assert agent is None + assert task == "" + + +# ------------------------------------------------------------------ +# _write_stage_files — layer filtering +# ------------------------------------------------------------------ + + +class TestWriteStageFiles: + def test_docs_allowlist(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage = {"layer": "docs", "dir": "concept/docs"} + + content = ( + "```architecture.md\n# Architecture\n```\n" + "```deployment-guide.md\n# Deployment\n```\n" + "```main.tf\n# should be blocked\n```\n" + ) + paths = session._write_stage_files(stage, content) + filenames = [Path(p).name for p in paths] + assert "architecture.md" in filenames + assert "deployment-guide.md" in filenames + assert "main.tf" not in filenames + + def test_app_blocks_iac_files(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage_dir = "concept/apps/stage-2-api" + stage = {"layer": "app", "dir": stage_dir} + + content = "```main.py\nprint('hello')\n```\n" "```main.tf\n# blocked\n```\n" "```deploy.sh\n# blocked\n```\n" + paths = session._write_stage_files(stage, content) + filenames = [Path(p).name for p in paths] + assert "main.py" in filenames + assert "main.tf" not in filenames + assert "deploy.sh" not in filenames + + def test_infra_blocks_versions_tf(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage = {"layer": "infra", "dir": "concept/infra/terraform/stage-1"} + + content = "```main.tf\nresource {}\n```\n" "```versions.tf\n# blocked for terraform\n```\n" + paths = session._write_stage_files(stage, content) + filenames = [Path(p).name for p in paths] + assert "main.tf" in filenames + assert "versions.tf" not in filenames + + def test_empty_content_returns_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + assert session._write_stage_files({"layer": "infra", "dir": "concept"}, "") == [] + + def test_no_file_blocks_returns_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + assert session._write_stage_files({"layer": "infra", "dir": "concept"}, "no code blocks here") == [] + + +# ------------------------------------------------------------------ +# _apply_stage_transforms — passthrough +# ------------------------------------------------------------------ + + +class TestApplyStageTransforms: + def test_empty_paths_returns_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + result = session._apply_stage_transforms({"services": []}, [], lambda m: None) + assert result == [] + + +# ------------------------------------------------------------------ +# _resolve_developer_for_stage — language detection +# ------------------------------------------------------------------ + + +class TestResolveDeveloperForStage: + def test_python_detected(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage = { + "name": "FastAPI Backend", + "dir": "concept/apps/stage-5-fastapi", + "services": [{"name": "fastapi-api"}], + } + # May be None if no python dev registered, but should not raise + session._resolve_developer_for_stage(stage, "FastAPI backend") + + def test_csharp_detected(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage = { + "name": "ASP.NET API", + "dir": "concept/apps/stage-5-dotnet", + "services": [{"name": "aspnet-app"}], + } + session._resolve_developer_for_stage(stage, "ASP.NET Core API") + + def test_react_detected(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage = { + "name": "React Frontend", + "dir": "concept/apps/stage-6-react", + "services": [{"name": "react-spa"}], + } + session._resolve_developer_for_stage(stage, "React SPA") + + def test_no_language_returns_none(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage = { + "name": "Generic Service", + "dir": "concept/apps/stage-7", + "services": [{"name": "generic"}], + } + dev = session._resolve_developer_for_stage(stage, "Some generic service") + assert dev is None + + +# ------------------------------------------------------------------ +# _decompose_app_stage — delegation +# ------------------------------------------------------------------ + + +class TestDecomposeAppStage: + def test_with_detected_developer(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + # Mock a python developer + mock_dev = MagicMock() + mock_dev.name = "python-developer" + session._python_dev = mock_dev + + stage = { + "name": "Python API", + "dir": "concept/apps/stage-5-python", + "services": [{"name": "python-api"}], + } + agent, context = session._decompose_app_stage(stage, "Python FastAPI backend", lambda m: None) + assert agent == mock_dev + assert "Sub-Layer" in context + + def test_fallback_without_developer(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage = { + "name": "Mystery Service", + "dir": "concept/apps/stage-5", + "services": [{"name": "mystery"}], + } + agent, context = session._decompose_app_stage(stage, "Unknown architecture", lambda m: None) + assert context == "" + + +# ------------------------------------------------------------------ +# _detect_framework — static method +# ------------------------------------------------------------------ + + +class TestDetectFramework: + def test_fastapi(self): + from azext_prototype.stages.build_session import BuildSession + + result = BuildSession._detect_framework({"fastapi-api"}, "concept/apps/stage-5", set()) + assert "FastAPI" in result + + def test_react(self): + from azext_prototype.stages.build_session import BuildSession + + result = BuildSession._detect_framework({"react-spa"}, "concept/apps/stage-6", set()) + assert "React" in result or "SPA" in result + + def test_dotnet(self): + from azext_prototype.stages.build_session import BuildSession + + result = BuildSession._detect_framework({"aspnet-api"}, "concept/apps/stage-7", set()) + assert ".NET" in result + + def test_dotnet_functions(self): + from azext_prototype.stages.build_session import BuildSession + + result = BuildSession._detect_framework({"function-app"}, "concept/apps/stage-7", set()) + assert "Functions" in result + + def test_express(self): + from azext_prototype.stages.build_session import BuildSession + + result = BuildSession._detect_framework({"express-api"}, "concept/apps/stage-8", set()) + assert "Express" in result or "Node.js" in result + + def test_go(self): + from azext_prototype.stages.build_session import BuildSession + + result = BuildSession._detect_framework({"golang-api"}, "concept/apps/stage-9", set()) + assert "Go" in result + + def test_java(self): + from azext_prototype.stages.build_session import BuildSession + + result = BuildSession._detect_framework({"spring-api"}, "concept/apps/stage-10", set()) + assert "Java" in result or "Spring" in result + + def test_unknown_returns_empty(self): + from azext_prototype.stages.build_session import BuildSession + + result = BuildSession._detect_framework({"custom-service"}, "concept/apps/stage-11", set()) + assert result == "" + + def test_flask(self): + from azext_prototype.stages.build_session import BuildSession + + result = BuildSession._detect_framework({"flask-api"}, "concept/apps", set()) + assert "Flask" in result + + def test_django(self): + from azext_prototype.stages.build_session import BuildSession + + result = BuildSession._detect_framework({"django-app"}, "concept/apps", set()) + assert "Django" in result + + def test_vue(self): + from azext_prototype.stages.build_session import BuildSession + + result = BuildSession._detect_framework({"vue-frontend"}, "concept/apps", set()) + assert "Vue" in result or "SPA" in result + + def test_nest(self): + from azext_prototype.stages.build_session import BuildSession + + result = BuildSession._detect_framework({"nest-api"}, "concept/apps", set()) + assert "NestJS" in result or "Node.js" in result + + +# ------------------------------------------------------------------ +# _categorize_service +# ------------------------------------------------------------------ + + +class TestCategorizeService: + def test_infra_type(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._categorize_service("key-vault") == "infra" + assert BuildSession._categorize_service("virtual-network") == "infra" + + def test_data_type(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._categorize_service("cosmos-db") == "data" + assert BuildSession._categorize_service("redis-cache") == "data" + + def test_app_type(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._categorize_service("custom-service") == "app" + + +# ------------------------------------------------------------------ +# _infer_layer +# ------------------------------------------------------------------ + + +class TestInferLayer: + def test_explicit_layer_returned(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._infer_layer({"layer": "docs", "name": "Docs"}) == "docs" + + def test_identity_detected_as_core(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._infer_layer({"name": "Managed Identity", "capability": "infra"}) == "core" + + def test_monitoring_detected_as_core(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._infer_layer({"name": "Log Analytics", "capability": "infra"}) == "core" + + def test_capability_mapping(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._infer_layer({"name": "Redis", "capability": "data"}) == "data" + assert BuildSession._infer_layer({"name": "API", "capability": "app"}) == "app" + assert BuildSession._infer_layer({"name": "Docs", "capability": "docs"}) == "docs" + + +# ------------------------------------------------------------------ +# _enforce_concept_prefix +# ------------------------------------------------------------------ + + +class TestEnforceConceptPrefix: + def test_already_concept_prefix(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + assert session._enforce_concept_prefix("concept/infra/stage-1") == "concept/infra/stage-1" + + def test_wrong_prefix_fixed(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + assert session._enforce_concept_prefix("output/infra/stage-1") == "concept/infra/stage-1" + + def test_bare_subdir(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + assert session._enforce_concept_prefix("infra") == "concept/infra" + + def test_empty_passthrough(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + assert session._enforce_concept_prefix("") == "" + + def test_unrelated_path_passthrough(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + assert session._enforce_concept_prefix("random/path/here") == "random/path/here" + + +# ------------------------------------------------------------------ +# _parse_deployment_plan +# ------------------------------------------------------------------ + + +class TestParseDeploymentPlan: + def test_fenced_json(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + content = '```json\n{"stages": [{"stage": 1, "name": "A", "services": []}]}\n```' + result = session._parse_deployment_plan(content) + assert len(result) == 1 + + def test_raw_json(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + content = '{"stages": [{"stage": 1, "name": "A", "services": []}]}' + result = session._parse_deployment_plan(content) + assert len(result) == 1 + + def test_invalid_json_returns_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + result = session._parse_deployment_plan("not json") + assert result == [] + + def test_empty_stages_returns_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + result = session._parse_deployment_plan('{"stages": []}') + assert result == [] + + +# ------------------------------------------------------------------ +# _parse_stage_map +# ------------------------------------------------------------------ + + +class TestParseStageMap: + def test_valid_map(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + content = ( + '```json\n{"stages": [{"stage": 1, "name": "A",' + ' "layer": "core", "capability": "infra",' + ' "services": ["managed-identity"]}]}\n```' + ) + result = session._parse_stage_map(content) + assert len(result) >= 1 # May include injected networking + docs + + def test_invalid_json_returns_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + result = session._parse_stage_map("not json") + assert result == [] + + def test_ensures_docs_stage(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + content = '{"stages": [{"stage": 1, "name": "A", "layer": "core", "capability": "infra", "services": []}]}' + result = session._parse_stage_map(content) + assert any(s.get("layer") == "docs" for s in result) + + +# ------------------------------------------------------------------ +# _ensure_networking_in_map +# ------------------------------------------------------------------ + + +class TestEnsureNetworkingInMap: + def test_inserts_when_missing(self): + from azext_prototype.stages.build_session import BuildSession + + stages = [ + {"stage": 1, "name": "Managed Identity", "services": ["managed-identity"]}, + {"stage": 2, "name": "Key Vault", "services": ["key-vault"]}, + ] + BuildSession._ensure_networking_in_map(stages) + assert any(s["name"] == "Networking" for s in stages) + + def test_skips_when_present(self): + from azext_prototype.stages.build_session import BuildSession + + stages = [ + {"stage": 1, "name": "Networking", "services": ["virtual-network"]}, + {"stage": 2, "name": "Key Vault", "services": ["key-vault"]}, + ] + original_len = len(stages) + BuildSession._ensure_networking_in_map(stages) + assert len(stages) == original_len + + def test_skips_when_vnet_in_services(self): + from azext_prototype.stages.build_session import BuildSession + + stages = [ + {"stage": 1, "name": "Foundation", "services": ["vnet"]}, + ] + original_len = len(stages) + BuildSession._ensure_networking_in_map(stages) + assert len(stages) == original_len + + +# ------------------------------------------------------------------ +# BuildResult +# ------------------------------------------------------------------ + + +class TestBuildResult: + def test_defaults(self): + from azext_prototype.stages.build_session import BuildResult + + result = BuildResult() + assert result.files_generated == [] + assert result.deployment_stages == [] + assert result.policy_overrides == [] + assert result.resources == [] + assert result.review_accepted is False + assert result.cancelled is False + + def test_cancelled(self): + from azext_prototype.stages.build_session import BuildResult + + result = BuildResult(cancelled=True) + assert result.cancelled is True + + +# ------------------------------------------------------------------ +# _get_app_scaffolding_requirements +# ------------------------------------------------------------------ + + +class TestGetAppScaffoldingRequirements: + def test_non_app_layer_returns_empty(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._get_app_scaffolding_requirements({"layer": "infra"}) == "" + + def test_app_layer_generic_fallback(self): + from azext_prototype.stages.build_session import BuildSession + + stage = { + "layer": "app", + "services": [{"name": "custom", "resource_type": "", "sku": ""}], + "dir": "concept/apps/stage-5", + } + result = BuildSession._get_app_scaffolding_requirements(stage) + assert "Required Project Files" in result + + def test_app_layer_python_detected(self): + from azext_prototype.stages.build_session import BuildSession + + stage = { + "layer": "app", + "services": [{"name": "python-api", "resource_type": "", "sku": ""}], + "dir": "concept/apps/stage-5-python", + } + result = BuildSession._get_app_scaffolding_requirements(stage) + assert "Python" in result or "requirements.txt" in result + + +# ------------------------------------------------------------------ +# Naming strategy fallback (lines 244-246) +# ------------------------------------------------------------------ + + +class TestNamingStrategyFallback: + """Tests for naming strategy graceful fallback when config is bad.""" + + def test_naming_fallback_on_bad_config(self, project_with_design, sample_config): + """When create_naming_strategy raises, session falls back to simple strategy.""" + from azext_prototype.stages.build_session import BuildSession + + provider = MagicMock() + provider.provider_name = "github-models" + provider.chat.return_value = MagicMock(content="ok", model="test", usage={}, finish_reason="stop") + ctx = AgentContext( + project_config=sample_config, + project_dir=str(project_with_design), + ai_provider=provider, + ) + + registry = MagicMock() + registry.find_by_capability.return_value = [] + + # Corrupt the config so naming strategy fails on first try + with patch("azext_prototype.stages.build_session.create_naming_strategy") as mock_naming: + call_count = [0] + + def side_effect(cfg): + call_count[0] += 1 + if call_count[0] == 1: + raise ValueError("bad config") + # Second call is fallback + from azext_prototype.naming import create_naming_strategy as real_create + + return real_create(cfg) + + mock_naming.side_effect = side_effect + session = BuildSession(ctx, registry) + assert session._naming is not None + + +# ------------------------------------------------------------------ +# Policy resolver regeneration path (lines 685-722) +# ------------------------------------------------------------------ + + +class TestPolicyRegenPath: + """Tests for the policy resolver triggering regeneration.""" + + def test_policy_regen_executes_with_fix_instructions(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "Test arch"} + + stage = _make_pending_stage(1, "Key Vault", layer="data", capability="data") + session._build_state.set_deployment_plan([stage]) + session._build_state.set_design_snapshot(design) + + regen_response = MagicMock(content="```main.tf\nfixed\n```", model="test", usage={}) + fix_instructions = "\n## Fix\nFix the SKU" + + # Track whether the regen path was exercised + regen_called = [] + + def mock_check_and_resolve(*args, **kwargs): + # First call: needs regen; subsequent calls: no regen + if not regen_called: + regen_called.append(True) + return (["override sku"], True) + return ([], False) + + session._policy_resolver.check_and_resolve = mock_check_and_resolve + session._policy_resolver.build_fix_instructions = MagicMock(return_value=fix_instructions) + + with patch.object(session, "_build_stage_task", return_value=(MagicMock(name="tf"), "task")): + with patch.object(session, "_execute_with_retry", return_value=regen_response) as mock_retry: + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + with patch.object(session, "_apply_stage_transforms", return_value=["main.tf"]): + session._run_stage_qa = lambda *a, **kw: True + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + assert len(regen_called) > 0, "Policy resolver should have triggered regeneration" + # The retry was called twice (original + regen) + assert mock_retry.call_count >= 2 + + def test_policy_regen_exception_routes_to_qa(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "Test arch"} + + stage = _make_pending_stage(1, "Key Vault", layer="data", capability="data") + session._build_state.set_deployment_plan([stage]) + session._build_state.set_design_snapshot(design) + + # First policy check triggers regen (needs_regen=True) + session._policy_resolver.check_and_resolve = MagicMock(return_value=(["issue"], True)) + session._policy_resolver.build_fix_instructions = MagicMock(return_value="\nfix") + + original_response = MagicMock(content="```main.tf\nok\n```", model="test", usage={}) + + with patch.object(session, "_build_stage_task", return_value=(MagicMock(name="tf"), "task")): + # First call: original generation; second call: regen throws + with patch.object(session, "_execute_with_retry", side_effect=[original_response, RuntimeError("boom")]): + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + with patch.object(session, "_apply_stage_transforms", return_value=["main.tf"]): + session._run_stage_qa = lambda *a, **kw: True + with patch("azext_prototype.stages.build_session.route_error_to_qa") as mock_qa_route: + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + # QA route was called for the regen exception + assert mock_qa_route.called + + def test_policy_regen_null_response_stops_build(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "Test arch"} + + stage = _make_pending_stage(1, "Key Vault", layer="data", capability="data") + session._build_state.set_deployment_plan([stage]) + session._build_state.set_design_snapshot(design) + + call_count = [0] + + def mock_check_and_resolve(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return ([], False) + return (["issue"], True) + + session._policy_resolver.check_and_resolve = mock_check_and_resolve + session._policy_resolver.build_fix_instructions = MagicMock(return_value="\nfix") + + original_response = MagicMock(content="```main.tf\nok\n```", model="test", usage={}) + + with patch.object(session, "_build_stage_task", return_value=(MagicMock(name="tf"), "task")): + with patch.object(session, "_execute_with_retry", side_effect=[original_response, None]): + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + with patch.object(session, "_apply_stage_transforms", return_value=["main.tf"]): + session._run_stage_qa = lambda *a, **kw: True + result = session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + # Build stopped because regen returned None + assert result is not None + + +# ------------------------------------------------------------------ +# Review loop / interactive rebuild (lines 863-918) +# ------------------------------------------------------------------ + + +class TestReviewLoop: + """Tests for the Phase 6 review loop.""" + + def test_review_loop_regenerates_affected_stage(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "Test arch"} + + stage = { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "services": [{"name": "key-vault", "computed_name": "kv-test", "resource_type": "", "sku": ""}], + "status": "generated", + "dir": "concept/infra/terraform/stage-1-key-vault", + "files": ["main.tf"], + } + session._build_state.set_deployment_plan([stage]) + session._build_state.set_design_snapshot(design) + + regen_response = MagicMock(content="```main.tf\nfixed\n```", model="test", usage={}) + + inputs = iter(["Fix the key vault SKU", "done"]) + + session._identify_affected_stages = MagicMock(return_value=[1]) + + mock_agent = MagicMock() + mock_agent.name = "terraform-agent" + + with patch.object(session, "_build_stage_task", return_value=(mock_agent, "task")): + with patch.object(session, "_execute_with_continuation", return_value=regen_response): + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + with patch.object(session, "_apply_stage_transforms", return_value=["main.tf"]): + result = session.run( + design=design, + input_fn=lambda p: next(inputs), + print_fn=lambda m: None, + ) + + assert result.review_accepted is True + # Stage should be marked accepted after done + final_stage = session._build_state._state["deployment_stages"][0] + assert final_stage["status"] == "accepted" + + def test_review_loop_no_affected_stages(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "Test arch"} + + stage = { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "services": [], + "status": "generated", + "dir": "concept/infra/terraform/stage-1-key-vault", + "files": ["main.tf"], + } + session._build_state.set_deployment_plan([stage]) + session._build_state.set_design_snapshot(design) + + printed = [] + inputs = iter(["something vague", "done"]) + + session._identify_affected_stages = MagicMock(return_value=[]) + + result = session.run( + design=design, + input_fn=lambda p: next(inputs), + print_fn=lambda m: printed.append(m), + ) + + assert any("Could not determine" in msg for msg in printed) + assert result.review_accepted is True + + def test_review_loop_quit_cancels(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "Test arch"} + + stage = { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "services": [], + "status": "generated", + "dir": "concept/infra/terraform/stage-1-key-vault", + "files": ["main.tf"], + } + session._build_state.set_deployment_plan([stage]) + session._build_state.set_design_snapshot(design) + + result = session.run( + design=design, + input_fn=lambda p: "quit", + print_fn=lambda m: None, + ) + + assert result.cancelled is True + + def test_review_loop_slash_command(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "Test arch"} + + stage = { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "services": [], + "status": "generated", + "dir": "concept/infra/terraform/stage-1-key-vault", + "files": ["main.tf"], + } + session._build_state.set_deployment_plan([stage]) + session._build_state.set_design_snapshot(design) + + printed = [] + inputs = iter(["/help", "done"]) + + result = session.run( + design=design, + input_fn=lambda p: next(inputs), + print_fn=lambda m: printed.append(m), + ) + + assert any("Available commands" in msg for msg in printed) + assert result.review_accepted is True + + def test_review_loop_eof_breaks(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "Test arch"} + + stage = { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "services": [], + "status": "generated", + "dir": "concept/infra/terraform/stage-1-key-vault", + "files": ["main.tf"], + } + session._build_state.set_deployment_plan([stage]) + session._build_state.set_design_snapshot(design) + + def raise_eof(p): + raise EOFError() + + result = session.run( + design=design, + input_fn=raise_eof, + print_fn=lambda m: None, + ) + + assert result.review_accepted is True + + +# ------------------------------------------------------------------ +# Fallback deployment plan with templates (lines 1357-1430) +# ------------------------------------------------------------------ + + +class TestFallbackDeploymentPlan: + """Tests for _fallback_deployment_plan with template services.""" + + def test_fallback_with_templates(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + # Create mock templates with services + mock_svc_infra = MagicMock() + mock_svc_infra.name = "key-vault" + mock_svc_infra.type = "key-vault" + mock_svc_infra.tier = "Standard" + mock_svc_infra.config = {} + + mock_svc_data = MagicMock() + mock_svc_data.name = "cosmos-db" + mock_svc_data.type = "cosmos-db" + mock_svc_data.tier = "Serverless" + mock_svc_data.config = {} + + mock_svc_app = MagicMock() + mock_svc_app.name = "python-api" + mock_svc_app.type = "python-app" + mock_svc_app.tier = "Standard" + mock_svc_app.config = {} + + mock_template = MagicMock() + mock_template.name = "web-app" + mock_template.display_name = "Web Application" + mock_template.services = [mock_svc_infra, mock_svc_data, mock_svc_app] + + stages = session._fallback_deployment_plan([mock_template]) + + # Should have managed identity + infra + data + app + docs stages + assert len(stages) >= 5 + + # First stage should be Managed Identity + assert stages[0]["name"] == "Managed Identity" + assert stages[0]["layer"] == "core" + + # Last stage should be Documentation + assert stages[-1]["name"] == "Documentation" + assert stages[-1]["layer"] == "docs" + + # Infra stage for container-registry + infra_stages = [s for s in stages if s["layer"] == "infra"] + assert len(infra_stages) >= 1 + + # Data stage for cosmos-db + data_stages = [s for s in stages if s["layer"] == "data"] + assert len(data_stages) >= 1 + + # App stage for python-api + app_stages = [s for s in stages if s["layer"] == "app"] + assert len(app_stages) >= 1 + + def test_fallback_without_templates(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stages = session._fallback_deployment_plan([]) + + # Only managed identity + documentation + assert len(stages) == 2 + assert stages[0]["name"] == "Managed Identity" + assert stages[-1]["name"] == "Documentation" + + +# ------------------------------------------------------------------ +# Ensure private endpoint stage (lines 1470-1540) +# ------------------------------------------------------------------ + + +class TestEnsurePrivateEndpointStage: + """Tests for _ensure_private_endpoint_stage.""" + + def test_skips_when_network_stage_exists(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stages = [ + { + "stage": 1, + "name": "Networking", + "layer": "infra", + "services": [{"name": "virtual-network", "resource_type": "Microsoft.Network/virtualNetworks"}], + }, + { + "stage": 2, + "name": "Key Vault", + "layer": "data", + "services": [{"name": "key-vault", "resource_type": "Microsoft.KeyVault/vaults"}], + }, + ] + result = session._ensure_private_endpoint_stage(stages) + assert len(result) == 2 # No change + + def test_skips_when_service_has_network_resource(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stages = [ + { + "stage": 1, + "name": "Foundation", + "layer": "infra", + "services": [{"name": "vnet", "resource_type": "Microsoft.Network/virtualNetworks"}], + }, + ] + result = session._ensure_private_endpoint_stage(stages) + assert len(result) == 1 # No change + + def test_injects_networking_stage_when_pe_services_found(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stages = [ + { + "stage": 1, + "name": "Managed Identity", + "layer": "core", + "services": [], + "dir": "concept/infra/terraform/stage-1-managed-identity", + }, + { + "stage": 2, + "name": "Key Vault", + "layer": "data", + "services": [{"name": "key-vault", "resource_type": "Microsoft.KeyVault/vaults"}], + "dir": "concept/infra/terraform/stage-2-key-vault", + }, + ] + + mock_pe = MagicMock() + mock_pe.service_name = "key-vault" + + with patch( + "azext_prototype.stages.build_session.BuildSession._ensure_private_endpoint_stage", + wraps=session._ensure_private_endpoint_stage, + ): + with patch( + "azext_prototype.knowledge.resource_metadata.get_private_endpoint_services", + return_value=[mock_pe], + ): + result = session._ensure_private_endpoint_stage(stages) + + # Should have injected a networking stage + assert len(result) == 3 + net_stage = result[1] + assert net_stage["name"] == "Networking" + assert any("virtual-network" in s["name"] for s in net_stage["services"]) + + def test_no_injection_when_pe_services_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stages = [ + { + "stage": 1, + "name": "Managed Identity", + "layer": "core", + "services": [], + "dir": "concept/infra/terraform/stage-1-managed-identity", + }, + ] + + with patch( + "azext_prototype.knowledge.resource_metadata.get_private_endpoint_services", + return_value=[], + ): + result = session._ensure_private_endpoint_stage(stages) + + assert len(result) == 1 + + def test_exception_in_pe_lookup_returns_stages_unchanged(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stages = [ + { + "stage": 1, + "name": "Managed Identity", + "layer": "core", + "services": [], + "dir": "concept/infra/terraform/stage-1-managed-identity", + }, + ] + + with patch( + "azext_prototype.knowledge.resource_metadata.get_private_endpoint_services", + side_effect=ImportError("no module"), + ): + result = session._ensure_private_endpoint_stage(stages) + + assert len(result) == 1 + + +# ------------------------------------------------------------------ +# _diff_architectures / _parse_diff_result (lines 1590-1613, 1698-1719) +# ------------------------------------------------------------------ + + +class TestDiffArchitectures: + """Tests for architecture diffing and response parsing.""" + + def test_diff_returns_fallback_without_architect(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._architect_agent = None + + existing = [{"stage": 1, "name": "A"}, {"stage": 2, "name": "B"}] + result = session._diff_architectures("old", "new", existing) + + assert result["modified"] == [1, 2] + assert result["unchanged"] == [] + + def test_diff_parses_valid_response(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + diff_json = ( + '{"unchanged": [1], "modified": [2], "removed": [], ' + '"added": [], "plan_restructured": false, "summary": "Stage 2 modified."}' + ) + mock_response = MagicMock(content=diff_json, model="test", usage={}) + session._architect_agent = MagicMock() + session._architect_agent.execute.return_value = mock_response + session._architect_agent.name = "cloud-architect" + + existing = [{"stage": 1, "name": "A"}, {"stage": 2, "name": "B"}] + result = session._diff_architectures("old arch", "new arch", existing) + + assert 1 in result["unchanged"] + assert 2 in result["modified"] + + def test_diff_fallback_on_exception(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._architect_agent = MagicMock() + session._architect_agent.execute.side_effect = RuntimeError("boom") + session._architect_agent.name = "cloud-architect" + + existing = [{"stage": 1, "name": "A"}] + result = session._diff_architectures("old", "new", existing) + + assert result["modified"] == [1] + + def test_parse_diff_result_with_fenced_json(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + content = ( + '```json\n{"unchanged": [1], "modified": [2], "removed": [], ' + '"added": [], "plan_restructured": false, "summary": "ok"}\n```' + ) + existing = [{"stage": 1, "name": "A"}, {"stage": 2, "name": "B"}] + + result = session._parse_diff_result(content, existing) + assert result is not None + assert 1 in result["unchanged"] + assert 2 in result["modified"] + + def test_parse_diff_result_invalid_json(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + result = session._parse_diff_result("not json at all", []) + assert result is None + + def test_parse_diff_result_not_dict(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + result = session._parse_diff_result("[1, 2, 3]", []) + assert result is None + + def test_parse_diff_unmentioned_stages_default_unchanged(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + content = '{"unchanged": [], "modified": [2], "removed": []}' + existing = [{"stage": 1, "name": "A"}, {"stage": 2, "name": "B"}, {"stage": 3, "name": "C"}] + + result = session._parse_diff_result(content, existing) + assert result is not None + # Stage 1 and 3 not mentioned → unchanged + assert 1 in result["unchanged"] + assert 3 in result["unchanged"] + + +# ------------------------------------------------------------------ +# _adjust_plan (lines 1590-1613) +# ------------------------------------------------------------------ + + +class TestAdjustPlan: + """Tests for the _adjust_plan method.""" + + def test_adjust_returns_none_without_architect(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._architect_agent = None + + result = session._adjust_plan("add redis", "arch", []) + assert result is None + + def test_adjust_returns_none_without_ai_provider(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._context.ai_provider = None + + result = session._adjust_plan("add redis", "arch", []) + assert result is None + + def test_adjust_returns_parsed_plan(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + plan_json = ( + '```json\n{"stages": [{"stage": 1, "name": "Redis", "layer": "data", ' + '"capability": "data", "dir": "concept/infra/terraform/stage-1-redis", ' + '"services": [], "status": "pending", "files": []}]}\n```' + ) + session._architect_agent = MagicMock() + session._architect_agent.execute.return_value = MagicMock(content=plan_json, model="test", usage={}) + session._architect_agent.name = "cloud-architect" + + result = session._adjust_plan("add redis", "arch", []) + assert result is not None + assert len(result) >= 1 + + def test_adjust_returns_none_on_empty_response(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._architect_agent = MagicMock() + session._architect_agent.execute.return_value = MagicMock(content="", model="test", usage={}) + session._architect_agent.name = "cloud-architect" + + result = session._adjust_plan("add redis", "arch", []) + assert result is None + + +# ------------------------------------------------------------------ +# _apply_stage_transforms debug logging (lines 2698-2708) +# ------------------------------------------------------------------ + + +class TestApplyStageTransformsDebug: + """Tests for _apply_stage_transforms debug path.""" + + def test_transforms_no_changes(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1" + stage_dir.mkdir(parents=True, exist_ok=True) + main_tf = stage_dir / "main.tf" + main_tf.write_text("resource {}", encoding="utf-8") + + stage = { + "stage": 1, + "name": "Test", + "services": [{"resource_type": "Microsoft.KeyVault/vaults"}], + } + + rel_path = str(main_tf.relative_to(project_dir)) + + with patch("azext_prototype.governance.transforms.apply", return_value=("resource {}", [])): + result = session._apply_stage_transforms(stage, [rel_path], lambda m: None) + + # Returns the same paths (no transforms applied) + assert result == [rel_path] + + def test_transforms_debug_log_assembles_files(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1" + stage_dir.mkdir(parents=True, exist_ok=True) + main_tf = stage_dir / "main.tf" + main_tf.write_text('resource "test" {}', encoding="utf-8") + + stage = { + "stage": 1, + "name": "Test", + "services": [], + } + + rel_path = str(main_tf.relative_to(project_dir)) + + with patch("azext_prototype.governance.transforms.apply", return_value=('resource "test" {}', [])): + with patch("azext_prototype.debug_log.is_active", return_value=True): + with patch("azext_prototype.debug_log.log_flow") as mock_dbg: + result = session._apply_stage_transforms(stage, [rel_path], lambda m: None) + + assert result == [rel_path] + # Debug log should have been called with post-transform info + assert mock_dbg.called + + def test_transforms_empty_paths_returns_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage = {"stage": 1, "name": "Test", "services": []} + result = session._apply_stage_transforms(stage, [], lambda m: None) + assert result == [] + + +# ------------------------------------------------------------------ +# QA remediation write-back (lines 3309-3322, 3358-3411) +# ------------------------------------------------------------------ + + +class TestQaRemediationWriteBack: + """Tests for QA review retry logic including rate limit and timeout handling.""" + + def test_qa_rate_limit_retries(self, build_context, build_registry): + from azext_prototype.ai.copilot_provider import CopilotRateLimitError + + session = _make_session(build_context, build_registry) + + stage = { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "services": [], + "files": ["main.tf"], + } + + # Create file on disk + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1-key-vault" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text("resource {}", encoding="utf-8") + stage["files"] = [str((stage_dir / "main.tf").relative_to(project_dir))] + + session._build_state.set_deployment_plan([stage]) + + call_count = [0] + + def mock_qa_execute(ctx, task): + call_count[0] += 1 + if call_count[0] == 1: + raise CopilotRateLimitError("rate limited", retry_after=1) + return _make_response("VERDICT: PASS") + + session._qa_agent.execute = mock_qa_execute + with patch.object(session, "_countdown"): + passed = session._run_stage_qa(stage, "arch", [], False, lambda m: None) + + assert passed is True + assert call_count[0] >= 2 + + def test_qa_timeout_exhausts_retries(self, build_context, build_registry): + from azext_prototype.ai.copilot_provider import CopilotTimeoutError + + session = _make_session(build_context, build_registry) + + stage = { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "services": [], + "files": ["main.tf"], + } + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1-key-vault" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text("resource {}", encoding="utf-8") + stage["files"] = [str((stage_dir / "main.tf").relative_to(project_dir))] + + session._build_state.set_deployment_plan([stage]) + + session._qa_agent.execute = MagicMock(side_effect=CopilotTimeoutError("timeout")) + with patch.object(session, "_countdown"): + passed = session._run_stage_qa(stage, "arch", [], False, lambda m: None) + + assert passed is False + + def test_qa_remediation_cycle(self, build_context, build_registry): + """QA finds issues, remediates, then passes.""" + session = _make_session(build_context, build_registry) + + stage = { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "services": [{"name": "key-vault", "resource_type": "Microsoft.KeyVault/vaults"}], + "files": ["main.tf"], + } + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1-key-vault" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text("resource {}", encoding="utf-8") + stage["files"] = [str((stage_dir / "main.tf").relative_to(project_dir))] + + session._build_state.set_deployment_plan([stage]) + + call_count = [0] + + def mock_qa_execute(ctx, task): + call_count[0] += 1 + if call_count[0] == 1: + # First QA call: issues found + return _make_response("VERDICT: FAIL\nCRITICAL: missing auth") + # Second QA call: pass + return _make_response("VERDICT: PASS") + + regen_response = _make_response("```main.tf\nfixed\n```") + + mock_iac_agent = MagicMock() + mock_iac_agent.name = "terraform-agent" + + session._qa_agent.execute = mock_qa_execute + with patch.object(session, "_select_agent", return_value=mock_iac_agent): + with patch.object(session, "_build_stage_task", return_value=(mock_iac_agent, "task")): + with patch.object(session, "_execute_with_retry", return_value=regen_response): + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + with patch.object(session, "_apply_stage_transforms", return_value=["main.tf"]): + passed = session._run_stage_qa(stage, "arch", [], False, lambda m: None) + + assert passed is True + assert call_count[0] >= 2 + + +# ------------------------------------------------------------------ +# QA remediation status transitions +# ------------------------------------------------------------------ + + +class TestQaRemediationStatusTransition: + """After QA remediation, stage must stay 'validating' until QA passes. + + Bug: mark_stage_generated() was called after each remediation attempt, + leaving the stage as 'generated' even when QA subsequently failed. + On re-entry, 'generated' stages are skipped instead of retried. + """ + + def test_qa_remediation_failure_leaves_stage_validating(self, build_context, build_registry): + """After exhausted QA remediation, stage status must be 'validating' for re-entry.""" + session = _make_session(build_context, build_registry) + + stage = { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "services": [{"name": "key-vault", "resource_type": "Microsoft.KeyVault/vaults"}], + "files": ["main.tf"], + } + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1-key-vault" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text("resource {}", encoding="utf-8") + stage["files"] = [str((stage_dir / "main.tf").relative_to(project_dir))] + + session._build_state.set_deployment_plan([stage]) + + # QA always returns FAIL — exhausts all remediation attempts + session._qa_agent.execute = MagicMock( + return_value=_make_response("CRITICAL: This will never be fixed.") + ) + + mock_iac_agent = MagicMock() + mock_iac_agent.name = "terraform-agent" + + with patch.object(session, "_select_agent", return_value=mock_iac_agent): + with patch.object(session, "_build_stage_task", return_value=(mock_iac_agent, "task")): + with patch.object(session, "_execute_with_retry", return_value=_make_response("```main.tf\nfixed\n```")): + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + with patch.object(session, "_apply_stage_transforms", return_value=["main.tf"]): + passed = session._run_stage_qa(stage, "arch", [], False, lambda m: None) + + assert passed is False + actual_status = session._build_state._state["deployment_stages"][0]["status"] + assert actual_status == "validating", ( + f"After exhausted QA remediation, stage should be 'validating' for re-entry, got '{actual_status}'" + ) + + def test_reentry_after_qa_failure_retries_qa(self, build_context, build_registry): + """Re-running build after QA failure must re-attempt QA on the 'validating' stage.""" + session = _make_session(build_context, build_registry) + design = {"architecture": "Test"} + + session._build_state.set_deployment_plan([ + _make_validating_stage(1, "Key Vault", layer="data", files=["main.tf"]) + ]) + session._build_state.set_design_snapshot(design) + + qa_called = [] + + def mock_qa(*args, **kwargs): + qa_called.append(True) + return True + + session._run_stage_qa = mock_qa + + session.run(design=design, input_fn=lambda p: "done", print_fn=lambda m: None) + + assert len(qa_called) > 0, "Re-entry must re-run QA on a 'validating' stage" + + def test_qa_remediation_marks_validating_between_attempts(self, build_context, build_registry): + """After remediation writes files, status must be 'validating' before next QA check.""" + session = _make_session(build_context, build_registry) + + stage = { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "services": [{"name": "key-vault", "resource_type": "Microsoft.KeyVault/vaults"}], + "files": ["main.tf"], + } + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1-key-vault" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text("resource {}", encoding="utf-8") + stage["files"] = [str((stage_dir / "main.tf").relative_to(project_dir))] + + session._build_state.set_deployment_plan([stage]) + + # Track which mark_stage_* calls happen inside the remediation loop + status_calls = [] + original_validating = session._build_state.mark_stage_validating + original_generated = session._build_state.mark_stage_generated + + def spy_validating(sn, files): + original_validating(sn, files) + status_calls.append(("validating", sn)) + + def spy_generated(sn, files, agent): + original_generated(sn, files, agent) + status_calls.append(("generated", sn)) + + session._build_state.mark_stage_validating = spy_validating + session._build_state.mark_stage_generated = spy_generated + + call_count = [0] + + def mock_qa_execute(ctx, task): + call_count[0] += 1 + if call_count[0] == 1: + return _make_response("VERDICT: FAIL\nCRITICAL: missing auth") + return _make_response("VERDICT: PASS") + + mock_iac_agent = MagicMock() + mock_iac_agent.name = "terraform-agent" + + session._qa_agent.execute = mock_qa_execute + + with patch.object(session, "_select_agent", return_value=mock_iac_agent): + with patch.object(session, "_build_stage_task", return_value=(mock_iac_agent, "task")): + with patch.object(session, "_execute_with_retry", return_value=_make_response("```main.tf\nfixed\n```")): + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + with patch.object(session, "_apply_stage_transforms", return_value=["main.tf"]): + passed = session._run_stage_qa(stage, "arch", [], False, lambda m: None) + + assert passed is True + # The mark call inside the remediation loop must be "validating", not "generated" + assert any(status == "validating" for status, _ in status_calls), ( + f"Remediation loop must call mark_stage_validating, not mark_stage_generated. " + f"Calls observed: {status_calls}" + ) + + +# ------------------------------------------------------------------ +# _generate_stage_advisory (lines 3458-3503) +# ------------------------------------------------------------------ + + +class TestGenerateStageAdvisory: + """Tests for per-stage advisory generation.""" + + def test_advisory_returns_empty_without_advisor(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._advisor_agent = None + result = session._generate_stage_advisory({"stage": 1, "name": "Test", "files": ["a.tf"]}, lambda m: None) + assert result == "" + + def test_advisory_skips_docs_layer(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._advisor_agent = MagicMock() + result = session._generate_stage_advisory( + {"stage": 1, "name": "Docs", "layer": "docs", "files": ["README.md"]}, lambda m: None + ) + assert result == "" + + def test_advisory_skips_empty_files(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._advisor_agent = MagicMock() + result = session._generate_stage_advisory( + {"stage": 1, "name": "Test", "layer": "infra", "files": []}, lambda m: None + ) + assert result == "" + + def test_advisory_returns_content(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text("resource {} {}", encoding="utf-8") + + rel_path = str((stage_dir / "main.tf").relative_to(project_dir)) + + session._advisor_agent = MagicMock() + session._advisor_agent.name = "advisor" + + advisory_text = "Consider upgrading to Premium SKU for production." + + with patch("azext_prototype.agents.orchestrator.AgentOrchestrator") as MockOrch: + mock_orch = MockOrch.return_value + mock_orch.delegate.return_value = MagicMock(content=advisory_text, model="test", usage={}) + result = session._generate_stage_advisory( + {"stage": 1, "name": "Key Vault", "layer": "data", "files": [rel_path]}, lambda m: None + ) + + assert result == advisory_text + + def test_advisory_handles_exception(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text("resource {} {}", encoding="utf-8") + + rel_path = str((stage_dir / "main.tf").relative_to(project_dir)) + + session._advisor_agent = MagicMock() + session._advisor_agent.name = "advisor" + + with patch("azext_prototype.agents.orchestrator.AgentOrchestrator") as MockOrch: + mock_orch = MockOrch.return_value + mock_orch.delegate.side_effect = RuntimeError("boom") + result = session._generate_stage_advisory( + {"stage": 1, "name": "Key Vault", "layer": "data", "files": [rel_path]}, lambda m: None + ) + + assert result == "" + + +# ------------------------------------------------------------------ +# _execute_with_retry (lines 3537-3549) +# ------------------------------------------------------------------ + + +class TestExecuteWithRetry: + """Tests for _execute_with_retry timeout/rate-limit backoff.""" + + def test_retry_success_on_first_attempt(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + mock_agent = MagicMock() + mock_response = MagicMock(content="ok", model="test", usage={}) + + with patch.object(session, "_execute_with_continuation", return_value=mock_response): + result = session._execute_with_retry(mock_agent, "task", 1, "Stage", lambda m: None) + + assert result is mock_response + + def test_retry_on_rate_limit(self, build_context, build_registry): + from azext_prototype.ai.copilot_provider import CopilotRateLimitError + + session = _make_session(build_context, build_registry) + mock_agent = MagicMock() + mock_response = MagicMock(content="ok", model="test", usage={}) + + call_count = [0] + + def side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise CopilotRateLimitError("limited", retry_after=1) + return mock_response + + with patch.object(session, "_execute_with_continuation", side_effect=side_effect): + with patch.object(session, "_countdown"): + result = session._execute_with_retry(mock_agent, "task", 1, "Stage", lambda m: None) + + assert result is mock_response + assert call_count[0] == 2 + + def test_retry_on_timeout_eventually_returns_none(self, build_context, build_registry): + from azext_prototype.ai.copilot_provider import CopilotTimeoutError + + session = _make_session(build_context, build_registry) + mock_agent = MagicMock() + + with patch.object(session, "_execute_with_continuation", side_effect=CopilotTimeoutError("timeout")): + with patch.object(session, "_countdown"): + printed = [] + result = session._execute_with_retry(mock_agent, "task", 1, "Stage", lambda m: printed.append(m)) + + assert result is None + assert any("timed out" in msg for msg in printed) + + def test_retry_rate_limit_uses_retry_after(self, build_context, build_registry): + from azext_prototype.ai.copilot_provider import CopilotRateLimitError + + session = _make_session(build_context, build_registry) + mock_agent = MagicMock() + mock_response = MagicMock(content="ok", model="test", usage={}) + + def side_effect(*args, **kwargs): + raise CopilotRateLimitError("limited", retry_after=42) + + countdown_calls = [] + + def mock_countdown(seconds, *a, **kw): + countdown_calls.append(seconds) + + call_count = [0] + + def exec_side(*args, **kwargs): + call_count[0] += 1 + if call_count[0] <= 2: + raise CopilotRateLimitError("limited", retry_after=42) + return mock_response + + with patch.object(session, "_execute_with_continuation", side_effect=exec_side): + with patch.object(session, "_countdown", side_effect=mock_countdown): + result = session._execute_with_retry(mock_agent, "task", 1, "Stage", lambda m: None) + + assert result is mock_response + # countdown should have been called with 42 (retry_after value) + assert 42 in countdown_calls + + +# ------------------------------------------------------------------ +# _execute_with_continuation (lines 3579-3610) +# ------------------------------------------------------------------ + + +class TestExecuteWithContinuation: + """Tests for truncation recovery via continuation.""" + + def test_no_continuation_on_stop(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + mock_agent = MagicMock() + response = MagicMock(content="full output", finish_reason="stop", model="test", usage={}) + mock_agent.execute.return_value = response + + result = session._execute_with_continuation(mock_agent, "task") + assert result.content == "full output" + assert mock_agent.execute.call_count == 1 + + def test_continuation_on_length(self, build_context, build_registry): + from azext_prototype.ai.provider import AIResponse + + session = _make_session(build_context, build_registry) + mock_agent = MagicMock() + + first_response = AIResponse( + content="partial", + model="test", + usage={"prompt_tokens": 100, "completion_tokens": 200}, + finish_reason="length", + ) + second_response = AIResponse( + content=" continued", + model="test", + usage={"prompt_tokens": 50, "completion_tokens": 100}, + finish_reason="stop", + ) + mock_agent.execute.side_effect = [first_response, second_response] + + result = session._execute_with_continuation(mock_agent, "task") + assert result.content == "partial continued" + assert result.finish_reason == "stop" + assert mock_agent.execute.call_count == 2 + # Token usage should be merged + assert result.usage["prompt_tokens"] == 150 + assert result.usage["completion_tokens"] == 300 + + def test_continuation_with_stage_context(self, build_context, build_registry): + from azext_prototype.ai.provider import AIResponse + + session = _make_session(build_context, build_registry) + mock_agent = MagicMock() + + first_response = AIResponse( + content="partial code", + model="test", + usage={"prompt_tokens": 100}, + finish_reason="length", + ) + second_response = AIResponse( + content=" rest of code", + model="test", + usage={"prompt_tokens": 50}, + finish_reason="stop", + ) + mock_agent.execute.side_effect = [first_response, second_response] + + result = session._execute_with_continuation( + mock_agent, "task", stage_num=3, stage_name="Key Vault", stage_capability="data" + ) + assert result.content == "partial code rest of code" + # Conversation history should have continuation messages + assert len(session._context.conversation_history) >= 2 + + def test_continuation_max_limit(self, build_context, build_registry): + from azext_prototype.ai.provider import AIResponse + + session = _make_session(build_context, build_registry) + mock_agent = MagicMock() + + # All responses truncated + truncated = AIResponse( + content="chunk", + model="test", + usage={"prompt_tokens": 10}, + finish_reason="length", + ) + mock_agent.execute.return_value = truncated + + result = session._execute_with_continuation(mock_agent, "task", max_continuations=2) + # 1 original + 2 continuations = 3 calls + assert mock_agent.execute.call_count == 3 + # Content should be accumulated + assert "chunk" in result.content + + def test_continuation_none_response_breaks(self, build_context, build_registry): + from azext_prototype.ai.provider import AIResponse + + session = _make_session(build_context, build_registry) + mock_agent = MagicMock() + + first_response = AIResponse( + content="partial", + model="test", + usage={"prompt_tokens": 100}, + finish_reason="length", + ) + mock_agent.execute.side_effect = [first_response, None] + + result = session._execute_with_continuation(mock_agent, "task") + assert result.content == "partial" + + +# ------------------------------------------------------------------ +# _collect_generated_file_content (lines 3413-3439) +# ------------------------------------------------------------------ + + +class TestCollectGeneratedFileContent: + """Tests for collecting generated file content for QA.""" + + def test_collects_existing_files(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text("resource {}", encoding="utf-8") + + rel_path = str((stage_dir / "main.tf").relative_to(project_dir)) + + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Test", "layer": "infra", "status": "generated", "files": [rel_path]}, + ] + + content = session._collect_generated_file_content() + assert "resource {}" in content + assert "Stage 1: Test" in content + + def test_handles_missing_files(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + session._build_state._state["deployment_stages"] = [ + { + "stage": 1, + "name": "Test", + "layer": "infra", + "status": "generated", + "files": ["nonexistent/main.tf"], + }, + ] + + content = session._collect_generated_file_content() + assert "(could not read file)" in content + + def test_skips_stages_without_files(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Test", "layer": "infra", "status": "generated", "files": []}, + ] + + content = session._collect_generated_file_content() + assert content == "" + + +# ------------------------------------------------------------------ +# _categorize_service (static method) — additional coverage +# ------------------------------------------------------------------ + + +class TestCategorizeServiceExtended: + """Extended tests for service type categorization.""" + + def test_infra_types(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._categorize_service("key-vault") == "infra" + assert BuildSession._categorize_service("virtual-network") == "infra" + assert BuildSession._categorize_service("managed-identity") == "infra" + assert BuildSession._categorize_service("application-insights") == "infra" + + def test_data_types(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._categorize_service("cosmos-db") == "data" + assert BuildSession._categorize_service("sql-database") == "data" + assert BuildSession._categorize_service("redis-cache") == "data" + assert BuildSession._categorize_service("storage-account") == "data" + + def test_app_type_fallback(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._categorize_service("python-app") == "app" + assert BuildSession._categorize_service("container-registry") == "app" + + +# ------------------------------------------------------------------ +# _parse_deployment_plan (lines 1235-1236) +# ------------------------------------------------------------------ + + +class TestParseDeploymentPlanExtended: + """Extended tests for deployment plan JSON parsing.""" + + def test_parse_fenced_json(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + content = ( + '```json\n{"stages": [{"stage": 1, "name": "Test", ' + '"dir": "concept/infra/terraform/stage-1", "services": [], ' + '"capability": "infra"}]}\n```' + ) + result = session._parse_deployment_plan(content) + assert len(result) >= 1 + assert result[0]["name"] == "Test" + + def test_parse_raw_json(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + content = ( + '{"stages": [{"stage": 1, "name": "Test", ' + '"dir": "concept/infra/terraform/stage-1", "services": [], ' + '"capability": "infra"}]}' + ) + result = session._parse_deployment_plan(content) + assert len(result) >= 1 + + def test_parse_invalid_json_returns_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + result = session._parse_deployment_plan("not json at all") + assert result == [] + + def test_parse_fenced_bad_json_falls_back(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + content = "```json\n{bad json}\n```" + result = session._parse_deployment_plan(content) + assert result == [] + + +# ------------------------------------------------------------------ +# _build_docs_context (lines 3017-3045) +# ------------------------------------------------------------------ + + +class TestBuildDocsContext: + """Tests for documentation context builder.""" + + def test_returns_empty_when_no_generated(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "A", "status": "pending", "files": []}, + ] + assert session._build_docs_context() == "" + + def test_includes_output_keys(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1" + stage_dir.mkdir(parents=True, exist_ok=True) + outputs_tf = stage_dir / "outputs.tf" + outputs_tf.write_text( + 'output "vault_id" {\n description = "Key Vault ID"\n value = azapi_resource.vault.id\n}\n', + encoding="utf-8", + ) + + rel_path = str(outputs_tf.relative_to(project_dir)) + + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Key Vault", "status": "generated", "files": [rel_path], "layer": "data"}, + ] + + result = session._build_docs_context() + assert "vault_id" in result + assert "Key Vault ID" in result + + def test_lists_files_when_no_outputs(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1" + stage_dir.mkdir(parents=True, exist_ok=True) + main_tf = stage_dir / "main.tf" + main_tf.write_text("resource {}", encoding="utf-8") + + rel_path = str(main_tf.relative_to(project_dir)) + + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Test", "status": "generated", "files": [rel_path], "layer": "infra"}, + ] + + result = session._build_docs_context() + assert "main.tf" in result + + +# ------------------------------------------------------------------ +# _build_dns_zone_note (lines 3062-3085) +# ------------------------------------------------------------------ + + +class TestBuildDnsZoneNote: + """Tests for DNS zone note generation.""" + + def test_returns_empty_when_no_zones(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Test", "services": []}, + ] + + with patch( + "azext_prototype.knowledge.private_dns_zones.get_zones_for_services", + return_value={}, + ): + result = session._build_dns_zone_note() + assert result == "" + + def test_returns_zones_when_found(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Test", "services": [{"name": "kv"}]}, + ] + + zones = {"privatelink.vaultcore.azure.net": "Microsoft.KeyVault/vaults"} + + with patch( + "azext_prototype.knowledge.private_dns_zones.get_zones_for_services", + return_value=zones, + ): + result = session._build_dns_zone_note() + + assert "privatelink.vaultcore.azure.net" in result + assert "REQUIRED PRIVATE DNS ZONES" in result + + def test_exception_returns_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Test", "services": []}, + ] + + with patch( + "azext_prototype.knowledge.private_dns_zones.get_zones_for_services", + side_effect=ImportError("boom"), + ): + result = session._build_dns_zone_note() + assert result == "" + + +# ------------------------------------------------------------------ +# _get_networking_stage_note (lines 3092-3096) +# ------------------------------------------------------------------ + + +class TestGetNetworkingStageNote: + """Tests for networking stage QA note generation.""" + + def test_returns_note_when_networking_stage_has_pe_services(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + { + "stage": 2, + "name": "Networking", + "services": [ + {"name": "virtual-network"}, + {"name": "private-endpoint-keyvault"}, + ], + }, + ] + + result = session._get_networking_stage_note() + assert "CRITICAL: Networking Stage" in result + assert "private-endpoint-keyvault" in result + + def test_returns_empty_when_no_networking_stage(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Key Vault", "services": []}, + ] + + result = session._get_networking_stage_note() + assert result == "" + + def test_returns_empty_when_networking_has_no_pe_services(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + { + "stage": 2, + "name": "Networking", + "services": [{"name": "virtual-network"}], + }, + ] + + result = session._get_networking_stage_note() + assert result == "" + + +# ------------------------------------------------------------------ +# _extract_output_keys (lines 2969-2979) +# ------------------------------------------------------------------ + + +class TestExtractOutputKeys: + """Tests for extracting output key names from stage files.""" + + def test_extracts_terraform_output_keys(self, tmp_path): + from azext_prototype.stages.build_session import BuildSession + + outputs_file = tmp_path / "concept" / "infra" / "terraform" / "stage-1" / "outputs.tf" + outputs_file.parent.mkdir(parents=True, exist_ok=True) + outputs_file.write_text( + 'output "vault_id" {\n value = azapi_resource.vault.id\n}\n' + 'output "vault_uri" {\n value = azapi_resource.vault.properties.vaultUri\n}\n', + encoding="utf-8", + ) + + rel_path = str(outputs_file.relative_to(tmp_path)) + stage = {"files": [rel_path]} + + keys = BuildSession._extract_output_keys(stage, tmp_path) + assert "vault_id" in keys + assert "vault_uri" in keys + + def test_returns_empty_when_no_outputs_file(self, tmp_path): + from azext_prototype.stages.build_session import BuildSession + + stage = {"files": ["main.tf"]} + keys = BuildSession._extract_output_keys(stage, tmp_path) + assert keys == [] + + +# ------------------------------------------------------------------ +# Design change branch B (lines 356-379) +# ------------------------------------------------------------------ + + +class TestDesignChangeBranchB: + """Tests for re-entry when design has changed.""" + + def test_design_changed_restructured_quit(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "New arch"} + + session._build_state.set_deployment_plan([_make_pending_stage(1, "Test")]) + session._build_state.set_design_snapshot({"architecture": "Old arch"}) + + # Simulate design changed + session._build_state.design_has_changed = MagicMock(return_value=True) + session._build_state.get_previous_architecture = MagicMock(return_value="Old arch") + + diff_result = { + "unchanged": [], + "modified": [], + "removed": [], + "added": [], + "plan_restructured": True, + "summary": "Big changes.", + } + + with patch.object(session, "_diff_architectures", return_value=diff_result): + result = session.run( + design=design, + input_fn=lambda p: "quit", + print_fn=lambda m: None, + ) + + assert result.cancelled is True + + def test_design_changed_targeted_updates(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "New arch with redis"} + + stages = [ + _make_pending_stage(1, "Key Vault", layer="data", capability="data"), + _make_pending_stage(2, "Redis", layer="data", capability="data"), + ] + session._build_state.set_deployment_plan(stages) + session._build_state.set_design_snapshot({"architecture": "Old arch"}) + + session._build_state.design_has_changed = MagicMock(return_value=True) + session._build_state.get_previous_architecture = MagicMock(return_value="Old arch") + + diff_result = { + "unchanged": [1], + "modified": [2], + "removed": [], + "added": [], + "plan_restructured": False, + "summary": "Stage 2 modified.", + } + + with patch.object(session, "_diff_architectures", return_value=diff_result): + session._run_stage_qa = lambda *a, **kw: True + with patch.object(session, "_build_stage_task", return_value=(MagicMock(name="tf"), "task")): + with patch.object( + session, "_execute_with_retry", return_value=MagicMock(content="```main.tf\nok\n```") + ): + with patch.object(session, "_write_stage_files", return_value=["main.tf"]): + with patch.object(session, "_apply_stage_transforms", return_value=["main.tf"]): + result = session.run( + design=design, + input_fn=lambda p: "done", + print_fn=lambda m: None, + ) + + assert result is not None + + +# ------------------------------------------------------------------ +# _resolve_service_policies / _resolve_api_versions (lines 3190-3212) +# ------------------------------------------------------------------ + + +class TestResolveHelpers: + """Tests for service policy and API version resolution helpers.""" + + def test_resolve_service_policies_exception(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + with patch( + "azext_prototype.governance.policies.PolicyEngine.load", + side_effect=Exception("boom"), + ): + result = session._resolve_service_policies([{"name": "kv"}]) + assert result == "" + + def test_resolve_api_versions_exception(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + with patch( + "azext_prototype.knowledge.resource_metadata.resolve_resource_metadata", + side_effect=ImportError("boom"), + ): + result = session._resolve_api_versions([{"resource_type": "Microsoft.KeyVault/vaults"}]) + assert result == "" + + def test_resolve_api_versions_no_resource_types(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + result = session._resolve_api_versions([{"name": "kv"}]) + assert result == "" + + def test_resolve_companion_requirements_exception(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + with patch( + "azext_prototype.knowledge.resource_metadata.resolve_companion_requirements", + side_effect=ImportError("boom"), + ): + result = session._resolve_companion_requirements([{"resource_type": "Microsoft.KeyVault/vaults"}]) + assert result == "" + + +# ------------------------------------------------------------------ +# _infer_layer (static method) +# ------------------------------------------------------------------ + + +class TestInferLayerExtended: + """Extended tests for layer inference from stage data.""" + + def test_explicit_layer_returned(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._infer_layer({"layer": "data", "name": "test"}) == "data" + + def test_identity_name_returns_core(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._infer_layer({"name": "Managed Identity", "capability": "infra"}) == "core" + + def test_monitoring_name_returns_core(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._infer_layer({"name": "Log Analytics", "capability": "infra"}) == "core" + + def test_capability_mapping(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._infer_layer({"name": "Redis", "capability": "data"}) == "data" + assert BuildSession._infer_layer({"name": "API", "capability": "app"}) == "app" + + +# ------------------------------------------------------------------ +# _enforce_concept_prefix +# ------------------------------------------------------------------ + + +class TestEnforceConceptPrefixExtended: + """Extended tests for concept prefix enforcement in dirs.""" + + def test_already_has_concept_prefix(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + assert session._enforce_concept_prefix("concept/infra/terraform") == "concept/infra/terraform" + + def test_fixes_wrong_prefix(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + result = session._enforce_concept_prefix("output/infra/terraform/stage-1") + assert result.startswith("concept/") + + def test_single_subdir(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + result = session._enforce_concept_prefix("infra") + assert result == "concept/infra" + + def test_empty_string(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + assert session._enforce_concept_prefix("") == "" + + +# ------------------------------------------------------------------ +# _clean_removed_stage_files (lines 1759-1769) +# ------------------------------------------------------------------ + + +class TestCleanRemovedStageFiles: + """Tests for removing stage directories on disk.""" + + def test_removes_existing_directory(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-2-redis" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text("resource {}", encoding="utf-8") + + stages = [{"stage": 2, "dir": "concept/infra/terraform/stage-2-redis"}] + session._clean_removed_stage_files([2], stages) + + assert not stage_dir.exists() + + def test_ignores_nonexistent_directory(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stages = [{"stage": 3, "dir": "concept/infra/terraform/stage-3-nonexistent"}] + # Should not raise + session._clean_removed_stage_files([3], stages) + + +# ------------------------------------------------------------------ +# _fix_stage_dirs (lines 1771-1789) +# ------------------------------------------------------------------ + + +class TestFixStageDirs: + """Tests for post-renumber directory path fixing.""" + + def test_fix_renumbers_dirs(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "A", "dir": "concept/infra/terraform/stage-3-redis"}, + ] + + session._fix_stage_dirs() + + assert session._build_state._state["deployment_stages"][0]["dir"] == ("concept/infra/terraform/stage-1-redis") + + def test_fix_no_change_when_correct(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "A", "dir": "concept/infra/terraform/stage-1-redis"}, + ] + + session._fix_stage_dirs() + + assert session._build_state._state["deployment_stages"][0]["dir"] == ("concept/infra/terraform/stage-1-redis") + + +# ------------------------------------------------------------------ +# _identify_affected_stages / _identify_stages_regex / _identify_stages_via_architect +# ------------------------------------------------------------------ + + +class TestIdentifyAffectedStages: + """Tests for feedback-to-stage matching.""" + + def test_regex_explicit_stage_number(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Key Vault", "status": "generated", "services": [], "files": []}, + {"stage": 2, "name": "Redis", "status": "generated", "services": [], "files": []}, + ] + + result = session._identify_stages_regex("Please fix stage 2") + assert result == [2] + + def test_regex_stage_name_match(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Key Vault", "status": "generated", "services": [], "files": []}, + {"stage": 2, "name": "Redis", "status": "generated", "services": [], "files": []}, + ] + + result = session._identify_stages_regex("fix the key vault configuration") + assert result == [1] + + def test_regex_service_name_match(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + { + "stage": 1, + "name": "Cache", + "status": "generated", + "services": [{"name": "redis-cache"}], + "files": [], + }, + ] + + result = session._identify_stages_regex("update the redis-cache settings") + assert result == [1] + + def test_regex_fallback_returns_generated_and_accepted(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Aaa", "status": "generated", "services": [], "files": []}, + {"stage": 2, "name": "Bbb", "status": "accepted", "services": [], "files": []}, + ] + + # Feedback doesn't match any stage name, service, or number + result = session._identify_stages_regex("xyz unrelated text 999") + # Last resort: returns all generated+accepted stages + assert 1 in result + assert 2 in result + + def test_identify_via_architect(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Key Vault", "status": "generated", "services": [], "files": []}, + {"stage": 2, "name": "Redis", "status": "generated", "services": [], "files": []}, + ] + + session._architect_agent = MagicMock() + session._architect_agent.execute.return_value = MagicMock(content="[2]", model="test", usage={}) + session._architect_agent.name = "cloud-architect" + + result = session._identify_stages_via_architect("fix the caching layer") + assert result == [2] + + def test_identify_via_architect_exception_returns_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "A", "status": "generated", "services": [], "files": []}, + ] + + session._architect_agent = MagicMock() + session._architect_agent.execute.side_effect = RuntimeError("boom") + session._architect_agent.name = "cloud-architect" + + result = session._identify_stages_via_architect("fix something") + assert result == [] + + def test_identify_via_architect_empty_stages(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [] + + session._architect_agent = MagicMock() + + result = session._identify_stages_via_architect("fix something") + assert result == [] + + def test_identify_affected_uses_architect_then_regex(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Key Vault", "status": "generated", "services": [], "files": []}, + ] + + # Architect returns empty -> falls back to regex + session._architect_agent = MagicMock() + session._architect_agent.execute.return_value = MagicMock(content="[]", model="test", usage={}) + session._architect_agent.name = "cloud-architect" + + result = session._identify_affected_stages("fix the key vault") + assert result == [1] # Regex matched by name + + def test_parse_stage_numbers_valid(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._parse_stage_numbers("[1, 3, 5]") == [1, 3, 5] + + def test_parse_stage_numbers_embedded(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._parse_stage_numbers("The affected stages are: [2, 4]") == [2, 4] + + def test_parse_stage_numbers_invalid(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._parse_stage_numbers("no json here") == [] + + def test_parse_stage_numbers_bad_json(self): + from azext_prototype.stages.build_session import BuildSession + + assert BuildSession._parse_stage_numbers("[not, valid]") == [] + + +# ------------------------------------------------------------------ +# _handle_slash_command / _handle_describe +# ------------------------------------------------------------------ + + +class TestHandleSlashCommand: + """Tests for slash command handling.""" + + def test_status_command(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Test", "status": "generated", "files": []}, + ] + printed = [] + session._handle_slash_command("/status", printed.append) + assert len(printed) >= 1 -Covers all new build-stage modules introduced in the interactive build overhaul. -""" + def test_files_command(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + printed = [] + session._handle_slash_command("/files", printed.append) + assert len(printed) >= 1 -from __future__ import annotations + def test_policy_command(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + printed = [] + session._handle_slash_command("/policy", printed.append) + assert len(printed) >= 1 -import json -from pathlib import Path -from unittest.mock import MagicMock, patch + def test_describe_command(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "status": "generated", + "dir": "concept/infra/terraform/stage-1-kv", + "services": [ + { + "name": "key-vault", + "computed_name": "kv-test", + "resource_type": "Microsoft.KeyVault/vaults", + "sku": "Standard", + } + ], + "files": ["main.tf", "outputs.tf"], + }, + ] + printed = [] + session._handle_slash_command("/describe 1", printed.append) + assert any("Key Vault" in msg for msg in printed) + assert any("Microsoft.KeyVault/vaults" in msg for msg in printed) -import pytest -import yaml + def test_describe_no_arg(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + printed = [] + session._handle_describe("", printed.append) + assert any("Usage" in msg for msg in printed) -from azext_prototype.agents.base import AgentCapability, AgentContext + def test_describe_nonexistent_stage(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [] + printed = [] + session._handle_describe("99", printed.append) + assert any("not found" in msg for msg in printed) + + def test_help_command(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + printed = [] + session._handle_slash_command("/help", printed.append) + assert any("/status" in msg for msg in printed) + + +# ------------------------------------------------------------------ +# _derive_deployment_plan -- two-phase plan derivation +# ------------------------------------------------------------------ + + +class TestDeriveDeploymentPlan: + """Tests for _derive_deployment_plan two-phase AI flow.""" + + def test_fallback_without_architect(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._architect_agent = None + + result = session._derive_deployment_plan("architecture text", []) + # Fallback plan always has at least identity + docs + assert len(result) >= 2 + assert result[0]["name"] == "Managed Identity" + assert result[-1]["name"] == "Documentation" + + def test_fallback_without_ai_provider(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._context.ai_provider = None + + result = session._derive_deployment_plan("architecture text", []) + assert len(result) >= 2 + + def test_fallback_on_empty_phase1_response(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._architect_agent = MagicMock() + session._architect_agent.execute.return_value = MagicMock(content="", model="test", usage={}) + session._architect_agent.name = "cloud-architect" + session._architect_agent.set_governor_brief = MagicMock() + + result = session._derive_deployment_plan("architecture text", []) + # Falls back + assert len(result) >= 2 + assert result[0]["name"] == "Managed Identity" + + def test_fallback_on_unparseable_phase1(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._architect_agent = MagicMock() + session._architect_agent.execute.return_value = MagicMock(content="not json", model="test", usage={}) + session._architect_agent.name = "cloud-architect" + session._architect_agent.set_governor_brief = MagicMock() + + result = session._derive_deployment_plan("architecture text", []) + assert len(result) >= 2 + + def test_successful_two_phase(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + # Phase 1: map + phase1_content = json.dumps( + { + "stages": [ + { + "stage": 1, + "name": "Managed Identity", + "layer": "core", + "capability": "infra", + "services": ["managed-identity"], + }, + {"stage": 2, "name": "Key Vault", "layer": "data", "capability": "data", "services": ["key-vault"]}, + {"stage": 3, "name": "Documentation", "layer": "docs", "capability": "docs", "services": []}, + ] + } + ) + phase1_json = f"```json\n{phase1_content}\n```" + + # Phase 2: detailed + phase2_content = json.dumps( + { + "stages": [ + { + "stage": 1, + "name": "Managed Identity", + "layer": "core", + "capability": "infra", + "dir": "concept/infra/terraform/stage-1-managed-identity", + "services": [ + { + "name": "managed-identity", + "computed_name": "id-test", + "resource_type": "Microsoft.ManagedIdentity/userAssignedIdentities", + "sku": "", + } + ], + "status": "pending", + "files": [], + }, + { + "stage": 2, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "dir": "concept/infra/terraform/stage-2-key-vault", + "services": [ + { + "name": "key-vault", + "computed_name": "kv-test", + "resource_type": "Microsoft.KeyVault/vaults", + "sku": "Standard", + } + ], + "status": "pending", + "files": [], + }, + { + "stage": 3, + "name": "Documentation", + "layer": "docs", + "capability": "docs", + "dir": "concept/docs", + "services": [], + "status": "pending", + "files": [], + }, + ] + } + ) + phase2_json = f"```json\n{phase2_content}\n```" + + session._architect_agent = MagicMock() + session._architect_agent.execute.side_effect = [ + MagicMock(content=phase1_json, model="test", usage={}), + MagicMock(content=phase2_json, model="test", usage={}), + ] + session._architect_agent.name = "cloud-architect" + session._architect_agent.set_governor_brief = MagicMock() + + result = session._derive_deployment_plan("Build a web app with key vault", []) + assert len(result) >= 3 + assert result[0]["name"] == "Managed Identity" + assert any(s["name"] == "Key Vault" for s in result) + + def test_fallback_on_empty_phase2(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + phase1_json = ( + '```json\n{"stages": [' + '{"stage": 1, "name": "Test", "layer": "core",' + ' "capability": "infra", "services": ["id"]}' + "]}\n```" + ) + + session._architect_agent = MagicMock() + session._architect_agent.execute.side_effect = [ + MagicMock(content=phase1_json, model="test", usage={}), + MagicMock(content="", model="test", usage={}), + ] + session._architect_agent.name = "cloud-architect" + session._architect_agent.set_governor_brief = MagicMock() + + result = session._derive_deployment_plan("architecture", []) + assert len(result) >= 2 + + def test_phase1_null_response(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._architect_agent = MagicMock() + session._architect_agent.execute.return_value = None + session._architect_agent.name = "cloud-architect" + session._architect_agent.set_governor_brief = MagicMock() + + result = session._derive_deployment_plan("architecture", []) + assert len(result) >= 2 + + +# ------------------------------------------------------------------ +# _build_stage_task extended coverage (lines 2146-2189) +# ------------------------------------------------------------------ + + +class TestBuildStageTaskExtended: + """Extended tests for _build_stage_task to cover cross-reference paths.""" + + def test_build_stage_task_with_templates(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [ + { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "dir": "concept/infra/terraform/stage-1-kv", + "services": [ + { + "name": "key-vault", + "computed_name": "kv-test", + "resource_type": "Microsoft.KeyVault/vaults", + "sku": "Standard", + "component": "secrets", + } + ], + "status": "pending", + "files": [], + }, + ] + + mock_template = MagicMock() + mock_template.display_name = "Web App" + mock_svc = MagicMock() + mock_svc.name = "key-vault" + mock_svc.type = "key-vault" + mock_svc.tier = "Standard" + mock_svc.config = {"softDelete": True} + mock_template.services = [mock_svc] + + stage = session._build_state._state["deployment_stages"][0] + agent, task = session._build_stage_task(stage, "architecture", [mock_template]) + assert agent is not None + assert "Template reference" in task + assert "softDelete" in task + + def test_build_stage_task_app_layer_prev_context(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + # Set up a developer agent for app stages + mock_dev = MagicMock() + mock_dev.name = "app-developer" + mock_dev._include_standards = True + mock_dev.set_knowledge_override = MagicMock() + mock_dev.set_governor_brief = MagicMock() + mock_dev.get_system_messages = MagicMock(return_value=[]) + mock_dev._governance_aware = False + mock_dev._enable_web_search = False + mock_dev._enable_mcp_tools = False + session._dev_agent = mock_dev + + session._build_state._state["deployment_stages"] = [ + { + "stage": 1, + "name": "Key Vault", + "layer": "data", + "capability": "data", + "dir": "concept/infra/terraform/stage-1-kv", + "services": [{"name": "key-vault", "computed_name": "kv-test"}], + "status": "generated", + "files": [], + }, + { + "stage": 2, + "name": "API", + "layer": "app", + "capability": "app", + "dir": "concept/apps/stage-2-api", + "services": [{"name": "api", "computed_name": "api-test"}], + "status": "pending", + "files": [], + }, + ] + + stage = session._build_state._state["deployment_stages"][1] + agent, task = session._build_stage_task(stage, "architecture", []) + assert agent is not None + # App layer should get infrastructure cross-reference + assert "Previously Generated Stages" in task + + +# ------------------------------------------------------------------ +# _build_qa_context (lines 2939, 2957-2958) +# ------------------------------------------------------------------ + + +class TestBuildQaContext: + """Tests for QA context construction.""" + + def test_qa_context_iac_includes_provider_compliance(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._iac_tool = "terraform" + session._build_state._state["deployment_stages"] = [] + + result = session._build_qa_context([], layer="infra") + assert "Provider Compliance" in result + + def test_qa_context_non_iac_skips_provider(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [] + + result = session._build_qa_context([], layer="app") + assert "Provider Compliance" not in result + + def test_qa_context_includes_standards(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + session._build_state._state["deployment_stages"] = [] + + result = session._build_qa_context([], layer="infra") + assert isinstance(result, str) + + +# ------------------------------------------------------------------ +# _collect_stage_file_content (line 3242) +# ------------------------------------------------------------------ + + +class TestCollectStageFileContent: + """Tests for single-stage file content collection.""" + + def test_collects_files(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + + project_dir = Path(build_context.project_dir) + stage_dir = project_dir / "concept" / "infra" / "terraform" / "stage-1" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text("resource azapi_resource {}", encoding="utf-8") + + rel_path = str((stage_dir / "main.tf").relative_to(project_dir)) + + stage = {"stage": 1, "name": "Test", "files": [rel_path]} + content = session._collect_stage_file_content(stage) + assert "azapi_resource" in content + + def test_empty_files_returns_empty(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage = {"stage": 1, "name": "Test", "files": []} + content = session._collect_stage_file_content(stage) + assert content == "" + + def test_missing_files_handled(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + stage = {"stage": 1, "name": "Test", "files": ["nonexistent/main.tf"]} + content = session._collect_stage_file_content(stage) + assert "(could not read file)" in content + + +# ------------------------------------------------------------------ +# run() -- Branch A first build (lines 313-324) +# ------------------------------------------------------------------ + + +class TestRunBranchA: + """Tests for run() Branch A: first build deriving fresh plan.""" + + def test_first_build_empty_plan_cancels(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "Test arch"} + + with patch.object(session, "_derive_deployment_plan", return_value=[]): + result = session.run( + design=design, + input_fn=lambda p: "done", + print_fn=lambda m: None, + ) + + assert result.cancelled is True + + def test_first_build_derives_and_saves(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "Test arch"} + + mock_agent = MagicMock() + mock_agent.name = "terraform-agent" + + stages = [ + { + "stage": 1, + "name": "Identity", + "layer": "core", + "capability": "infra", + "dir": "concept/infra/terraform/stage-1-identity", + "services": [], + "status": "pending", + "files": [], + }, + ] + + with patch.object(session, "_derive_deployment_plan", return_value=stages): + with patch.object(session, "_build_stage_task", return_value=(mock_agent, "task")): + with patch.object(session, "_execute_with_retry", return_value=MagicMock(content="ok", usage={})): + with patch.object(session, "_write_stage_files", return_value=[]): + with patch.object(session, "_apply_stage_transforms", return_value=[]): + session._run_stage_qa = lambda *a, **kw: True + result = session.run( + design=design, + input_fn=lambda p: "done", + print_fn=lambda m: None, + ) + + assert result is not None + assert not result.cancelled + + +# ------------------------------------------------------------------ +# run() -- confirmation prompt and plan adjustment (lines 418-440) +# ------------------------------------------------------------------ + + +class TestRunConfirmation: + """Tests for the plan confirmation prompt in run().""" + + def test_confirmation_quit_cancels(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "Test arch"} + + stages = [ + { + "stage": 1, + "name": "Test", + "layer": "core", + "capability": "infra", + "dir": "concept/infra/terraform/stage-1", + "services": [], + "status": "pending", + "files": [], + }, + ] + session._build_state.set_deployment_plan(stages) + session._build_state.set_design_snapshot(design) + + result = session.run( + design=design, + input_fn=lambda p: "quit", + print_fn=lambda m: None, + ) + + assert result.cancelled is True + + def test_confirmation_with_feedback_adjusts_plan(self, build_context, build_registry): + session = _make_session(build_context, build_registry) + design = {"architecture": "Test arch"} + + stages = [ + { + "stage": 1, + "name": "Test", + "layer": "core", + "capability": "infra", + "dir": "concept/infra/terraform/stage-1", + "services": [], + "status": "pending", + "files": [], + }, + ] + session._build_state.set_deployment_plan(stages) + session._build_state.set_design_snapshot(design) + + adjusted_stages = [ + { + "stage": 1, + "name": "Adjusted", + "layer": "core", + "capability": "infra", + "dir": "concept/infra/terraform/stage-1", + "services": [], + "status": "pending", + "files": [], + }, + ] + + calls = [0] + + def mock_input(prompt): + calls[0] += 1 + if calls[0] == 1: + return "add redis" + return "done" + + mock_agent = MagicMock() + mock_agent.name = "terraform-agent" + + with patch.object(session, "_adjust_plan", return_value=adjusted_stages): + session._run_stage_qa = lambda *a, **kw: True + with patch.object(session, "_build_stage_task", return_value=(mock_agent, "task")): + with patch.object(session, "_execute_with_retry", return_value=MagicMock(content="ok", usage={})): + with patch.object(session, "_write_stage_files", return_value=[]): + with patch.object(session, "_apply_stage_transforms", return_value=[]): + run_result = session.run( + design=design, + input_fn=mock_input, + print_fn=lambda m: None, + ) + + assert run_result is not None + +# --- Additional imports from merged flat test --- from azext_prototype.ai.provider import AIResponse +import yaml + # ====================================================================== # Helpers @@ -549,49 +4360,18 @@ def mock_architect_agent_for_build(): "status": "pending", "files": [], }, - ] - } - agent.execute.return_value = _make_response(f"```json\n{json.dumps(plan)}\n```") - return agent - - -@pytest.fixture -def mock_qa_agent(): - agent = MagicMock() - agent.name = "qa-engineer" - return agent - - -@pytest.fixture -def build_registry(mock_tf_agent, mock_dev_agent, mock_doc_agent, mock_architect_agent_for_build, mock_qa_agent): - registry = MagicMock() - - def find_by_cap(cap): - mapping = { - AgentCapability.TERRAFORM: [mock_tf_agent], - AgentCapability.BICEP: [], - AgentCapability.DEVELOP: [mock_dev_agent], - AgentCapability.DOCUMENT: [mock_doc_agent], - AgentCapability.ARCHITECT: [mock_architect_agent_for_build], - AgentCapability.QA: [mock_qa_agent], - } - return mapping.get(cap, []) - - registry.find_by_capability.side_effect = find_by_cap - return registry - - -@pytest.fixture -def build_context(project_with_design, sample_config): - """AgentContext for build tests with design already completed.""" - provider = MagicMock() - provider.provider_name = "github-models" - provider.chat.return_value = _make_response() - return AgentContext( - project_config=sample_config, - project_dir=str(project_with_design), - ai_provider=provider, - ) + ] + } + agent.execute.return_value = _make_response(f"```json\n{json.dumps(plan)}\n```") + return agent + + +@pytest.fixture +def mock_qa_agent(): + agent = MagicMock() + agent.name = "qa-engineer" + agent.execute.return_value = _make_response("All looks good. No issues found.") + return agent # ====================================================================== @@ -638,15 +4418,14 @@ def test_done_accepts(self, build_context, build_registry, mock_architect_agent_ session._governance = mock_gov_cls.return_value session._policy_resolver._governance = mock_gov_cls.return_value - # Patch AgentOrchestrator.delegate to avoid real QA call - with patch("azext_prototype.stages.build_session.AgentOrchestrator") as mock_orch: - mock_orch.return_value.delegate.return_value = _make_response("QA looks good") + # QA agent returns clean review + session._qa_agent.execute.return_value = _make_response("QA looks good. No issues found.") - result = session.run( - design={"architecture": "Sample architecture with key-vault and sql-database"}, - input_fn=lambda p: next(inputs), - print_fn=lambda m: None, - ) + result = session.run( + design={"architecture": "Sample architecture with key-vault and sql-database"}, + input_fn=lambda p: next(inputs), + print_fn=lambda m: None, + ) assert result.cancelled is False assert result.review_accepted is True @@ -971,14 +4750,13 @@ def test_reentrant_skips_generated_stages(self, build_context, build_registry, m session._governance = mock_gov_cls.return_value session._policy_resolver._governance = mock_gov_cls.return_value - with patch("azext_prototype.stages.build_session.AgentOrchestrator") as mock_orch: - mock_orch.return_value.delegate.return_value = _make_response("QA ok") + session._qa_agent.execute.return_value = _make_response("QA ok. No issues found.") - session.run( - design=design, - input_fn=lambda p: next(inputs), - print_fn=lambda m: None, - ) + session.run( + design=design, + input_fn=lambda p: next(inputs), + print_fn=lambda m: None, + ) # Stage 1 (generated) should NOT have been re-run # Only doc agent should have been called (for stage 2) @@ -986,6 +4764,9 @@ def test_reentrant_skips_generated_stages(self, build_context, build_registry, m assert mock_doc_agent.execute.call_count == 1 + # Re-entry validating tests moved to tests/stages/test_build_session_reentry.py + + # ====================================================================== # Incremental build / design snapshot tests # ====================================================================== @@ -1436,14 +5217,13 @@ def test_incremental_run_with_changes( session._governance = mock_gov_cls.return_value session._policy_resolver._governance = mock_gov_cls.return_value - with patch("azext_prototype.stages.build_session.AgentOrchestrator") as mock_orch: - mock_orch.return_value.delegate.return_value = _make_response("QA ok") + session._qa_agent.execute.return_value = _make_response("QA ok. No issues found.") - result = session.run( - design=new_design, - input_fn=lambda p: next(inputs), - print_fn=lambda m: printed.append(m), - ) + result = session.run( + design=new_design, + input_fn=lambda p: next(inputs), + print_fn=lambda m: printed.append(m), + ) output = "\n".join(printed) assert "Design changes detected" in output @@ -1532,14 +5312,13 @@ def architect_side_effect(ctx, task): session._governance = mock_gov_cls.return_value session._policy_resolver._governance = mock_gov_cls.return_value - with patch("azext_prototype.stages.build_session.AgentOrchestrator") as mock_orch: - mock_orch.return_value.delegate.return_value = _make_response("QA ok") + session._qa_agent.execute.return_value = _make_response("QA ok. No issues found.") - result = session.run( - design=new_design, - input_fn=lambda p: next(inputs), - print_fn=lambda m: printed.append(m), - ) + result = session.run( + design=new_design, + input_fn=lambda p: next(inputs), + print_fn=lambda m: printed.append(m), + ) output = "\n".join(printed) assert "full plan re-derive" in output.lower() @@ -2017,170 +5796,6 @@ def test_condense_caches_result_in_build_state(self, build_context, build_regist assert "Foundation" in cached["1"] -# ====================================================================== -# _select_agent tests -# ====================================================================== - - -class TestSelectAgent: - """Tests for _select_agent capability-to-agent mapping.""" - - def test_select_agent_infra(self, build_context, build_registry, mock_tf_agent): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - agent = session._select_agent({"capability": "infra"}) - assert agent is mock_tf_agent - - def test_select_agent_data(self, build_context, build_registry, mock_tf_agent): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - agent = session._select_agent({"capability": "data"}) - assert agent is mock_tf_agent - - def test_select_agent_integration(self, build_context, build_registry, mock_tf_agent): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - agent = session._select_agent({"capability": "integration"}) - assert agent is mock_tf_agent - - def test_select_agent_app(self, build_context, build_registry, mock_dev_agent): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - agent = session._select_agent({"capability": "app"}) - assert agent is mock_dev_agent - - def test_select_agent_schema(self, build_context, build_registry, mock_dev_agent): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - agent = session._select_agent({"capability": "schema"}) - assert agent is mock_dev_agent - - def test_select_agent_cicd(self, build_context, build_registry, mock_dev_agent): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - agent = session._select_agent({"capability": "cicd"}) - assert agent is mock_dev_agent - - def test_select_agent_external(self, build_context, build_registry, mock_dev_agent): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - agent = session._select_agent({"capability": "external"}) - assert agent is mock_dev_agent - - def test_select_agent_docs(self, build_context, build_registry, mock_doc_agent): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - agent = session._select_agent({"capability": "docs"}) - assert agent is mock_doc_agent - - def test_select_agent_unknown_falls_back_to_iac(self, build_context, build_registry, mock_tf_agent): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - agent = session._select_agent({"capability": "unknown_capability"}) - # Falls back to iac_agents[iac_tool] or dev_agent - assert agent is mock_tf_agent - - def test_select_agent_missing_capability_defaults_to_infra(self, build_context, build_registry, mock_tf_agent): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - agent = session._select_agent({}) - # capability defaults to "infra" - assert agent is mock_tf_agent - - def test_select_agent_no_agent_returns_none(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - session._doc_agent = None - agent = session._select_agent({"capability": "docs"}) - assert agent is None - - def test_select_agent_layer_core_routes_to_iac(self, build_context, build_registry, mock_tf_agent): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - # Core layer stages need IaC generation, not architecture design - agent = session._select_agent({"layer": "core", "capability": "identity"}) - assert agent is mock_tf_agent - - def test_select_agent_layer_docs(self, build_context, build_registry, mock_doc_agent): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - agent = session._select_agent({"layer": "docs", "capability": "docs"}) - assert agent is mock_doc_agent - - -# ====================================================================== -# _infer_layer tests -# ====================================================================== - - -class TestInferLayer: - """Tests for _infer_layer static method.""" - - def test_explicit_layer_preserved(self): - from azext_prototype.stages.build_session import BuildSession - - assert BuildSession._infer_layer({"layer": "data", "capability": "infra"}) == "data" - - def test_identity_stage_maps_to_core(self): - from azext_prototype.stages.build_session import BuildSession - - assert BuildSession._infer_layer({"name": "Managed Identity", "capability": "infra"}) == "core" - - def test_monitoring_stage_maps_to_core(self): - from azext_prototype.stages.build_session import BuildSession - - assert BuildSession._infer_layer({"name": "Log Analytics", "capability": "infra"}) == "core" - assert BuildSession._infer_layer({"name": "Application Insights", "capability": "infra"}) == "core" - - def test_infra_capability_maps_to_infra(self): - from azext_prototype.stages.build_session import BuildSession - - assert BuildSession._infer_layer({"name": "Networking", "capability": "infra"}) == "infra" - - def test_data_capability_maps_to_data(self): - from azext_prototype.stages.build_session import BuildSession - - assert BuildSession._infer_layer({"name": "Key Vault", "capability": "data"}) == "data" - - def test_app_capability_maps_to_app(self): - from azext_prototype.stages.build_session import BuildSession - - assert BuildSession._infer_layer({"name": "API", "capability": "app"}) == "app" - - def test_docs_capability_maps_to_docs(self): - from azext_prototype.stages.build_session import BuildSession - - assert BuildSession._infer_layer({"name": "Documentation", "capability": "docs"}) == "docs" - - def test_integration_capability_maps_to_infra(self): - from azext_prototype.stages.build_session import BuildSession - - assert BuildSession._infer_layer({"name": "APIM", "capability": "integration"}) == "infra" - - def test_unknown_capability_defaults_to_infra(self): - from azext_prototype.stages.build_session import BuildSession - - assert BuildSession._infer_layer({"name": "Custom", "capability": "xyz"}) == "infra" - - def test_empty_stage_defaults_to_infra(self): - from azext_prototype.stages.build_session import BuildSession - - assert BuildSession._infer_layer({}) == "infra" - - # ====================================================================== # Layer-based routing decisions (QA, anti-pattern scan, IaC detection) # ====================================================================== @@ -2599,44 +6214,6 @@ def test_build_stage_reset_ignores_missing_dirs(self, project_with_design): stage._clean_output_dirs(str(project_with_design)) -# ====================================================================== -# BuildResult tests -# ====================================================================== - - -class TestBuildResult: - - def test_default_values(self): - from azext_prototype.stages.build_session import BuildResult - - result = BuildResult() - assert result.files_generated == [] - assert result.deployment_stages == [] - assert result.policy_overrides == [] - assert result.resources == [] - assert result.review_accepted is False - assert result.cancelled is False - - def test_cancelled_result(self): - from azext_prototype.stages.build_session import BuildResult - - result = BuildResult(cancelled=True) - assert result.cancelled is True - assert result.review_accepted is False - - def test_populated_result(self): - from azext_prototype.stages.build_session import BuildResult - - result = BuildResult( - files_generated=["main.tf"], - resources=[{"resourceType": "Microsoft.KeyVault/vaults", "sku": "standard"}], - review_accepted=True, - ) - assert len(result.files_generated) == 1 - assert len(result.resources) == 1 - assert result.review_accepted is True - - # ====================================================================== # Architect-based stage identification tests (Phase 9) # ====================================================================== @@ -3059,11 +6636,10 @@ def test_per_stage_qa_passes_clean(self, tmp_project): printed = [] - with patch("azext_prototype.stages.build_session.AgentOrchestrator") as mock_orch: - mock_orch.return_value.delegate.return_value = _make_response( - "All looks good. Code is clean and well-structured." - ) - session._run_stage_qa(stage, "arch", [], False, lambda m: printed.append(m)) + qa_agent.execute = MagicMock( + return_value=_make_response("All looks good. Code is clean and well-structured.") + ) + session._run_stage_qa(stage, "arch", [], False, lambda m: printed.append(m)) output = "\n".join(printed) assert "passed QA" in output @@ -3091,15 +6667,14 @@ def test_per_stage_qa_triggers_remediation(self, tmp_project): printed = [] call_count = [0] - def mock_delegate(**kwargs): + def mock_qa_execute(ctx, task): call_count[0] += 1 if call_count[0] == 1: return _make_response("CRITICAL: Missing managed identity config. Must fix.") return _make_response("All resolved, no remaining issues.") - with patch("azext_prototype.stages.build_session.AgentOrchestrator") as mock_orch: - mock_orch.return_value.delegate.side_effect = mock_delegate - session._run_stage_qa(stage, "arch", [], False, lambda m: printed.append(m)) + qa_agent.execute = mock_qa_execute + session._run_stage_qa(stage, "arch", [], False, lambda m: printed.append(m)) output = "\n".join(printed) assert "remediating" in output.lower() @@ -3107,8 +6682,6 @@ def mock_delegate(**kwargs): assert call_count[0] >= 2 def test_per_stage_qa_max_attempts(self, tmp_project): - pass - session, qa_agent, tf_agent = self._make_session(tmp_project) stage_dir = tmp_project / "concept" / "infra" / "terraform" / "stage-1" @@ -3130,10 +6703,11 @@ def test_per_stage_qa_max_attempts(self, tmp_project): printed = [] - with patch("azext_prototype.stages.build_session.AgentOrchestrator") as mock_orch: - # Always return issues - mock_orch.return_value.delegate.return_value = _make_response("CRITICAL: This will never be fixed.") - session._run_stage_qa(stage, "arch", [], False, lambda m: printed.append(m)) + # Always return issues + qa_agent.execute = MagicMock( + return_value=_make_response("CRITICAL: This will never be fixed.") + ) + session._run_stage_qa(stage, "arch", [], False, lambda m: printed.append(m)) output = "\n".join(printed) assert "issues remain" in output.lower() @@ -3142,40 +6716,163 @@ def test_per_stage_qa_skips_docs_stages(self, tmp_project): """Docs capability stages should not get QA review during Phase 3.""" # This tests the gating in the Phase 3 loop, not _run_stage_qa itself stage = { - "stage": 5, - "name": "Documentation", - "capability": "docs", - "dir": "concept/docs", - "files": [], + "stage": 5, + "name": "Documentation", + "capability": "docs", + "dir": "concept/docs", + "files": [], + "status": "generated", + "services": [], + } + # docs capability is not in ("infra", "data", "integration", "app") + assert stage["capability"] not in ("infra", "data", "integration", "app") + + def test_collect_stage_file_content(self, tmp_project): + session, _, _ = self._make_session(tmp_project) + + stage_dir = tmp_project / "concept" / "infra" / "terraform" / "stage-1" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text('resource "null" "x" {}') + + stage = { + "stage": 1, + "name": "Foundation", + "capability": "infra", + "files": ["concept/infra/terraform/stage-1/main.tf"], + } + + content = session._collect_stage_file_content(stage) + assert "main.tf" in content + assert 'resource "null" "x"' in content + + def test_collect_stage_file_content_empty(self, tmp_project): + session, _, _ = self._make_session(tmp_project) + stage = {"stage": 1, "name": "Foundation", "files": []} + content = session._collect_stage_file_content(stage) + assert content == "" + + def test_qa_collects_complete_review_before_evaluating(self, tmp_project): + """When a QA review response is truncated (finish_reason=length), + the system must continue collecting until the full review is received, + then evaluate the concatenated result. + + Business requirement: QA must never evaluate a partial review. + """ + session, qa_agent, tf_agent = self._make_session(tmp_project) + + stage_dir = tmp_project / "concept" / "infra" / "terraform" / "stage-1" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text( + 'resource "azapi_resource" "rg" {\n type = "Microsoft.Resources/resourceGroups@2025-06-01"\n}' + ) + + stage = { + "stage": 1, + "name": "Foundation", + "capability": "infra", + "dir": "concept/infra/terraform/stage-1", + "files": ["concept/infra/terraform/stage-1/main.tf"], + "status": "generated", + "services": [], + } + + # First response: truncated mid-review (no verdict yet) + truncated = _make_response( + "### Review\n\nChecking authentication...\nChecking cross-stage refs...\n", + finish_reason="length", + ) + # Second response: continuation completes the review with a verdict + complete = _make_response( + "\nAll checks passed.\n\nVERDICT: PASS", + finish_reason="stop", + ) + qa_agent.execute = MagicMock(side_effect=[truncated, complete]) + + printed = [] + passed = session._run_stage_qa(stage, "arch", [], False, lambda m: printed.append(m)) + + assert passed is True, "Stage should pass QA after full review is collected" + assert qa_agent.execute.call_count == 2, "QA agent should be called twice (initial + continuation)" + assert "passed QA" in "\n".join(printed) + + def test_qa_continuation_requests_review_not_code(self, tmp_project): + """When QA is continued after truncation, the continuation prompt + must instruct the agent to continue reviewing — not to generate code. + + Business requirement: a truncated QA review must never trigger + code generation in the continuation response. + """ + session, qa_agent, tf_agent = self._make_session(tmp_project) + + stage_dir = tmp_project / "concept" / "infra" / "terraform" / "stage-1" + stage_dir.mkdir(parents=True, exist_ok=True) + (stage_dir / "main.tf").write_text( + 'resource "azapi_resource" "rg" {\n type = "Microsoft.Resources/resourceGroups@2025-06-01"\n}' + ) + + stage = { + "stage": 1, + "name": "Foundation", + "capability": "infra", + "dir": "concept/infra/terraform/stage-1", + "files": ["concept/infra/terraform/stage-1/main.tf"], "status": "generated", "services": [], } - # docs capability is not in ("infra", "data", "integration", "app") - assert stage["capability"] not in ("infra", "data", "integration", "app") - def test_collect_stage_file_content(self, tmp_project): - session, _, _ = self._make_session(tmp_project) + truncated = _make_response("Partial review...", finish_reason="length") + complete = _make_response("\nVERDICT: PASS", finish_reason="stop") + qa_agent.execute = MagicMock(side_effect=[truncated, complete]) + + session._run_stage_qa(stage, "arch", [], False, lambda m: None) + + # The continuation call (second invoke) must contain review language + second_call_args = qa_agent.execute.call_args_list[1] + continuation_task = second_call_args[0][1] # positional arg: (context, task) + continuation_lower = continuation_task.lower() + + assert "review" in continuation_lower, "Continuation prompt must mention 'review'" + assert "do not" in continuation_lower and "code" in continuation_lower, ( + "Continuation prompt must instruct agent NOT to generate code" + ) + + def test_qa_review_does_not_contaminate_generation_history(self, tmp_project): + """QA review messages — including continuations — must not leak + into the conversation history used for subsequent stage generation. + + Business requirement: each stage's generation must start from a + clean context, uncontaminated by QA review exchanges. + """ + session, qa_agent, tf_agent = self._make_session(tmp_project) stage_dir = tmp_project / "concept" / "infra" / "terraform" / "stage-1" stage_dir.mkdir(parents=True, exist_ok=True) - (stage_dir / "main.tf").write_text('resource "null" "x" {}') + (stage_dir / "main.tf").write_text( + 'resource "azapi_resource" "rg" {\n type = "Microsoft.Resources/resourceGroups@2025-06-01"\n}' + ) stage = { "stage": 1, "name": "Foundation", "capability": "infra", + "dir": "concept/infra/terraform/stage-1", "files": ["concept/infra/terraform/stage-1/main.tf"], + "status": "generated", + "services": [], } - content = session._collect_stage_file_content(stage) - assert "main.tf" in content - assert 'resource "null" "x"' in content + history_before = len(session._context.conversation_history) - def test_collect_stage_file_content_empty(self, tmp_project): - session, _, _ = self._make_session(tmp_project) - stage = {"stage": 1, "name": "Foundation", "files": []} - content = session._collect_stage_file_content(stage) - assert content == "" + truncated = _make_response("Partial review...", finish_reason="length") + complete = _make_response("\nVERDICT: PASS", finish_reason="stop") + qa_agent.execute = MagicMock(side_effect=[truncated, complete]) + + session._run_stage_qa(stage, "arch", [], False, lambda m: None) + + history_after = len(session._context.conversation_history) + assert history_after == history_before, ( + f"QA contaminated conversation history: {history_before} → {history_after} messages" + ) # ====================================================================== @@ -3322,21 +7019,20 @@ def test_advisory_qa_no_remediation_loop(self, tmp_project): session._governance = mock_gov_cls.return_value session._policy_resolver._governance = mock_gov_cls.return_value - with patch("azext_prototype.stages.build_session.AgentOrchestrator") as mock_orch: - # Return warnings — in old code this would trigger remediation - mock_orch.return_value.delegate.return_value = _make_response( - "WARNING: Missing monitoring. CRITICAL: No backup config." - ) + # QA agent returns warnings — in old code this would trigger remediation + qa_agent.execute = MagicMock( + return_value=_make_response("WARNING: Missing monitoring. CRITICAL: No backup config.") + ) - with patch.object(session, "_identify_affected_stages") as mock_identify: - session.run( - design={"architecture": "Simple architecture"}, - input_fn=lambda p: next(inputs), - print_fn=lambda m: None, - ) + with patch.object(session, "_identify_affected_stages") as mock_identify: + session.run( + design={"architecture": "Simple architecture"}, + input_fn=lambda p: next(inputs), + print_fn=lambda m: None, + ) - # _identify_affected_stages should NOT have been called during Phase 4 - mock_identify.assert_not_called() + # _identify_affected_stages should NOT have been called during Phase 4 + mock_identify.assert_not_called() def test_advisory_qa_header_says_advisory(self, tmp_project): """Output should contain 'Advisory notes' not 'QA Review'.""" @@ -3431,304 +7127,108 @@ def test_stable_ids_unique_on_name_collision(self, tmp_project): bs.set_deployment_plan(stages) ids = [s["id"] for s in bs.state["deployment_stages"]] - assert len(set(ids)) == 2 # all unique - assert ids[0] == "foundation" - assert ids[1] == "foundation-2" - - def test_stable_ids_backfilled_on_load(self, tmp_project): - from azext_prototype.stages.build_state import BuildState - - # Write a legacy state file without ids - state_dir = Path(str(tmp_project)) / ".prototype" / "state" - state_dir.mkdir(parents=True, exist_ok=True) - legacy = { - "deployment_stages": [ - { - "stage": 1, - "name": "Foundation", - "capability": "infra", - "services": [], - "status": "generated", - "files": [], - }, - ], - "templates_used": [], - "iac_tool": "terraform", - "_metadata": {"created": None, "last_updated": None, "iteration": 0}, - } - with open(state_dir / "build.yaml", "w") as f: - yaml.dump(legacy, f) - - bs = BuildState(str(tmp_project)) - bs.load() - assert bs.state["deployment_stages"][0]["id"] == "foundation" - assert bs.state["deployment_stages"][0]["deploy_mode"] == "auto" - - def test_get_stage_by_id(self, tmp_project): - from azext_prototype.stages.build_state import BuildState - - bs = BuildState(str(tmp_project)) - stages = [ - {"stage": 1, "name": "Foundation", "capability": "infra", "services": [], "status": "pending", "files": []}, - {"stage": 2, "name": "Data Layer", "capability": "data", "services": [], "status": "pending", "files": []}, - ] - bs.set_deployment_plan(stages) - - found = bs.get_stage_by_id("data-layer") - assert found is not None - assert found["name"] == "Data Layer" - assert bs.get_stage_by_id("nonexistent") is None - - def test_deploy_mode_in_stage_schema(self, tmp_project): - from azext_prototype.stages.build_state import BuildState - - bs = BuildState(str(tmp_project)) - stages = [ - { - "stage": 1, - "name": "Manual Upload", - "capability": "external", - "services": [], - "status": "pending", - "files": [], - "deploy_mode": "manual", - "manual_instructions": "Upload the notebook to the Fabric workspace.", - }, - { - "stage": 2, - "name": "Foundation", - "capability": "infra", - "services": [], - "status": "pending", - "files": [], - }, - ] - bs.set_deployment_plan(stages) - - assert bs.state["deployment_stages"][0]["deploy_mode"] == "manual" - assert "Upload" in bs.state["deployment_stages"][0]["manual_instructions"] - assert bs.state["deployment_stages"][1]["deploy_mode"] == "auto" - assert bs.state["deployment_stages"][1]["manual_instructions"] is None - - def test_add_stages_assigns_ids(self, tmp_project): - from azext_prototype.stages.build_state import BuildState - - bs = BuildState(str(tmp_project)) - bs.set_deployment_plan( - [ - { - "stage": 1, - "name": "Foundation", - "capability": "infra", - "services": [], - "status": "pending", - "files": [], - }, - ] - ) - bs.add_stages( - [ - {"name": "API Layer", "capability": "app"}, - ] - ) - ids = [s["id"] for s in bs.state["deployment_stages"]] - assert "api-layer" in ids - - -# ====================================================================== -# _get_app_scaffolding_requirements tests -# ====================================================================== - - -class TestGetAppScaffoldingRequirements: - """Tests for _get_app_scaffolding_requirements static method.""" - - def test_infra_capability_returns_empty(self): - from azext_prototype.stages.build_session import BuildSession - - result = BuildSession._get_app_scaffolding_requirements({"layer": "infra", "capability": "infra", "services": []}) - assert result == "" - - def test_data_capability_returns_empty(self): - from azext_prototype.stages.build_session import BuildSession - - result = BuildSession._get_app_scaffolding_requirements({"layer": "data", "capability": "data", "services": []}) - assert result == "" - - def test_docs_capability_returns_empty(self): - from azext_prototype.stages.build_session import BuildSession - - result = BuildSession._get_app_scaffolding_requirements({"layer": "docs", "capability": "docs", "services": []}) - assert result == "" - - def test_functions_detected_by_resource_type(self): - from azext_prototype.stages.build_session import BuildSession - - stage = { - "layer": "app", - "capability": "app", - "services": [{"name": "api", "resource_type": "Microsoft.Web/functionapps"}], - } - result = BuildSession._get_app_scaffolding_requirements(stage) - assert "host.json" in result - assert ".csproj" in result - - def test_functions_detected_by_name(self): - from azext_prototype.stages.build_session import BuildSession - - stage = { - "layer": "app", - "capability": "app", - "services": [{"name": "function-app", "resource_type": ""}], - } - result = BuildSession._get_app_scaffolding_requirements(stage) - assert "host.json" in result - - def test_webapp_without_language_hint_gets_generic(self): - """Webapp resource type without a language hint falls back to generic.""" - from azext_prototype.stages.build_session import BuildSession - - stage = { - "layer": "app", - "capability": "app", - "services": [{"name": "api", "resource_type": "Microsoft.Web/sites"}], - } - result = BuildSession._get_app_scaffolding_requirements(stage) - assert "Required Project Files" in result - assert "Entry point" in result - - def test_webapp_with_framework_hint_detected(self): - """Webapp with a framework name in the service name returns framework-specific scaffolding.""" - from azext_prototype.stages.build_session import BuildSession - - stage = { - "layer": "app", - "capability": "app", - "services": [{"name": "api-fastapi", "resource_type": "Microsoft.App/containerApps"}], - } - result = BuildSession._get_app_scaffolding_requirements(stage) - assert "FastAPI" in result - assert "requirements.txt" in result - assert "Dockerfile" in result - - def test_generic_app_fallback(self): - from azext_prototype.stages.build_session import BuildSession - - stage = { - "layer": "app", - "capability": "app", - "services": [{"name": "worker", "resource_type": ""}], - } - result = BuildSession._get_app_scaffolding_requirements(stage) - assert "Required Project Files" in result - assert "Entry point" in result - - def test_schema_capability_triggers_scaffolding(self): - from azext_prototype.stages.build_session import BuildSession - - stage = { - "layer": "app", - "capability": "schema", - "services": [{"name": "db-migration", "resource_type": ""}], - } - result = BuildSession._get_app_scaffolding_requirements(stage) - assert "Required Project Files" in result - - def test_external_capability_triggers_scaffolding(self): - from azext_prototype.stages.build_session import BuildSession - - stage = { - "layer": "app", - "capability": "external", - "services": [{"name": "stripe-integration", "resource_type": ""}], - } - result = BuildSession._get_app_scaffolding_requirements(stage) - assert "Required Project Files" in result - - -# ====================================================================== -# _write_stage_files tests -# ====================================================================== - - -class TestWriteStageFiles: - """Tests for _write_stage_files edge cases.""" - - def test_empty_content_returns_empty(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - stage = {"dir": "concept/infra/terraform/stage-1-foundation"} - - result = session._write_stage_files(stage, "") - assert result == [] - - def test_no_file_blocks_returns_empty(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - stage = {"dir": "concept/infra/terraform/stage-1-foundation"} - - result = session._write_stage_files(stage, "This is just text with no code blocks.") - assert result == [] - - def test_writes_files_and_returns_paths(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - stage = {"dir": "concept/infra/terraform/stage-1-foundation"} - - content = '```main.tf\n# terraform code\n```\n\n```variables.tf\nvariable "name" {}\n```' - result = session._write_stage_files(stage, content) - - assert len(result) == 2 - # Files should exist on disk - project_root = Path(build_context.project_dir) - for rel_path in result: - assert (project_root / rel_path).exists() - - def test_strips_stage_dir_prefix_from_filenames(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession + assert len(set(ids)) == 2 # all unique + assert ids[0] == "foundation" + assert ids[1] == "foundation-2" - session = BuildSession(build_context, build_registry) - stage_dir = "concept/infra/terraform/stage-1-foundation" - stage = {"dir": stage_dir} + def test_stable_ids_backfilled_on_load(self, tmp_project): + from azext_prototype.stages.build_state import BuildState - # AI sometimes includes full path in filename - content = f"```{stage_dir}/main.tf\n# code\n```" - result = session._write_stage_files(stage, content) + # Write a legacy state file without ids + state_dir = Path(str(tmp_project)) / ".prototype" / "state" + state_dir.mkdir(parents=True, exist_ok=True) + legacy = { + "deployment_stages": [ + { + "stage": 1, + "name": "Foundation", + "capability": "infra", + "services": [], + "status": "generated", + "files": [], + }, + ], + "templates_used": [], + "iac_tool": "terraform", + "_metadata": {"created": None, "last_updated": None, "iteration": 0}, + } + with open(state_dir / "build.yaml", "w") as f: + yaml.dump(legacy, f) - assert len(result) == 1 - # Should NOT create nested duplicate path - assert result[0] == f"{stage_dir}/main.tf" + bs = BuildState(str(tmp_project)) + bs.load() + assert bs.state["deployment_stages"][0]["id"] == "foundation" + assert bs.state["deployment_stages"][0]["deploy_mode"] == "auto" - def test_blocks_versions_tf_for_terraform(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession + def test_get_stage_by_id(self, tmp_project): + from azext_prototype.stages.build_state import BuildState - session = BuildSession(build_context, build_registry) - session._iac_tool = "terraform" - stage = {"dir": "concept/infra/terraform/stage-1"} + bs = BuildState(str(tmp_project)) + stages = [ + {"stage": 1, "name": "Foundation", "capability": "infra", "services": [], "status": "pending", "files": []}, + {"stage": 2, "name": "Data Layer", "capability": "data", "services": [], "status": "pending", "files": []}, + ] + bs.set_deployment_plan(stages) - content = "```main.tf\n# main code\n```\n\n```versions.tf\n# should be blocked\n```" - result = session._write_stage_files(stage, content) + found = bs.get_stage_by_id("data-layer") + assert found is not None + assert found["name"] == "Data Layer" + assert bs.get_stage_by_id("nonexistent") is None - filenames = [Path(p).name for p in result] - assert "main.tf" in filenames - assert "versions.tf" not in filenames + def test_deploy_mode_in_stage_schema(self, tmp_project): + from azext_prototype.stages.build_state import BuildState - def test_allows_versions_tf_for_bicep(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession + bs = BuildState(str(tmp_project)) + stages = [ + { + "stage": 1, + "name": "Manual Upload", + "capability": "external", + "services": [], + "status": "pending", + "files": [], + "deploy_mode": "manual", + "manual_instructions": "Upload the notebook to the Fabric workspace.", + }, + { + "stage": 2, + "name": "Foundation", + "capability": "infra", + "services": [], + "status": "pending", + "files": [], + }, + ] + bs.set_deployment_plan(stages) - session = BuildSession(build_context, build_registry) - session._iac_tool = "bicep" - stage = {"dir": "concept/infra/bicep/stage-1"} + assert bs.state["deployment_stages"][0]["deploy_mode"] == "manual" + assert "Upload" in bs.state["deployment_stages"][0]["manual_instructions"] + assert bs.state["deployment_stages"][1]["deploy_mode"] == "auto" + assert bs.state["deployment_stages"][1]["manual_instructions"] is None - content = "```main.bicep\n# main code\n```\n\n```versions.tf\n# allowed for bicep\n```" - result = session._write_stage_files(stage, content) + def test_add_stages_assigns_ids(self, tmp_project): + from azext_prototype.stages.build_state import BuildState - filenames = [Path(p).name for p in result] - assert "main.bicep" in filenames - assert "versions.tf" in filenames + bs = BuildState(str(tmp_project)) + bs.set_deployment_plan( + [ + { + "stage": 1, + "name": "Foundation", + "capability": "infra", + "services": [], + "status": "pending", + "files": [], + }, + ] + ) + bs.add_stages( + [ + {"name": "API Layer", "capability": "app"}, + ] + ) + ids = [s["id"] for s in bs.state["deployment_stages"]] + assert "api-layer" in ids # ====================================================================== @@ -3822,115 +7322,6 @@ def test_describe_non_numeric(self, build_context, build_registry): assert "Usage" in output -# ====================================================================== -# _clean_removed_stage_files tests -# ====================================================================== - - -class TestCleanRemovedStageFiles: - """Tests for _clean_removed_stage_files.""" - - def test_removes_existing_directory(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - - # Create the directory with a file - stage_dir = Path(build_context.project_dir) / "concept" / "infra" / "terraform" / "stage-2-data" - stage_dir.mkdir(parents=True, exist_ok=True) - (stage_dir / "main.tf").write_text("# data stage", encoding="utf-8") - assert stage_dir.exists() - - stages = [ - {"stage": 2, "dir": "concept/infra/terraform/stage-2-data"}, - ] - session._clean_removed_stage_files([2], stages) - - assert not stage_dir.exists() - - def test_ignores_nonexistent_directory(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - - stages = [ - {"stage": 2, "dir": "concept/infra/terraform/stage-2-nonexistent"}, - ] - # Should not raise - session._clean_removed_stage_files([2], stages) - - def test_ignores_stage_not_in_removed_list(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - - stage_dir = Path(build_context.project_dir) / "concept" / "infra" / "terraform" / "stage-1-foundation" - stage_dir.mkdir(parents=True, exist_ok=True) - (stage_dir / "main.tf").write_text("# keep this", encoding="utf-8") - - stages = [ - {"stage": 1, "dir": "concept/infra/terraform/stage-1-foundation"}, - {"stage": 2, "dir": "concept/infra/terraform/stage-2-data"}, - ] - # Only remove stage 2, not stage 1 - session._clean_removed_stage_files([2], stages) - - assert stage_dir.exists() - - -# ====================================================================== -# _fix_stage_dirs tests -# ====================================================================== - - -class TestFixStageDirs: - """Tests for _fix_stage_dirs after stage renumbering.""" - - def test_renumbers_stage_dir_paths(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - session._build_state._state["deployment_stages"] = [ - { - "stage": 1, - "name": "A", - "dir": "concept/infra/terraform/stage-1-foundation", - "capability": "infra", - "services": [], - "status": "generated", - "files": [], - }, - { - "stage": 2, - "name": "B", - "dir": "concept/infra/terraform/stage-4-data", - "capability": "data", - "services": [], - "status": "pending", - "files": [], - }, - ] - - session._fix_stage_dirs() - - stages = session._build_state._state["deployment_stages"] - assert stages[0]["dir"] == "concept/infra/terraform/stage-1-foundation" - assert stages[1]["dir"] == "concept/infra/terraform/stage-2-data" - - def test_skips_empty_dirs(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - session._build_state._state["deployment_stages"] = [ - {"stage": 1, "name": "A", "dir": "", "capability": "infra", "services": [], "status": "pending", "files": []}, - ] - - # Should not raise - session._fix_stage_dirs() - - assert session._build_state._state["deployment_stages"][0]["dir"] == "" - - # ====================================================================== # _build_stage_task bicep branch tests # ====================================================================== @@ -4079,97 +7470,6 @@ def test_no_files_returns_empty(self, build_context, build_registry): assert result == "" -# ====================================================================== -# _collect_generated_file_content tests -# ====================================================================== - - -class TestCollectGeneratedFileContent: - """Tests for _collect_generated_file_content.""" - - def test_collects_from_generated_stages(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - - # Create a file - stage_dir = Path(build_context.project_dir) / "concept" / "infra" / "terraform" / "stage-1" - stage_dir.mkdir(parents=True, exist_ok=True) - (stage_dir / "main.tf").write_text("# tf code", encoding="utf-8") - - session._build_state.set_deployment_plan( - [ - { - "stage": 1, - "name": "Foundation", - "capability": "infra", - "services": [], - "status": "generated", - "dir": "concept/infra/terraform/stage-1", - "files": ["concept/infra/terraform/stage-1/main.tf"], - }, - ] - ) - - result = session._collect_generated_file_content() - assert "main.tf" in result - assert "tf code" in result - - def test_empty_when_no_generated_stages(self, build_context, build_registry): - from azext_prototype.stages.build_session import BuildSession - - session = BuildSession(build_context, build_registry) - session._build_state.set_deployment_plan( - [ - { - "stage": 1, - "name": "Foundation", - "capability": "infra", - "services": [], - "status": "pending", - "dir": "", - "files": [], - }, - ] - ) - - result = session._collect_generated_file_content() - assert result == "" - - -# ====================================================================== -# Naming strategy fallback tests -# ====================================================================== - - -class TestNamingStrategyFallback: - """Tests for the naming strategy fallback in __init__.""" - - def test_naming_fallback_on_invalid_config(self, project_with_design, sample_config): - """When naming config is invalid, should fall back to simple strategy.""" - from azext_prototype.stages.build_session import BuildSession - - # Corrupt the naming config - sample_config["naming"]["strategy"] = "nonexistent-strategy" - - provider = MagicMock() - provider.provider_name = "github-models" - provider.chat.return_value = _make_response() - - context = AgentContext( - project_config=sample_config, - project_dir=str(project_with_design), - ai_provider=provider, - ) - - registry = MagicMock() - registry.find_by_capability.return_value = [] - - # Should not raise — falls back to simple strategy - session = BuildSession(context, registry) - assert session._naming is not None - - # ====================================================================== # _identify_stages_via_architect edge cases # ====================================================================== diff --git a/tests/stages/test_build_stage.py b/tests/stages/test_build_stage.py new file mode 100644 index 0000000..adea885 --- /dev/null +++ b/tests/stages/test_build_stage.py @@ -0,0 +1,301 @@ +"""Tests for BuildStage — guard conditions, state transitions, dry-run routing. + +Covers: +- Multi-guard validation (3 prerequisites: project_initialized, discovery_complete, design_complete) +- State transitions (IN_PROGRESS, COMPLETED, FAILED) +- Reset behavior (clears build state and output dirs) +- Dry-run vs interactive routing +- Template matching with threshold scoring +- Design loading from state file +""" + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from azext_prototype.stages.base import StageState + +# ====================================================================== +# Fixtures +# ====================================================================== + + +@pytest.fixture +def build_stage(): + from azext_prototype.stages.build_stage import BuildStage + + return BuildStage() + + +@pytest.fixture +def agent_context(project_with_design, sample_config): + from azext_prototype.agents.base import AgentContext + + provider = MagicMock() + provider.provider_name = "github-models" + provider.chat.return_value = MagicMock( + content="ok", + model="test", + usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + ) + return AgentContext( + project_config=sample_config, + project_dir=str(project_with_design), + ai_provider=provider, + ) + + +@pytest.fixture +def registry(): + return MagicMock() + + +# ====================================================================== +# Guard validation +# ====================================================================== + + +class TestBuildStageGuards: + """Test multi-guard prerequisite checking.""" + + def test_guards_return_three_guards(self, build_stage): + guards = build_stage.get_guards() + assert len(guards) == 3 + names = [g.name for g in guards] + assert "project_initialized" in names + assert "discovery_complete" in names + assert "design_complete" in names + + def test_all_guards_pass(self, build_stage, project_with_design, monkeypatch): + """All files exist → can_run returns True.""" + monkeypatch.chdir(project_with_design) + # Ensure discovery.yaml exists + disco = project_with_design / ".prototype" / "state" / "discovery.yaml" + disco.write_text("exchange_count: 1", encoding="utf-8") + + can_run, failures = build_stage.can_run() + assert can_run is True + assert failures == [] + + def test_missing_project_yaml(self, build_stage, tmp_path, monkeypatch): + """No prototype.yaml → first guard fails.""" + monkeypatch.chdir(tmp_path) + can_run, failures = build_stage.can_run() + assert can_run is False + assert any("prototype" in f.lower() or "init" in f.lower() for f in failures) + + def test_missing_discovery_state(self, build_stage, project_with_config, monkeypatch): + """prototype.yaml exists but no discovery.yaml → discovery guard fails.""" + monkeypatch.chdir(project_with_config) + can_run, failures = build_stage.can_run() + assert can_run is False + assert any("discovery" in f.lower() for f in failures) + + def test_missing_design_json(self, build_stage, project_with_config, monkeypatch): + """Has prototype.yaml and discovery.yaml but no design.json → design guard fails.""" + monkeypatch.chdir(project_with_config) + disco = project_with_config / ".prototype" / "state" / "discovery.yaml" + disco.write_text("exchange_count: 1", encoding="utf-8") + can_run, failures = build_stage.can_run() + assert can_run is False + assert any("design" in f.lower() for f in failures) + + +# ====================================================================== +# State transitions +# ====================================================================== + + +class TestBuildStageStateTransitions: + """Test stage state moves correctly during execute.""" + + def test_initial_state_not_started(self, build_stage): + assert build_stage.state == StageState.NOT_STARTED + + def test_execute_sets_in_progress_then_completed(self, build_stage, agent_context, registry): + """Dry-run sets IN_PROGRESS then COMPLETED.""" + result = build_stage.execute(agent_context, registry, dry_run=True, print_fn=lambda x: None) + assert build_stage.state == StageState.COMPLETED + assert result["status"] == "dry-run" + + def test_cancelled_session_sets_failed(self, build_stage, agent_context, registry): + """When BuildSession returns cancelled, state goes to FAILED.""" + mock_result = MagicMock() + mock_result.cancelled = True + + with patch("azext_prototype.stages.build_stage.BuildSession") as mock_session_cls: + mock_session_cls.return_value.run.return_value = mock_result + result = build_stage.execute(agent_context, registry, print_fn=lambda x: None) + assert build_stage.state == StageState.FAILED + assert result["status"] == "cancelled" + + def test_successful_session_sets_completed(self, build_stage, agent_context, registry): + """When BuildSession completes successfully, state goes to COMPLETED.""" + mock_result = MagicMock() + mock_result.cancelled = False + mock_result.policy_overrides = [] + mock_result.files_generated = ["main.tf"] + mock_result.deployment_stages = [] + mock_result.resources = [] + + with ( + patch("azext_prototype.stages.build_stage.BuildSession") as mock_session_cls, + patch("azext_prototype.stages.build_stage.ProjectConfig") as mock_config_cls, + ): + mock_config_cls.return_value.load.return_value = None + mock_config_cls.return_value.get.return_value = "terraform" + mock_session_cls.return_value.run.return_value = mock_result + result = build_stage.execute(agent_context, registry, print_fn=lambda x: None) + assert build_stage.state == StageState.COMPLETED + assert result["status"] == "success" + + def test_missing_architecture_raises(self, build_stage, agent_context, registry): + """If design.json has no architecture key, CLIError is raised.""" + from knack.util import CLIError + + # Overwrite design.json with empty architecture + design_path = Path(agent_context.project_dir) / ".prototype" / "state" / "design.json" + with open(design_path, "w") as f: + json.dump({"artifacts": []}, f) + + with pytest.raises(CLIError, match="No architecture"): + build_stage.execute(agent_context, registry, print_fn=lambda x: None) + + +# ====================================================================== +# Reset behavior +# ====================================================================== + + +class TestBuildStageReset: + """Test reset clears build state and output directories.""" + + def test_reset_cleans_output_dirs(self, build_stage, agent_context, registry): + """--reset cleans concept/infra, concept/apps, etc.""" + project_dir = Path(agent_context.project_dir) + # Create output dirs + for d in build_stage._OUTPUT_DIRS: + (project_dir / d).mkdir(parents=True, exist_ok=True) + (project_dir / d / "test.tf").write_text("content", encoding="utf-8") + + # Run with reset + dry_run to avoid full session + build_stage.execute(agent_context, registry, reset=True, dry_run=True, print_fn=lambda x: None) + + for d in build_stage._OUTPUT_DIRS: + assert not (project_dir / d).is_dir() + + def test_reset_nonexistent_dirs_no_error(self, build_stage, agent_context, registry): + """--reset with no existing output dirs should not error.""" + build_stage.execute(agent_context, registry, reset=True, dry_run=True, print_fn=lambda x: None) + assert build_stage.state == StageState.COMPLETED + + +# ====================================================================== +# Dry-run routing +# ====================================================================== + + +class TestBuildStageDryRun: + """Test dry-run mode behavior.""" + + def test_dry_run_all_scope(self, build_stage, agent_context, registry): + printed = [] + result = build_stage.execute(agent_context, registry, dry_run=True, scope="all", print_fn=printed.append) + assert result["status"] == "dry-run" + assert result["scope"] == "all" + assert "infra" in result["results"] + assert "apps" in result["results"] + assert "db" in result["results"] + assert "docs" in result["results"] + + def test_dry_run_infra_only(self, build_stage, agent_context, registry): + printed = [] + result = build_stage.execute(agent_context, registry, dry_run=True, scope="infra", print_fn=printed.append) + assert "infra" in result["results"] + assert "apps" not in result["results"] + + def test_dry_run_apps_only(self, build_stage, agent_context, registry): + printed = [] + result = build_stage.execute(agent_context, registry, dry_run=True, scope="apps", print_fn=printed.append) + assert "apps" in result["results"] + assert "infra" not in result["results"] + + def test_dry_run_with_templates(self, build_stage, agent_context, registry): + """When templates match, dry-run shows template names.""" + printed = [] + mock_tmpl = MagicMock() + mock_tmpl.display_name = "Web App" + mock_tmpl.service_names.return_value = [] + + with patch.object(build_stage, "_match_templates", return_value=[mock_tmpl]): + build_stage.execute(agent_context, registry, dry_run=True, print_fn=printed.append) + assert any("Web App" in p for p in printed) + + +# ====================================================================== +# Template matching +# ====================================================================== + + +class TestTemplateMatching: + """Test template matching with threshold scoring.""" + + def test_match_templates_above_threshold(self, build_stage): + mock_tmpl = MagicMock() + mock_tmpl.service_names.return_value = ["key-vault", "app-service"] + + mock_registry = MagicMock() + mock_registry.list_templates.return_value = [mock_tmpl] + + with patch("azext_prototype.templates.registry.TemplateRegistry", return_value=mock_registry): + design = {"architecture": "Deploy key-vault and app-service resources"} + config = MagicMock() + templates = build_stage._match_templates(design, config) + assert len(templates) == 1 + + def test_match_templates_below_threshold(self, build_stage): + mock_tmpl = MagicMock() + mock_tmpl.service_names.return_value = ["key-vault", "cosmos-db", "redis", "apim"] + + mock_registry = MagicMock() + mock_registry.list_templates.return_value = [mock_tmpl] + + with patch("azext_prototype.templates.registry.TemplateRegistry", return_value=mock_registry): + design = {"architecture": "Only key-vault is mentioned"} + config = MagicMock() + templates = build_stage._match_templates(design, config) + assert len(templates) == 0 + + def test_match_templates_empty_architecture(self, build_stage): + design = {"architecture": ""} + config = MagicMock() + assert build_stage._match_templates(design, config) == [] + + def test_match_templates_no_templates_available(self, build_stage): + mock_registry = MagicMock() + mock_registry.list_templates.return_value = [] + + with patch("azext_prototype.templates.registry.TemplateRegistry", return_value=mock_registry): + design = {"architecture": "something"} + config = MagicMock() + templates = build_stage._match_templates(design, config) + assert templates == [] + + +# ====================================================================== +# Design loading +# ====================================================================== + + +class TestLoadDesign: + """Test _load_design from state file.""" + + def test_load_existing_design(self, build_stage, project_with_design): + design = build_stage._load_design(str(project_with_design)) + assert "architecture" in design + + def test_load_missing_design(self, build_stage, tmp_path): + design = build_stage._load_design(str(tmp_path)) + assert design == {} diff --git a/tests/test_coverage_design_deploy.py b/tests/stages/test_coverage_design_deploy.py similarity index 100% rename from tests/test_coverage_design_deploy.py rename to tests/stages/test_coverage_design_deploy.py diff --git a/tests/test_coverage_gaps.py b/tests/stages/test_coverage_gaps.py similarity index 78% rename from tests/test_coverage_gaps.py rename to tests/stages/test_coverage_gaps.py index 89f7774..2ed8a0b 100644 --- a/tests/test_coverage_gaps.py +++ b/tests/stages/test_coverage_gaps.py @@ -18,119 +18,11 @@ class TestDeployHelpersDeep: """Deep tests for deploy_helpers module-level functions.""" - # --- Bicep helpers --- - - def test_find_bicep_params_json(self, tmp_path): - from azext_prototype.stages.deploy_helpers import find_bicep_params - - (tmp_path / "main.parameters.json").write_text("{}", encoding="utf-8") - result = find_bicep_params(tmp_path, tmp_path / "main.bicep") - assert result is not None - assert result.name == "main.parameters.json" - - def test_find_bicep_params_bicepparam(self, tmp_path): - from azext_prototype.stages.deploy_helpers import find_bicep_params - - (tmp_path / "main.bicepparam").write_text("", encoding="utf-8") - result = find_bicep_params(tmp_path, tmp_path / "main.bicep") - assert result is not None - assert result.name == "main.bicepparam" - - def test_find_bicep_params_generic(self, tmp_path): - from azext_prototype.stages.deploy_helpers import find_bicep_params - - (tmp_path / "parameters.json").write_text("{}", encoding="utf-8") - result = find_bicep_params(tmp_path, tmp_path / "main.bicep") - assert result is not None - assert result.name == "parameters.json" - - def test_find_bicep_params_none(self, tmp_path): - from azext_prototype.stages.deploy_helpers import find_bicep_params - - result = find_bicep_params(tmp_path, tmp_path / "main.bicep") - assert result is None - - def test_is_subscription_scoped_true(self, tmp_path): - from azext_prototype.stages.deploy_helpers import is_subscription_scoped - - bicep = tmp_path / "main.bicep" - bicep.write_text("targetScope = 'subscription'\n", encoding="utf-8") - assert is_subscription_scoped(bicep) is True - - def test_is_subscription_scoped_false(self, tmp_path): - from azext_prototype.stages.deploy_helpers import is_subscription_scoped - - bicep = tmp_path / "main.bicep" - bicep.write_text("resource rg 'Microsoft.Resources/resourceGroups@2023-07-01' = {}\n", encoding="utf-8") - assert is_subscription_scoped(bicep) is False - def test_is_subscription_scoped_missing_file(self, tmp_path): from azext_prototype.stages.deploy_helpers import is_subscription_scoped assert is_subscription_scoped(tmp_path / "nope.bicep") is False - def test_get_deploy_location_from_params(self, tmp_path): - from azext_prototype.stages.deploy_helpers import get_deploy_location - - params = {"parameters": {"location": {"value": "westus2"}}} - (tmp_path / "parameters.json").write_text(json.dumps(params), encoding="utf-8") - assert get_deploy_location(tmp_path) == "westus2" - - def test_get_deploy_location_from_string(self, tmp_path): - from azext_prototype.stages.deploy_helpers import get_deploy_location - - params = {"location": "centralus"} - (tmp_path / "parameters.json").write_text(json.dumps(params), encoding="utf-8") - assert get_deploy_location(tmp_path) == "centralus" - - def test_get_deploy_location_none(self, tmp_path): - from azext_prototype.stages.deploy_helpers import get_deploy_location - - assert get_deploy_location(tmp_path) is None - - def test_get_deploy_location_invalid_json(self, tmp_path): - from azext_prototype.stages.deploy_helpers import get_deploy_location - - (tmp_path / "parameters.json").write_text("not json", encoding="utf-8") - assert get_deploy_location(tmp_path) is None - - # --- check_az_login --- - - @patch("azext_prototype.stages.deploy_helpers.subprocess.run") - def test_check_az_login_true(self, mock_run): - from azext_prototype.stages.deploy_helpers import check_az_login - - mock_run.return_value = MagicMock(returncode=0) - assert check_az_login() is True - - @patch("azext_prototype.stages.deploy_helpers.subprocess.run") - def test_check_az_login_false(self, mock_run): - from azext_prototype.stages.deploy_helpers import check_az_login - - mock_run.return_value = MagicMock(returncode=1) - assert check_az_login() is False - - @patch("azext_prototype.stages.deploy_helpers.subprocess.run", side_effect=FileNotFoundError) - def test_check_az_login_no_az(self, mock_run): - from azext_prototype.stages.deploy_helpers import check_az_login - - assert check_az_login() is False - - # --- get_current_subscription --- - - @patch("azext_prototype.stages.deploy_helpers.subprocess.run") - def test_get_current_subscription(self, mock_run): - from azext_prototype.stages.deploy_helpers import get_current_subscription - - mock_run.return_value = MagicMock(returncode=0, stdout="sub-abc-123\n") - assert get_current_subscription() == "sub-abc-123" - - @patch("azext_prototype.stages.deploy_helpers.subprocess.run", side_effect=FileNotFoundError) - def test_get_current_subscription_error(self, mock_run): - from azext_prototype.stages.deploy_helpers import get_current_subscription - - assert get_current_subscription() == "" - # --- deploy_terraform --- @patch("azext_prototype.stages.deploy_helpers.subprocess.run") @@ -490,7 +382,7 @@ class TestResolveDefinition: def test_resolve_known(self): from azext_prototype.custom import _resolve_definition - defs_dir = Path(__file__).resolve().parent.parent / "azext_prototype" / "agents" / "builtin" / "definitions" + defs_dir = Path(__file__).resolve().parent.parent.parent / "azext_prototype" / "agents" / "builtin" / "definitions" result = _resolve_definition(defs_dir, "example_custom_agent") assert result.exists() diff --git a/tests/stages/test_deploy_helpers.py b/tests/stages/test_deploy_helpers.py new file mode 100644 index 0000000..1975574 --- /dev/null +++ b/tests/stages/test_deploy_helpers.py @@ -0,0 +1,1019 @@ +"""Tests for deploy_helpers — error handling paths. + +Covers: +- Azure CLI command execution with error handling (subprocess errors, FileNotFoundError, stderr parsing) +- Terraform secret variable scanning (suffix detection, default value overriding, deduplication) +- Secret resolution with generation (reuse existing, generate new, config update) +- Az CLI path resolution (Windows .cmd variant, fallback ordering) +- build_deploy_env construction +- check_az_login / get_current_subscription / get_current_tenant +- login_service_principal / set_deployment_context +- DeploymentOutputCapture: terraform/bicep capture, accessors, env vars +- find_bicep_params / is_subscription_scoped / get_deploy_location +""" + +import json +import os +import subprocess +from unittest.mock import MagicMock, patch + +# ====================================================================== +# _find_az / _az — Az CLI path resolution +# ====================================================================== + + +class TestFindAz: + """Test _find_az fallback chain.""" + + def test_shutil_which_found(self): + """When shutil.which finds az, return that path.""" + from azext_prototype.stages import deploy_helpers + + # Clear module cache + deploy_helpers._AZ = None + with patch("shutil.which", return_value="/usr/bin/az"): + result = deploy_helpers._find_az() + assert result == "/usr/bin/az" + + def test_falls_back_to_python_bin_dir(self, tmp_path): + """When shutil.which returns None, check Python's bin dir.""" + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + fake_az = tmp_path / "az" + fake_az.touch() + + with ( + patch("shutil.which", return_value=None), + patch.object(deploy_helpers.sys, "executable", str(tmp_path / "python")), + patch("os.path.isfile") as mock_isfile, + ): + # First call: az candidate check → True + # Second call (would be .cmd check) should not be reached + mock_isfile.side_effect = lambda p: p == str(tmp_path / "az") + result = deploy_helpers._find_az() + assert result == str(tmp_path / "az") + + def test_falls_back_to_windows_cmd(self, tmp_path): + """When az is not found but az.cmd exists, return .cmd path.""" + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with ( + patch("shutil.which", return_value=None), + patch.object(deploy_helpers.sys, "executable", str(tmp_path / "python")), + patch("os.path.isfile") as mock_isfile, + ): + # First call: az candidate → False, second call: az.cmd → True + def isfile_side(p): + return p.endswith(".cmd") + + mock_isfile.side_effect = isfile_side + result = deploy_helpers._find_az() + assert result.endswith(".cmd") + + def test_final_fallback_bare_az(self, tmp_path): + """When nothing else works, return bare 'az'.""" + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with ( + patch("shutil.which", return_value=None), + patch.object(deploy_helpers.sys, "executable", str(tmp_path / "python")), + patch("os.path.isfile", return_value=False), + ): + result = deploy_helpers._find_az() + assert result == "az" + + def test_az_caches_result(self): + """_az() caches on first call.""" + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with patch.object(deploy_helpers, "_find_az", return_value="/cached/az") as mock_find: + val1 = deploy_helpers._az() + val2 = deploy_helpers._az() + assert val1 == "/cached/az" + assert val2 == "/cached/az" + mock_find.assert_called_once() + # Clean up + deploy_helpers._AZ = None + + +# ====================================================================== +# build_deploy_env +# ====================================================================== + + +# ====================================================================== +# Terraform Secret Variable Scanning +# ====================================================================== + + +class TestScanTfSecretVariables: + """Test scan_tf_secret_variables with suffix detection, defaults, dedup.""" + + def test_detects_secret_suffix(self, tmp_path): + from azext_prototype.stages.deploy_helpers import scan_tf_secret_variables + + tf_file = tmp_path / "main.tf" + tf_file.write_text( + """ +variable "db_password" { + type = string +} +""", + encoding="utf-8", + ) + result = scan_tf_secret_variables(tmp_path) + assert "db_password" in result + + def test_detects_secret_suffix_underscore(self, tmp_path): + from azext_prototype.stages.deploy_helpers import scan_tf_secret_variables + + tf_file = tmp_path / "main.tf" + tf_file.write_text( + """ +variable "api_secret" { + type = string +} +""", + encoding="utf-8", + ) + result = scan_tf_secret_variables(tmp_path) + assert "api_secret" in result + + def test_skips_non_secret_variable(self, tmp_path): + from azext_prototype.stages.deploy_helpers import scan_tf_secret_variables + + tf_file = tmp_path / "main.tf" + tf_file.write_text( + """ +variable "resource_group_name" { + type = string +} +""", + encoding="utf-8", + ) + result = scan_tf_secret_variables(tmp_path) + assert result == [] + + def test_skips_known_auth_variables(self, tmp_path): + from azext_prototype.stages.deploy_helpers import scan_tf_secret_variables + + tf_file = tmp_path / "main.tf" + tf_file.write_text( + """ +variable "client_secret" { + type = string +} +""", + encoding="utf-8", + ) + result = scan_tf_secret_variables(tmp_path) + assert "client_secret" not in result + + def test_skips_variable_with_non_empty_default(self, tmp_path): + from azext_prototype.stages.deploy_helpers import scan_tf_secret_variables + + tf_file = tmp_path / "main.tf" + tf_file.write_text( + """ +variable "db_password" { + type = string + default = "predefined-value" +} +""", + encoding="utf-8", + ) + result = scan_tf_secret_variables(tmp_path) + assert result == [] + + def test_includes_variable_with_empty_default(self, tmp_path): + from azext_prototype.stages.deploy_helpers import scan_tf_secret_variables + + tf_file = tmp_path / "main.tf" + tf_file.write_text( + """ +variable "db_password" { + type = string + default = "" +} +""", + encoding="utf-8", + ) + result = scan_tf_secret_variables(tmp_path) + assert "db_password" in result + + def test_deduplicates_across_files(self, tmp_path): + from azext_prototype.stages.deploy_helpers import scan_tf_secret_variables + + (tmp_path / "a.tf").write_text( + 'variable "db_password" {\n type = string\n}\n', + encoding="utf-8", + ) + (tmp_path / "b.tf").write_text( + 'variable "db_password" {\n type = string\n}\n', + encoding="utf-8", + ) + result = scan_tf_secret_variables(tmp_path) + assert result.count("db_password") == 1 + + def test_handles_unreadable_file(self, tmp_path): + from azext_prototype.stages.deploy_helpers import scan_tf_secret_variables + + tf_file = tmp_path / "main.tf" + tf_file.write_text("some content", encoding="utf-8") + # Make unreadable (best-effort; may not work on all platforms) + with patch("pathlib.Path.read_text", side_effect=OSError("permission denied")): + result = scan_tf_secret_variables(tmp_path) + assert result == [] + + def test_no_tf_files(self, tmp_path): + from azext_prototype.stages.deploy_helpers import scan_tf_secret_variables + + result = scan_tf_secret_variables(tmp_path) + assert result == [] + + +# ====================================================================== +# resolve_stage_secrets +# ====================================================================== + + +class TestResolveStageSecrets: + """Test secret resolution: reuse existing, generate new, config update.""" + + def test_no_secrets_needed(self, tmp_path): + from azext_prototype.stages.deploy_helpers import resolve_stage_secrets + + (tmp_path / "main.tf").write_text( + 'variable "name" {\n type = string\n}\n', + encoding="utf-8", + ) + config = MagicMock() + result = resolve_stage_secrets(tmp_path, config) + assert result == {} + + def test_generates_new_secret(self, tmp_path): + from azext_prototype.stages.deploy_helpers import resolve_stage_secrets + + (tmp_path / "main.tf").write_text( + 'variable "db_password" {\n type = string\n}\n', + encoding="utf-8", + ) + config = MagicMock() + config.get.return_value = {} + result = resolve_stage_secrets(tmp_path, config) + assert "TF_VAR_db_password" in result + assert len(result["TF_VAR_db_password"]) == 64 # 32 bytes hex + config.set.assert_called_once() + + def test_reuses_existing_secret(self, tmp_path): + from azext_prototype.stages.deploy_helpers import resolve_stage_secrets + + (tmp_path / "main.tf").write_text( + 'variable "db_password" {\n type = string\n}\n', + encoding="utf-8", + ) + config = MagicMock() + config.get.return_value = {"db_password": "existing-secret-value"} + result = resolve_stage_secrets(tmp_path, config) + assert result["TF_VAR_db_password"] == "existing-secret-value" + config.set.assert_not_called() + + def test_stored_not_dict_generates_new(self, tmp_path): + """If stored secrets is a non-dict, treat as missing.""" + from azext_prototype.stages.deploy_helpers import resolve_stage_secrets + + (tmp_path / "main.tf").write_text( + 'variable "db_password" {\n type = string\n}\n', + encoding="utf-8", + ) + config = MagicMock() + config.get.return_value = "not-a-dict" + result = resolve_stage_secrets(tmp_path, config) + assert "TF_VAR_db_password" in result + config.set.assert_called_once() + + +# ====================================================================== +# check_az_login / get_current_subscription / get_current_tenant +# ====================================================================== + + +class TestAzCliCommands: + """Test Azure CLI command wrappers with error handling.""" + + def test_check_az_login_success(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with patch("subprocess.run") as mock_run, patch.object(deploy_helpers, "_find_az", return_value="az"): + mock_run.return_value = MagicMock(returncode=0) + deploy_helpers._AZ = None + result = deploy_helpers.check_az_login() + assert result is True + + def test_check_az_login_failure(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with patch("subprocess.run") as mock_run, patch.object(deploy_helpers, "_find_az", return_value="az"): + mock_run.return_value = MagicMock(returncode=1) + deploy_helpers._AZ = None + result = deploy_helpers.check_az_login() + assert result is False + + def test_check_az_login_file_not_found(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with ( + patch("subprocess.run", side_effect=FileNotFoundError), + patch.object(deploy_helpers, "_find_az", return_value="az"), + ): + deploy_helpers._AZ = None + result = deploy_helpers.check_az_login() + assert result is False + + def test_get_current_subscription_success(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with patch("subprocess.run") as mock_run, patch.object(deploy_helpers, "_find_az", return_value="az"): + mock_run.return_value = MagicMock(returncode=0, stdout="sub-id-123\n") + deploy_helpers._AZ = None + result = deploy_helpers.get_current_subscription() + assert result == "sub-id-123" + + def test_get_current_subscription_error(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with ( + patch("subprocess.run", side_effect=subprocess.CalledProcessError(1, "az")), + patch.object(deploy_helpers, "_find_az", return_value="az"), + ): + deploy_helpers._AZ = None + result = deploy_helpers.get_current_subscription() + assert result == "" + + def test_get_current_subscription_file_not_found(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with ( + patch("subprocess.run", side_effect=FileNotFoundError), + patch.object(deploy_helpers, "_find_az", return_value="az"), + ): + deploy_helpers._AZ = None + result = deploy_helpers.get_current_subscription() + assert result == "" + + def test_get_current_tenant_success(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with patch("subprocess.run") as mock_run, patch.object(deploy_helpers, "_find_az", return_value="az"): + mock_run.return_value = MagicMock(returncode=0, stdout="tenant-abc\n") + deploy_helpers._AZ = None + result = deploy_helpers.get_current_tenant() + assert result == "tenant-abc" + + def test_get_current_tenant_error(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with ( + patch("subprocess.run", side_effect=FileNotFoundError), + patch.object(deploy_helpers, "_find_az", return_value="az"), + ): + deploy_helpers._AZ = None + result = deploy_helpers.get_current_tenant() + assert result == "" + + +# ====================================================================== +# login_service_principal +# ====================================================================== + + +class TestLoginServicePrincipal: + """Test service principal login with error paths.""" + + def test_login_success(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with ( + patch("subprocess.run") as mock_run, + patch.object(deploy_helpers, "_find_az", return_value="az"), + patch.object(deploy_helpers, "get_current_subscription", return_value="sub-after"), + ): + mock_run.return_value = MagicMock(returncode=0) + deploy_helpers._AZ = None + result = deploy_helpers.login_service_principal("cid", "csec", "tid") + assert result["status"] == "ok" + assert result["subscription"] == "sub-after" + + def test_login_failure_returncode(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with patch("subprocess.run") as mock_run, patch.object(deploy_helpers, "_find_az", return_value="az"): + mock_run.return_value = MagicMock(returncode=1, stderr="auth failed", stdout="") + deploy_helpers._AZ = None + result = deploy_helpers.login_service_principal("cid", "csec", "tid") + assert result["status"] == "failed" + assert "auth failed" in result["error"] + + def test_login_file_not_found(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with ( + patch("subprocess.run", side_effect=FileNotFoundError), + patch.object(deploy_helpers, "_find_az", return_value="az"), + ): + deploy_helpers._AZ = None + result = deploy_helpers.login_service_principal("cid", "csec", "tid") + assert result["status"] == "failed" + assert "not found" in result["error"] + + +# ====================================================================== +# set_deployment_context +# ====================================================================== + + +class TestSetDeploymentContext: + """Test set_deployment_context with tenant and error paths.""" + + def test_success_without_tenant(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with patch("subprocess.run") as mock_run, patch.object(deploy_helpers, "_find_az", return_value="az"): + mock_run.return_value = MagicMock(returncode=0) + deploy_helpers._AZ = None + result = deploy_helpers.set_deployment_context("sub-123") + assert result["status"] == "ok" + + def test_success_with_tenant(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with patch("subprocess.run") as mock_run, patch.object(deploy_helpers, "_find_az", return_value="az"): + mock_run.return_value = MagicMock(returncode=0) + deploy_helpers._AZ = None + result = deploy_helpers.set_deployment_context("sub-123", tenant="tid") + assert result["status"] == "ok" + # Verify --tenant flag was passed + call_args = mock_run.call_args[0][0] + assert "--tenant" in call_args + assert "tid" in call_args + + def test_failure_returncode(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with patch("subprocess.run") as mock_run, patch.object(deploy_helpers, "_find_az", return_value="az"): + mock_run.return_value = MagicMock(returncode=1, stderr="bad sub", stdout="") + deploy_helpers._AZ = None + result = deploy_helpers.set_deployment_context("bad-sub") + assert result["status"] == "failed" + assert "bad sub" in result["error"] + + def test_file_not_found(self): + from azext_prototype.stages import deploy_helpers + + deploy_helpers._AZ = None + with ( + patch("subprocess.run", side_effect=FileNotFoundError), + patch.object(deploy_helpers, "_find_az", return_value="az"), + ): + deploy_helpers._AZ = None + result = deploy_helpers.set_deployment_context("sub-123") + assert result["status"] == "failed" + assert "not found" in result["error"] + + +# ====================================================================== +# find_bicep_params / is_subscription_scoped / get_deploy_location +# ====================================================================== + + +class TestBicepDiscovery: + """Test Bicep template/parameter file discovery helpers.""" + + def test_find_bicep_params_parameters_json(self, tmp_path): + from azext_prototype.stages.deploy_helpers import find_bicep_params + + (tmp_path / "main.parameters.json").touch() + result = find_bicep_params(tmp_path, tmp_path / "main.bicep") + assert result == tmp_path / "main.parameters.json" + + def test_find_bicep_params_bicepparam(self, tmp_path): + from azext_prototype.stages.deploy_helpers import find_bicep_params + + (tmp_path / "main.bicepparam").touch() + result = find_bicep_params(tmp_path, tmp_path / "main.bicep") + assert result == tmp_path / "main.bicepparam" + + def test_find_bicep_params_generic(self, tmp_path): + from azext_prototype.stages.deploy_helpers import find_bicep_params + + (tmp_path / "parameters.json").touch() + result = find_bicep_params(tmp_path, tmp_path / "main.bicep") + assert result == tmp_path / "parameters.json" + + def test_find_bicep_params_none(self, tmp_path): + from azext_prototype.stages.deploy_helpers import find_bicep_params + + result = find_bicep_params(tmp_path, tmp_path / "main.bicep") + assert result is None + + def test_find_bicep_params_priority(self, tmp_path): + """parameters.json beats bicepparam, stem.parameters.json beats both.""" + from azext_prototype.stages.deploy_helpers import find_bicep_params + + (tmp_path / "main.parameters.json").touch() + (tmp_path / "main.bicepparam").touch() + (tmp_path / "parameters.json").touch() + result = find_bicep_params(tmp_path, tmp_path / "main.bicep") + assert result == tmp_path / "main.parameters.json" + + def test_is_subscription_scoped_true(self, tmp_path): + from azext_prototype.stages.deploy_helpers import is_subscription_scoped + + bicep_file = tmp_path / "main.bicep" + bicep_file.write_text("targetScope = 'subscription'\n\nresource rg ...", encoding="utf-8") + assert is_subscription_scoped(bicep_file) is True + + def test_is_subscription_scoped_false(self, tmp_path): + from azext_prototype.stages.deploy_helpers import is_subscription_scoped + + bicep_file = tmp_path / "main.bicep" + bicep_file.write_text("resource kv 'Microsoft.KeyVault/vaults@...'", encoding="utf-8") + assert is_subscription_scoped(bicep_file) is False + + def test_is_subscription_scoped_unreadable(self, tmp_path): + from azext_prototype.stages.deploy_helpers import is_subscription_scoped + + bicep_file = tmp_path / "missing.bicep" + assert is_subscription_scoped(bicep_file) is False + + def test_get_deploy_location_from_params(self, tmp_path): + from azext_prototype.stages.deploy_helpers import get_deploy_location + + params = {"parameters": {"location": {"value": "westus2"}}} + (tmp_path / "parameters.json").write_text(json.dumps(params), encoding="utf-8") + assert get_deploy_location(tmp_path) == "westus2" + + def test_get_deploy_location_string_value(self, tmp_path): + from azext_prototype.stages.deploy_helpers import get_deploy_location + + params = {"location": "eastus"} + (tmp_path / "parameters.json").write_text(json.dumps(params), encoding="utf-8") + assert get_deploy_location(tmp_path) == "eastus" + + def test_get_deploy_location_none(self, tmp_path): + from azext_prototype.stages.deploy_helpers import get_deploy_location + + assert get_deploy_location(tmp_path) is None + + def test_get_deploy_location_bad_json(self, tmp_path): + from azext_prototype.stages.deploy_helpers import get_deploy_location + + (tmp_path / "parameters.json").write_text("not json", encoding="utf-8") + assert get_deploy_location(tmp_path) is None + + +# ====================================================================== +# DeploymentOutputCapture +# ====================================================================== + + +class TestDeploymentOutputCapture: + """Test capture, accessors, and env var generation.""" + + def test_capture_terraform(self, tmp_path): + from azext_prototype.stages.deploy_helpers import DeploymentOutputCapture + + cap = DeploymentOutputCapture(str(tmp_path)) + tf_output = json.dumps( + { + "endpoint": {"value": "https://app.azurewebsites.net", "type": "string"}, + "key": {"value": "abc123", "type": "string"}, + } + ) + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0, stdout=tf_output) + result = cap.capture_terraform(tmp_path / "infra") + assert result["endpoint"] == "https://app.azurewebsites.net" + assert result["key"] == "abc123" + + def test_capture_terraform_failure(self, tmp_path): + from azext_prototype.stages.deploy_helpers import DeploymentOutputCapture + + cap = DeploymentOutputCapture(str(tmp_path)) + with patch("subprocess.run", side_effect=FileNotFoundError): + result = cap.capture_terraform(tmp_path / "infra") + assert result == {} + + def test_capture_bicep(self, tmp_path): + from azext_prototype.stages.deploy_helpers import DeploymentOutputCapture + + cap = DeploymentOutputCapture(str(tmp_path)) + bicep_output = json.dumps( + { + "properties": { + "outputs": { + "storageEndpoint": {"value": "https://storage.blob.core.windows.net", "type": "string"}, + } + } + } + ) + result = cap.capture_bicep(bicep_output) + assert result["storageEndpoint"] == "https://storage.blob.core.windows.net" + + def test_capture_bicep_bad_json(self, tmp_path): + from azext_prototype.stages.deploy_helpers import DeploymentOutputCapture + + cap = DeploymentOutputCapture(str(tmp_path)) + result = cap.capture_bicep("not json") + assert result == {} + + def test_get_across_providers(self, tmp_path): + from azext_prototype.stages.deploy_helpers import DeploymentOutputCapture + + cap = DeploymentOutputCapture(str(tmp_path)) + cap._outputs = { + "terraform": {"key1": "val1"}, + "bicep": {"key2": "val2"}, + } + assert cap.get("key1") == "val1" + assert cap.get("key2") == "val2" + assert cap.get("missing", "default") == "default" + + def test_get_all(self, tmp_path): + from azext_prototype.stages.deploy_helpers import DeploymentOutputCapture + + cap = DeploymentOutputCapture(str(tmp_path)) + cap._outputs = {"terraform": {"a": 1}} + all_outputs = cap.get_all() + assert all_outputs == {"terraform": {"a": 1}} + # Verify it's a copy + all_outputs["extra"] = True + assert "extra" not in cap._outputs + + def test_to_env_vars(self, tmp_path): + from azext_prototype.stages.deploy_helpers import DeploymentOutputCapture + + cap = DeploymentOutputCapture(str(tmp_path)) + cap._outputs = { + "terraform": {"endpoint": "https://app.com"}, + "bicep": {"storage_key": "secret"}, + } + env = cap.to_env_vars() + assert env["PROTOTYPE_ENDPOINT"] == "https://app.com" + assert env["PROTOTYPE_STORAGE_KEY"] == "secret" + + def test_flatten_outputs_plain_values(self, tmp_path): + from azext_prototype.stages.deploy_helpers import DeploymentOutputCapture + + flat = DeploymentOutputCapture._flatten_outputs({"simple": "value"}) + assert flat["simple"] == "value" + + def test_load_existing(self, tmp_path): + """Test that existing outputs file is loaded on construction.""" + from azext_prototype.stages.deploy_helpers import DeploymentOutputCapture + + output_dir = tmp_path / ".prototype" / "state" + output_dir.mkdir(parents=True) + (output_dir / "deployment_outputs.json").write_text( + json.dumps({"terraform": {"x": "y"}}), + encoding="utf-8", + ) + cap = DeploymentOutputCapture(str(tmp_path)) + assert cap.get("x") == "y" + + def test_load_bad_json(self, tmp_path): + """Bad JSON in existing outputs file falls back to empty dict.""" + from azext_prototype.stages.deploy_helpers import DeploymentOutputCapture + + output_dir = tmp_path / ".prototype" / "state" + output_dir.mkdir(parents=True) + (output_dir / "deployment_outputs.json").write_text("not json", encoding="utf-8") + cap = DeploymentOutputCapture(str(tmp_path)) + assert cap._outputs == {} + +# --- Additional imports from merged flat test --- +from azext_prototype.stages.deploy_helpers import DEPLOY_ENV_MAPPING, DeploymentOutputCapture, DeployScriptGenerator, RollbackManager, build_deploy_env, resolve_stage_secrets, scan_tf_secret_variables +from pathlib import Path + + +class TestDeployScriptGenerator: + """Test deploy script generation.""" + + def test_generate_webapp_script(self, tmp_path): + app_dir = tmp_path / "my-api" + app_dir.mkdir() + + script = DeployScriptGenerator.generate( + app_dir=app_dir, + app_name="my-api", + deploy_type="webapp", + resource_group="rg-test", + ) + + assert "#!/usr/bin/env bash" in script + assert "my-api" in script + assert "az webapp deploy" in script + assert (app_dir / "deploy.sh").exists() + + def test_generate_container_app_script(self, tmp_path): + app_dir = tmp_path / "my-app" + app_dir.mkdir() + + script = DeployScriptGenerator.generate( + app_dir=app_dir, + app_name="my-app", + deploy_type="container_app", + resource_group="rg-test", + registry="myregistry.azurecr.io", + ) + + assert "az acr build" in script + assert "az containerapp update" in script + assert "myregistry.azurecr.io" in script + + def test_generate_function_script(self, tmp_path): + app_dir = tmp_path / "my-func" + app_dir.mkdir() + + script = DeployScriptGenerator.generate( + app_dir=app_dir, + app_name="my-func", + deploy_type="function", + resource_group="rg-test", + ) + + assert "func azure functionapp publish" in script + assert "my-func" in script + + +class TestRollbackManager: + """Test rollback tracking and instructions.""" + + def test_snapshot_before_deploy(self, tmp_project): + mgr = RollbackManager(str(tmp_project)) + snapshot = mgr.snapshot_before_deploy("infra", "terraform") + + assert snapshot["scope"] == "infra" + assert snapshot["iac_tool"] == "terraform" + assert "timestamp" in snapshot + + def test_multiple_snapshots(self, tmp_project): + mgr = RollbackManager(str(tmp_project)) + mgr.snapshot_before_deploy("infra", "terraform") + mgr.snapshot_before_deploy("apps", "terraform") + + latest = mgr.get_last_snapshot() + assert latest["scope"] == "apps" + + def test_rollback_instructions_terraform(self, tmp_project): + mgr = RollbackManager(str(tmp_project)) + mgr.snapshot_before_deploy("infra", "terraform") + + instructions = mgr.get_rollback_instructions() + assert any("terraform" in line.lower() for line in instructions) + + def test_rollback_instructions_bicep(self, tmp_project): + mgr = RollbackManager(str(tmp_project)) + mgr.snapshot_before_deploy("infra", "bicep") + + instructions = mgr.get_rollback_instructions() + assert any("bicep" in line.lower() or "deployment" in line.lower() for line in instructions) + + def test_no_snapshots(self, tmp_project): + mgr = RollbackManager(str(tmp_project)) + assert mgr.get_last_snapshot() is None + + instructions = mgr.get_rollback_instructions() + assert len(instructions) >= 1 # Should have "nothing to roll back" message + + def test_persistence(self, tmp_project): + mgr1 = RollbackManager(str(tmp_project)) + mgr1.snapshot_before_deploy("infra", "terraform") + + mgr2 = RollbackManager(str(tmp_project)) + assert mgr2.get_last_snapshot() is not None + assert mgr2.get_last_snapshot()["scope"] == "infra" + + +class TestDeployEnvMapping: + """Tests for DEPLOY_ENV_MAPPING and build_deploy_env().""" + + def test_mapping_covers_all_params(self): + """Every build_deploy_env parameter has a mapping entry.""" + assert "subscription" in DEPLOY_ENV_MAPPING + assert "tenant" in DEPLOY_ENV_MAPPING + assert "client_id" in DEPLOY_ENV_MAPPING + assert "client_secret" in DEPLOY_ENV_MAPPING + + def test_mapping_includes_tf_var(self): + """Each param maps to at least one TF_VAR_* entry.""" + for param, keys in DEPLOY_ENV_MAPPING.items(): + tf_vars = [k for k in keys if k.startswith("TF_VAR_")] + assert tf_vars, f"{param} has no TF_VAR_* mapping" + + def test_mapping_includes_arm(self): + """Each param maps to at least one ARM_* entry.""" + for param, keys in DEPLOY_ENV_MAPPING.items(): + arm_vars = [k for k in keys if k.startswith("ARM_")] + assert arm_vars, f"{param} has no ARM_* mapping" + + def test_all_fields(self): + env = build_deploy_env("sub-123", "tenant-456", "client-id", "secret") + # ARM vars + assert env["ARM_SUBSCRIPTION_ID"] == "sub-123" + assert env["ARM_TENANT_ID"] == "tenant-456" + assert env["ARM_CLIENT_ID"] == "client-id" + assert env["ARM_CLIENT_SECRET"] == "secret" + # TF_VAR vars (auto-resolve HCL variables) + assert env["TF_VAR_subscription_id"] == "sub-123" + assert env["TF_VAR_tenant_id"] == "tenant-456" + assert env["TF_VAR_client_id"] == "client-id" + assert env["TF_VAR_client_secret"] == "secret" + # Legacy + assert env["SUBSCRIPTION_ID"] == "sub-123" + + def test_subscription_only(self): + env = build_deploy_env("sub-123") + assert env["ARM_SUBSCRIPTION_ID"] == "sub-123" + assert env["TF_VAR_subscription_id"] == "sub-123" + assert env["SUBSCRIPTION_ID"] == "sub-123" + assert "ARM_TENANT_ID" not in env + assert "TF_VAR_tenant_id" not in env + assert "ARM_CLIENT_ID" not in env + + def test_inherits_os_environ(self): + env = build_deploy_env("sub-123") + # PATH should be inherited from os.environ + assert "PATH" in env + + def test_empty(self): + env = build_deploy_env() + assert "ARM_SUBSCRIPTION_ID" not in env + assert "TF_VAR_subscription_id" not in env + assert "ARM_TENANT_ID" not in env + # Should still have os.environ entries + assert "PATH" in env + + +class TestDeployEnvPassing: + """Tests that verify env is passed through to subprocess calls.""" + + @patch("subprocess.run") + def test_deploy_terraform_passes_env(self, mock_run): + from azext_prototype.stages.deploy_helpers import deploy_terraform + + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + test_env = build_deploy_env("sub-123", "tenant-456") + + deploy_terraform(Path("/tmp/fake"), "sub-123", env=test_env) + + # All subprocess.run calls should receive env=test_env + for c in mock_run.call_args_list: + assert c.kwargs.get("env") is test_env + + @patch("subprocess.run") + def test_deploy_bicep_adds_tenant_flag(self, mock_run): + from azext_prototype.stages.deploy_helpers import deploy_bicep + + mock_run.return_value = MagicMock(returncode=0, stdout="{}", stderr="") + infra_dir = Path("/tmp/fake") + test_env = build_deploy_env("sub-123", "tenant-456") + + # Create a mock bicep file + with patch.object(Path, "exists", return_value=True), patch.object(Path, "glob", return_value=[]), patch( + "azext_prototype.stages.deploy_helpers.find_bicep_params", return_value=None + ), patch("azext_prototype.stages.deploy_helpers.is_subscription_scoped", return_value=False): + deploy_bicep(infra_dir, "sub-123", "my-rg", env=test_env) + + # Verify --tenant was added to the command + cmd = mock_run.call_args[0][0] + assert "--tenant" in cmd + assert "tenant-456" in cmd + assert mock_run.call_args.kwargs.get("env") is test_env + + @patch("subprocess.run") + def test_deploy_app_stage_merges_env(self, mock_run, tmp_path): + from azext_prototype.stages.deploy_helpers import deploy_app_stage + + stage_dir = tmp_path / "app" + stage_dir.mkdir() + deploy_sh = stage_dir / "deploy.sh" + deploy_sh.write_text("#!/bin/bash\necho ok") + + mock_run.return_value = MagicMock(returncode=0, stdout="ok", stderr="") + test_env = build_deploy_env("sub-123", "tenant-456", "cid", "csecret") + + deploy_app_stage(stage_dir, "sub-123", "my-rg", env=test_env) + + passed_env = mock_run.call_args.kwargs.get("env") + assert passed_env is not None + assert passed_env["ARM_SUBSCRIPTION_ID"] == "sub-123" + assert passed_env["ARM_TENANT_ID"] == "tenant-456" + assert passed_env["SUBSCRIPTION_ID"] == "sub-123" + assert passed_env["RESOURCE_GROUP"] == "my-rg" + + @patch("subprocess.run") + def test_deploy_app_sub_dirs_receive_env(self, mock_run, tmp_path): + from azext_prototype.stages.deploy_helpers import deploy_app_stage + + stage_dir = tmp_path / "apps" + stage_dir.mkdir() + sub_app = stage_dir / "api" + sub_app.mkdir() + (sub_app / "deploy.sh").write_text("#!/bin/bash\necho ok") + + mock_run.return_value = MagicMock(returncode=0, stdout="ok", stderr="") + test_env = build_deploy_env("sub-123", "tenant-456") + + deploy_app_stage(stage_dir, "sub-123", "my-rg", env=test_env) + + passed_env = mock_run.call_args.kwargs.get("env") + assert passed_env is not None + assert passed_env["ARM_SUBSCRIPTION_ID"] == "sub-123" + assert passed_env["ARM_TENANT_ID"] == "tenant-456" + assert passed_env["RESOURCE_GROUP"] == "my-rg" + + @patch("subprocess.run") + def test_rollback_terraform_passes_env(self, mock_run): + from azext_prototype.stages.deploy_helpers import rollback_terraform + + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + test_env = build_deploy_env("sub-123", "tenant-456") + + rollback_terraform(Path("/tmp/fake"), env=test_env) + + assert mock_run.call_args.kwargs.get("env") is test_env + + @patch("subprocess.run") + def test_plan_terraform_passes_env(self, mock_run): + from azext_prototype.stages.deploy_helpers import plan_terraform + + mock_run.return_value = MagicMock(returncode=0, stdout="Plan: 1 to add", stderr="") + test_env = build_deploy_env("sub-123") + + plan_terraform(Path("/tmp/fake"), "sub-123", env=test_env) + + for c in mock_run.call_args_list: + assert c.kwargs.get("env") is test_env + + @patch("subprocess.run") + def test_rollback_bicep_adds_tenant_flag(self, mock_run): + from azext_prototype.stages.deploy_helpers import rollback_bicep + + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + test_env = build_deploy_env("sub-123", "tenant-456") + + rollback_bicep(Path("/tmp/fake"), "sub-123", "my-rg", env=test_env) + + cmd = mock_run.call_args[0][0] + assert "--tenant" in cmd + assert "tenant-456" in cmd + assert mock_run.call_args.kwargs.get("env") is test_env + + @patch("subprocess.run") + def test_whatif_bicep_adds_tenant_flag(self, mock_run): + from azext_prototype.stages.deploy_helpers import whatif_bicep + + mock_run.return_value = MagicMock(returncode=0, stdout="What-if output", stderr="") + test_env = build_deploy_env("sub-123", "tenant-789") + + with patch.object(Path, "exists", return_value=True), patch.object(Path, "glob", return_value=[]), patch( + "azext_prototype.stages.deploy_helpers.find_bicep_params", return_value=None + ), patch("azext_prototype.stages.deploy_helpers.is_subscription_scoped", return_value=False): + whatif_bicep(Path("/tmp/fake"), "sub-123", "my-rg", env=test_env) + + cmd = mock_run.call_args[0][0] + assert "--tenant" in cmd + assert "tenant-789" in cmd + + @patch("subprocess.run") + def test_deploy_terraform_no_env_still_works(self, mock_run): + """Verify backward compat — env defaults to None.""" + from azext_prototype.stages.deploy_helpers import deploy_terraform + + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + deploy_terraform(Path("/tmp/fake"), "sub-123") + + # env=None is passed (default), which means subprocess inherits os.environ + for c in mock_run.call_args_list: + assert c.kwargs.get("env") is None + + diff --git a/tests/test_deploy_session.py b/tests/stages/test_deploy_session.py similarity index 89% rename from tests/test_deploy_session.py rename to tests/stages/test_deploy_session.py index 796134f..e8c46ac 100644 --- a/tests/test_deploy_session.py +++ b/tests/stages/test_deploy_session.py @@ -1,11 +1,6 @@ -"""Tests for DeployState, DeploySession, preflight checks, and deploy stage. - -Covers the deploy-stage overhaul modules: -- DeployState: YAML persistence, stage transitions, rollback ordering -- DeploySession: interactive session, dry-run, single-stage, slash commands -- Preflight checks: subscription, IaC tool, resource group, resource providers -- DeployStage: thin orchestrator delegation -- Deploy helpers: execution primitives, RollbackManager extensions +"""Tests for deploy_session.py — branch coverage for dry-run layer branching, +preflight checks, stage deployment by layer, rollback ordering, output capture, +SP credential resolution, deployment context env building, and interactive loop. """ from __future__ import annotations @@ -14,504 +9,657 @@ from unittest.mock import MagicMock, patch import pytest -import yaml -from azext_prototype.ai.provider import AIResponse +from azext_prototype.agents.base import AgentCapability, AgentContext -# ====================================================================== -# Helpers -# ====================================================================== +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ -def _make_response(content: str = "Mock response") -> AIResponse: - return AIResponse(content=content, model="gpt-4o", usage={}) +@pytest.fixture +def deploy_context(project_with_build, sample_config): + provider = MagicMock() + provider.provider_name = "github-models" + provider.default_model = "gpt-4o" + provider.chat.return_value = MagicMock( + content="Diagnosis: resource group missing.", + model="test", + usage={"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}, + ) + return AgentContext( + project_config=sample_config, + project_dir=str(project_with_build), + ai_provider=provider, + ) -def _build_yaml(stages: list[dict] | None = None, iac_tool: str = "terraform") -> dict: - """Return a realistic build.yaml structure.""" - if stages is None: - stages = [ - { - "stage": 1, - "name": "Foundation", - "layer": "infra", - "capability": "infra", - "services": [ - { - "name": "key-vault", - "computed_name": "zd-kv-api-dev-eus", - "resource_type": "Microsoft.KeyVault/vaults", - "sku": "standard", - }, - ], - "status": "generated", - "dir": "concept/infra/terraform/stage-1-foundation", - "files": [], - }, - { - "stage": 2, - "name": "Data Layer", - "layer": "data", - "capability": "data", - "services": [ - { - "name": "sql-db", - "computed_name": "zd-sql-api-dev-eus", - "resource_type": "Microsoft.Sql/servers", - "sku": "S0", - }, - ], - "status": "generated", - "dir": "concept/infra/terraform/stage-2-data", - "files": [], - }, - { - "stage": 3, - "name": "Application", - "layer": "app", - "capability": "app", - "services": [ - { - "name": "web-app", - "computed_name": "zd-app-web-dev-eus", - "resource_type": "Microsoft.Web/sites", - "sku": "B1", - }, - ], - "status": "generated", - "dir": "concept/apps/stage-3-application", - "files": [], - }, - ] - return { - "iac_tool": iac_tool, - "deployment_stages": stages, - "_metadata": {"created": "2026-01-01T00:00:00", "last_updated": "2026-01-01T00:00:00", "iteration": 1}, - } +@pytest.fixture +def deploy_registry(): + registry = MagicMock() + + mock_qa = MagicMock() + mock_qa.name = "qa-engineer" + mock_qa.execute = MagicMock(return_value=MagicMock(content="Issue diagnosed.", model="test", usage={})) + mock_qa.get_system_messages = MagicMock(return_value=[]) + mock_qa._temperature = 0.2 + mock_qa._max_tokens = 4096 + + mock_tf = MagicMock() + mock_tf.name = "terraform-agent" + mock_tf.execute = MagicMock(return_value=MagicMock(content="Fixed.", model="test", usage={})) + mock_tf.get_system_messages = MagicMock(return_value=[]) + + mock_dev = MagicMock() + mock_dev.name = "app-developer" + mock_dev.execute = MagicMock(return_value=MagicMock(content="Fixed app.", model="test", usage={})) + mock_dev.get_system_messages = MagicMock(return_value=[]) + + mock_architect = MagicMock() + mock_architect.name = "cloud-architect" + mock_architect.execute = MagicMock(return_value=MagicMock(content="Guide fix.", model="test", usage={})) + + def find_by_cap(cap): + mapping = { + AgentCapability.QA: [mock_qa], + AgentCapability.TERRAFORM: [mock_tf], + AgentCapability.BICEP: [], + AgentCapability.DEVELOP: [mock_dev], + AgentCapability.ARCHITECT: [mock_architect], + } + return mapping.get(cap, []) + registry.find_by_capability.side_effect = find_by_cap + return registry -def _write_build_yaml(project_dir, stages=None, iac_tool="terraform"): - """Write build.yaml into the project state dir.""" - state_dir = Path(project_dir) / ".prototype" / "state" - state_dir.mkdir(parents=True, exist_ok=True) - build_data = _build_yaml(stages, iac_tool) - with open(state_dir / "build.yaml", "w", encoding="utf-8") as f: - yaml.dump(build_data, f, default_flow_style=False) - return state_dir / "build.yaml" +def _make_session(deploy_context, deploy_registry): + from azext_prototype.stages.deploy_session import DeploySession -# ====================================================================== -# DeployState tests -# ====================================================================== + return DeploySession(deploy_context, deploy_registry) -class TestDeployState: +# ------------------------------------------------------------------ +# DeployResult +# ------------------------------------------------------------------ - def test_default_state_structure(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState - ds = DeployState(str(tmp_project)) - state = ds.state - assert state["iac_tool"] == "terraform" - assert state["subscription"] == "" - assert state["resource_group"] == "" - assert state["deployment_stages"] == [] - assert state["preflight_results"] == [] - assert state["deploy_log"] == [] - assert state["rollback_log"] == [] - assert state["captured_outputs"] == {} - assert state["_metadata"]["iteration"] == 0 - - def test_load_save_roundtrip(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState +class TestDeployResult: + def test_defaults(self): + from azext_prototype.stages.deploy_session import DeployResult - ds = DeployState(str(tmp_project)) - ds._state["subscription"] = "test-sub-123" - ds._state["iac_tool"] = "bicep" - ds.save() + result = DeployResult() + assert result.deployed_stages == [] + assert result.failed_stages == [] + assert result.rolled_back_stages == [] + assert result.captured_outputs == {} + assert result.cancelled is False - ds2 = DeployState(str(tmp_project)) - loaded = ds2.load() - assert loaded["subscription"] == "test-sub-123" - assert loaded["iac_tool"] == "bicep" - assert loaded["_metadata"]["created"] is not None - assert loaded["_metadata"]["last_updated"] is not None + def test_with_values(self): + from azext_prototype.stages.deploy_session import DeployResult - def test_exists_property(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState + result = DeployResult( + deployed_stages=[{"stage": 1}], + captured_outputs={"key": "val"}, + cancelled=True, + ) + assert len(result.deployed_stages) == 1 + assert result.captured_outputs["key"] == "val" + assert result.cancelled is True - ds = DeployState(str(tmp_project)) - assert not ds.exists - ds.save() - assert ds.exists - def test_load_from_build_state(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState +# ------------------------------------------------------------------ +# _lookup_deployer_object_id +# ------------------------------------------------------------------ - build_path = _write_build_yaml(tmp_project) - ds = DeployState(str(tmp_project)) - result = ds.load_from_build_state(build_path) - assert result is True - assert len(ds.state["deployment_stages"]) == 3 - # Verify deploy-specific fields were added - stage = ds.state["deployment_stages"][0] - assert stage["deploy_status"] == "pending" - assert stage["deploy_timestamp"] is None - assert stage["deploy_output"] == "" - assert stage["deploy_error"] == "" - assert stage["rollback_timestamp"] is None +class TestLookupDeployerObjectId: + @patch("azext_prototype.stages.deploy_session.subprocess.run") + def test_user_auth_returns_oid(self, mock_run): + from azext_prototype.stages.deploy_session import _lookup_deployer_object_id - def test_load_from_build_state_missing_file(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState + mock_run.return_value = MagicMock(returncode=0, stdout="abc-123-def\n") + result = _lookup_deployer_object_id() + assert result == "abc-123-def" - ds = DeployState(str(tmp_project)) - result = ds.load_from_build_state("/nonexistent/build.yaml") - assert result is False + @patch("azext_prototype.stages.deploy_session.subprocess.run") + def test_sp_auth_uses_client_id(self, mock_run): + from azext_prototype.stages.deploy_session import _lookup_deployer_object_id - def test_load_from_build_state_no_stages(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState + mock_run.return_value = MagicMock(returncode=0, stdout="sp-oid\n") + result = _lookup_deployer_object_id(client_id="my-client") + assert result == "sp-oid" + # Should have called with sp show + call_args = mock_run.call_args[0][0] + assert "sp" in call_args + assert "my-client" in call_args - build_path = _write_build_yaml(tmp_project, stages=[]) - ds = DeployState(str(tmp_project)) - result = ds.load_from_build_state(build_path) - assert result is False + @patch("azext_prototype.stages.deploy_session.subprocess.run") + def test_failure_returns_none(self, mock_run): + from azext_prototype.stages.deploy_session import _lookup_deployer_object_id - def test_stage_transitions(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState + mock_run.return_value = MagicMock(returncode=1, stdout="") + result = _lookup_deployer_object_id() + assert result is None - build_path = _write_build_yaml(tmp_project) - ds = DeployState(str(tmp_project)) - ds.load_from_build_state(build_path) + @patch("azext_prototype.stages.deploy_session.subprocess.run", side_effect=FileNotFoundError) + def test_az_not_found_returns_none(self, mock_run): + from azext_prototype.stages.deploy_session import _lookup_deployer_object_id - # pending → deploying - ds.mark_stage_deploying(1) - assert ds.get_stage(1)["deploy_status"] == "deploying" + result = _lookup_deployer_object_id() + assert result is None - # deploying → deployed - ds.mark_stage_deployed(1, output="resource_id=abc123") - stage = ds.get_stage(1) - assert stage["deploy_status"] == "deployed" - assert stage["deploy_timestamp"] is not None - assert stage["deploy_output"] == "resource_id=abc123" - assert stage["deploy_error"] == "" - # deployed → rolled_back - ds.mark_stage_rolled_back(1) - stage = ds.get_stage(1) - assert stage["deploy_status"] == "rolled_back" - assert stage["rollback_timestamp"] is not None +# ------------------------------------------------------------------ +# _resolve_context +# ------------------------------------------------------------------ - def test_stage_failure(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState - build_path = _write_build_yaml(tmp_project) - ds = DeployState(str(tmp_project)) - ds.load_from_build_state(build_path) +class TestResolveContext: + @patch("azext_prototype.stages.deploy_session._lookup_deployer_object_id", return_value=None) + @patch("azext_prototype.stages.deploy_session.build_deploy_env", return_value={}) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub-123") + def test_falls_back_to_current_subscription(self, mock_sub, mock_env, mock_oid, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + session._resolve_context(None, None) + assert session._subscription == "sub-123" + + @patch("azext_prototype.stages.deploy_session._lookup_deployer_object_id", return_value="oid-abc") + @patch("azext_prototype.stages.deploy_session.build_deploy_env", return_value={}) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="") + @patch("azext_prototype.stages.deploy_session.set_deployment_context", return_value={"status": "ok"}) + def test_sets_deployment_context_with_tenant( + self, mock_ctx, mock_sub, mock_env, mock_oid, deploy_context, deploy_registry + ): + session = _make_session(deploy_context, deploy_registry) + session._resolve_context("sub-override", "tenant-abc") + assert session._subscription == "sub-override" + assert session._tenant == "tenant-abc" + assert session._deploy_env["TF_VAR_deployer_object_id"] == "oid-abc" - ds.mark_stage_deploying(1) - ds.mark_stage_failed(1, error="timeout connecting to Azure") - stage = ds.get_stage(1) - assert stage["deploy_status"] == "failed" - assert stage["deploy_error"] == "timeout connecting to Azure" - def test_get_pending_deployed_failed(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState +# ------------------------------------------------------------------ +# Preflight checks +# ------------------------------------------------------------------ - build_path = _write_build_yaml(tmp_project) - ds = DeployState(str(tmp_project)) - ds.load_from_build_state(build_path) - assert len(ds.get_pending_stages()) == 3 - assert len(ds.get_deployed_stages()) == 0 - assert len(ds.get_failed_stages()) == 0 +class TestPreflightChecks: + @patch("azext_prototype.stages.deploy_session.check_az_login", return_value=False) + def test_check_subscription_not_logged_in(self, mock_login, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + result = session._check_subscription("sub-123") + assert result["status"] == "fail" + assert "Not logged in" in result["message"] - ds.mark_stage_deployed(1) - ds.mark_stage_failed(2, "error") + @patch("azext_prototype.stages.deploy_session.check_az_login", return_value=True) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="other-sub") + def test_check_subscription_mismatch(self, mock_sub, mock_login, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + result = session._check_subscription("target-sub-1234") + assert result["status"] == "warn" - assert len(ds.get_pending_stages()) == 1 - assert len(ds.get_deployed_stages()) == 1 - assert len(ds.get_failed_stages()) == 1 + @patch("azext_prototype.stages.deploy_session.check_az_login", return_value=True) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="") + def test_check_subscription_pass(self, mock_sub, mock_login, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + result = session._check_subscription("") + assert result["status"] == "pass" - def test_can_rollback_ordering(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState + @patch("azext_prototype.stages.deploy_session.get_current_tenant", return_value="other-tenant") + def test_check_tenant_mismatch(self, mock_tenant, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + result = session._check_tenant("target-tenant-1234") + assert result["status"] == "warn" - build_path = _write_build_yaml(tmp_project) - ds = DeployState(str(tmp_project)) - ds.load_from_build_state(build_path) + @patch("azext_prototype.stages.deploy_session.get_current_tenant", return_value="target-tenant") + def test_check_tenant_match(self, mock_tenant, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + result = session._check_tenant("target-tenant") + assert result["status"] == "pass" - ds.mark_stage_deployed(1) - ds.mark_stage_deployed(2) - ds.mark_stage_deployed(3) + @patch("azext_prototype.stages.deploy_session.subprocess.run") + def test_check_iac_tool_terraform_found(self, mock_run, deploy_context, deploy_registry): + mock_run.return_value = MagicMock(returncode=0, stdout="Terraform v1.5.0\n") + session = _make_session(deploy_context, deploy_registry) + result = session._check_iac_tool() + assert result["status"] == "pass" + assert "Terraform" in result["message"] - # Can only rollback stage 3 (highest) - assert ds.can_rollback(3) is True - assert ds.can_rollback(2) is False # stage 3 still deployed - assert ds.can_rollback(1) is False # stages 2,3 still deployed + @patch("azext_prototype.stages.deploy_session.subprocess.run", side_effect=FileNotFoundError) + def test_check_iac_tool_terraform_not_found(self, mock_run, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + result = session._check_iac_tool() + assert result["status"] == "fail" - # Roll back stage 3 - ds.mark_stage_rolled_back(3) - assert ds.can_rollback(2) is True - assert ds.can_rollback(1) is False # stage 2 still deployed + def test_check_iac_tool_bicep(self, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + session._iac_tool = "bicep" + result = session._check_iac_tool() + assert result["status"] == "pass" + assert "Bicep" in result["name"] - def test_rollback_candidates_reverse_order(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState + @patch("azext_prototype.stages.deploy_session.subprocess.run") + def test_check_resource_group_exists(self, mock_run, deploy_context, deploy_registry): + mock_run.return_value = MagicMock(returncode=0) + session = _make_session(deploy_context, deploy_registry) + result = session._check_resource_group("sub", "rg-test") + assert result["status"] == "pass" - build_path = _write_build_yaml(tmp_project) - ds = DeployState(str(tmp_project)) - ds.load_from_build_state(build_path) + @patch("azext_prototype.stages.deploy_session.subprocess.run") + def test_check_resource_group_not_found(self, mock_run, deploy_context, deploy_registry): + mock_run.return_value = MagicMock(returncode=1) + session = _make_session(deploy_context, deploy_registry) + result = session._check_resource_group("sub", "rg-test") + assert result["status"] == "warn" - ds.mark_stage_deployed(1) - ds.mark_stage_deployed(2) - ds.mark_stage_deployed(3) - candidates = ds.get_rollback_candidates() - assert [c["stage"] for c in candidates] == [3, 2, 1] +# ------------------------------------------------------------------ +# _deploy_single_stage — layer dispatch +# ------------------------------------------------------------------ - def test_preflight_results(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState - ds = DeployState(str(tmp_project)) - results = [ - {"name": "Azure Login", "status": "pass", "message": "Logged in."}, - {"name": "Terraform", "status": "fail", "message": "Not found.", "fix_command": "brew install terraform"}, - ] - ds.set_preflight_results(results) +class TestDeploySingleStage: + def _make_ready_session(self, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + session._subscription = "sub-123" + session._resource_group = "rg-test" + session._deploy_env = {} + return session - failures = ds.get_preflight_failures() - assert len(failures) == 1 - assert failures[0]["name"] == "Terraform" + def test_manual_deploy_mode(self, deploy_context, deploy_registry): + session = self._make_ready_session(deploy_context, deploy_registry) + stage = { + "stage": 1, + "name": "Manual Step", + "layer": "infra", + "deploy_mode": "manual", + "manual_instructions": "Run migration script", + "dir": "concept/infra", + "services": [], + } + result = session._deploy_single_stage(stage) + assert result["status"] == "awaiting_manual" + assert "migration" in result["instructions"] - def test_deploy_log(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState + def test_missing_directory_skipped(self, deploy_context, deploy_registry): + session = self._make_ready_session(deploy_context, deploy_registry) + stage = { + "stage": 1, + "name": "Missing", + "layer": "infra", + "deploy_mode": "auto", + "dir": "concept/infra/terraform/nonexistent", + "services": [], + } + result = session._deploy_single_stage(stage) + assert result["status"] == "skipped" - build_path = _write_build_yaml(tmp_project) - ds = DeployState(str(tmp_project)) - ds.load_from_build_state(build_path) + @patch("azext_prototype.stages.deploy_session.deploy_terraform") + @patch("azext_prototype.stages.deploy_session.resolve_stage_secrets", return_value={}) + def test_infra_layer_dispatches_terraform(self, mock_secrets, mock_deploy, deploy_context, deploy_registry): + session = self._make_ready_session(deploy_context, deploy_registry) + # Create stage directory + stage_dir = Path(deploy_context.project_dir) / "concept" / "infra" / "terraform" / "stage-1" + stage_dir.mkdir(parents=True, exist_ok=True) - ds.mark_stage_deploying(1) - ds.mark_stage_deployed(1) + mock_deploy.return_value = {"status": "deployed", "deployment_output": ""} - assert len(ds.state["deploy_log"]) == 2 - assert ds.state["deploy_log"][0]["action"] == "deploying" - assert ds.state["deploy_log"][1]["action"] == "deployed" + stage = { + "stage": 1, + "name": "Foundation", + "layer": "infra", + "deploy_mode": "auto", + "dir": "concept/infra/terraform/stage-1", + "services": [], + } + result = session._deploy_single_stage(stage) + assert result["status"] == "deployed" + mock_deploy.assert_called_once() - def test_reset(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState + @patch("azext_prototype.stages.deploy_session.deploy_app_stage") + def test_app_layer_dispatches_app_deploy(self, mock_deploy, deploy_context, deploy_registry): + session = self._make_ready_session(deploy_context, deploy_registry) + stage_dir = Path(deploy_context.project_dir) / "concept" / "apps" / "stage-2" + stage_dir.mkdir(parents=True, exist_ok=True) - build_path = _write_build_yaml(tmp_project) - ds = DeployState(str(tmp_project)) - ds.load_from_build_state(build_path) - assert len(ds.state["deployment_stages"]) == 3 + mock_deploy.return_value = {"status": "deployed"} - ds.reset() - assert ds.state["deployment_stages"] == [] - assert ds.exists # File still exists after reset + stage = { + "stage": 2, + "name": "API", + "layer": "app", + "deploy_mode": "auto", + "dir": "concept/apps/stage-2", + "services": [], + } + result = session._deploy_single_stage(stage) + assert result["status"] == "deployed" + mock_deploy.assert_called_once() - def test_format_deploy_report(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState + def test_docs_layer_auto_deployed(self, deploy_context, deploy_registry): + session = self._make_ready_session(deploy_context, deploy_registry) + stage_dir = Path(deploy_context.project_dir) / "concept" / "docs" + stage_dir.mkdir(parents=True, exist_ok=True) - build_path = _write_build_yaml(tmp_project) - ds = DeployState(str(tmp_project)) - ds.load_from_build_state(build_path) - ds._state["subscription"] = "sub-123" + stage = { + "stage": 3, + "name": "Documentation", + "layer": "docs", + "deploy_mode": "auto", + "dir": "concept/docs", + "services": [], + } + result = session._deploy_single_stage(stage) + assert result["status"] == "deployed" - ds.mark_stage_deployed(1) - ds.mark_stage_failed(2, "timeout") - report = ds.format_deploy_report() - assert "Deploy Report" in report - assert "sub-123" in report - assert "1 deployed" in report - assert "1 failed" in report +# ------------------------------------------------------------------ +# Dry-run +# ------------------------------------------------------------------ - def test_format_stage_status(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState - build_path = _write_build_yaml(tmp_project) - ds = DeployState(str(tmp_project)) - ds.load_from_build_state(build_path) +class TestDryRun: + @patch("azext_prototype.stages.deploy_session._lookup_deployer_object_id", return_value=None) + @patch("azext_prototype.stages.deploy_session.build_deploy_env", return_value={}) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub") + @patch("azext_prototype.stages.deploy_session.plan_terraform", return_value={"output": "Plan: 3 to add"}) + @patch("azext_prototype.stages.deploy_session.resolve_stage_secrets", return_value={}) + def test_dry_run_terraform( + self, mock_secrets, mock_plan, mock_sub, mock_env, mock_oid, deploy_context, deploy_registry + ): + session = _make_session(deploy_context, deploy_registry) + # Create stage directories + stage_dir = Path(deploy_context.project_dir) / "concept" / "infra" / "terraform" / "stage-1-foundation" + stage_dir.mkdir(parents=True, exist_ok=True) - status = ds.format_stage_status() - assert "Foundation" in status - assert "Application" in status - assert "0/3 stages deployed" in status + output = [] + result = session.run_dry_run( + subscription="sub-123", + print_fn=lambda m: output.append(m), + ) + assert result.cancelled is False - def test_format_preflight_report(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState + @patch("azext_prototype.stages.deploy_session._lookup_deployer_object_id", return_value=None) + @patch("azext_prototype.stages.deploy_session.build_deploy_env", return_value={}) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub") + def test_dry_run_no_build_state(self, mock_sub, mock_env, mock_oid, project_with_config, sample_config): + from azext_prototype.stages.deploy_session import DeploySession - ds = DeployState(str(tmp_project)) - ds.set_preflight_results( - [ - {"name": "Azure Login", "status": "pass", "message": "OK"}, - { - "name": "Terraform", - "status": "warn", - "message": "Old version", - "fix_command": "brew upgrade terraform", - }, - ] + # Use project WITHOUT build state + provider = MagicMock() + provider.provider_name = "github-models" + provider.chat.return_value = MagicMock(content="test", model="test", usage={}) + ctx = AgentContext( + project_config=sample_config, + project_dir=str(project_with_config), + ai_provider=provider, ) + registry = MagicMock() + registry.find_by_capability.return_value = [] - report = ds.format_preflight_report() - assert "Preflight Checks" in report - assert "2 passed" in report or "1 passed" in report - assert "1 warning" in report + session = DeploySession(ctx, registry) + output = [] + result = session.run_dry_run( + subscription="sub-123", + print_fn=lambda m: output.append(m), + ) + assert result.cancelled is True - def test_conversation_tracking(self, tmp_project): - from azext_prototype.stages.deploy_state import DeployState + @patch("azext_prototype.stages.deploy_session._lookup_deployer_object_id", return_value=None) + @patch("azext_prototype.stages.deploy_session.build_deploy_env", return_value={}) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub") + def test_dry_run_target_stage_not_found(self, mock_sub, mock_env, mock_oid, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + output = [] + result = session.run_dry_run( + subscription="sub-123", + target_stage=999, + print_fn=lambda m: output.append(m), + ) + assert result.cancelled is True - ds = DeployState(str(tmp_project)) - ds.update_from_exchange("deploy all", "Deploying stage 1...", 1) - assert len(ds.state["conversation_history"]) == 1 - assert ds.state["conversation_history"][0]["user"] == "deploy all" +# ------------------------------------------------------------------ +# Single-stage deploy +# ------------------------------------------------------------------ -# ====================================================================== -# Preflight check tests -# ====================================================================== +class TestSingleStageDeploy: + @patch("azext_prototype.stages.deploy_session._lookup_deployer_object_id", return_value=None) + @patch("azext_prototype.stages.deploy_session.build_deploy_env", return_value={}) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub") + def test_stage_not_found(self, mock_sub, mock_env, mock_oid, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + output = [] + result = session.run_single_stage( + 999, + subscription="sub-123", + print_fn=lambda m: output.append(m), + ) + assert result.cancelled is True + @patch("azext_prototype.stages.deploy_session._lookup_deployer_object_id", return_value=None) + @patch("azext_prototype.stages.deploy_session.build_deploy_env", return_value={}) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub") + def test_no_build_state_cancels(self, mock_sub, mock_env, mock_oid, project_with_config, sample_config): + from azext_prototype.stages.deploy_session import DeploySession -class TestPreflightChecks: + provider = MagicMock() + provider.provider_name = "github-models" + provider.chat.return_value = MagicMock(content="test", model="test", usage={}) + ctx = AgentContext( + project_config=sample_config, + project_dir=str(project_with_config), + ai_provider=provider, + ) + registry = MagicMock() + registry.find_by_capability.return_value = [] - def _make_session(self, project_dir, iac_tool="terraform"): - """Create a DeploySession with mocked dependencies.""" - from azext_prototype.agents.base import AgentContext - from azext_prototype.agents.builtin import register_all_builtin - from azext_prototype.agents.registry import AgentRegistry - from azext_prototype.stages.deploy_session import DeploySession + session = DeploySession(ctx, registry) + output = [] + result = session.run_single_stage( + 1, + subscription="sub-123", + print_fn=lambda m: output.append(m), + ) + assert result.cancelled is True - config_path = Path(project_dir) / "prototype.yaml" - if not config_path.exists(): - config_data = { - "project": {"name": "test", "location": "eastus", "iac_tool": iac_tool}, - "ai": {"provider": "github-models"}, - } - with open(config_path, "w") as f: - yaml.dump(config_data, f) - context = AgentContext( - project_config={"project": {"iac_tool": iac_tool}}, - project_dir=str(project_dir), - ai_provider=MagicMock(), +# ------------------------------------------------------------------ +# Run — interactive quit +# ------------------------------------------------------------------ + + +class TestDeployRunInteractive: + @patch("azext_prototype.stages.deploy_session._lookup_deployer_object_id", return_value=None) + @patch("azext_prototype.stages.deploy_session.build_deploy_env", return_value={}) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub") + def test_quit_at_confirmation(self, mock_sub, mock_env, mock_oid, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + output = [] + result = session.run( + subscription="sub-123", + input_fn=lambda p: "quit", + print_fn=lambda m: output.append(m), ) - registry = AgentRegistry() - register_all_builtin(registry) + assert result.cancelled is True - return DeploySession(context, registry) + @patch("azext_prototype.stages.deploy_session._lookup_deployer_object_id", return_value=None) + @patch("azext_prototype.stages.deploy_session.build_deploy_env", return_value={}) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub") + def test_eof_at_confirmation(self, mock_sub, mock_env, mock_oid, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) - @patch("azext_prototype.stages.deploy_session.check_az_login", return_value=True) - @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub-123") - def test_subscription_pass(self, _mock_sub, _mock_login, tmp_project): - session = self._make_session(tmp_project) - result = session._check_subscription("sub-123") - assert result["status"] == "pass" + def raise_eof(p): + raise EOFError - @patch("azext_prototype.stages.deploy_session.check_az_login", return_value=False) - def test_subscription_fail_no_login(self, _mock_login, tmp_project): - session = self._make_session(tmp_project) - result = session._check_subscription("sub-123") - assert result["status"] == "fail" - assert "az login" in result.get("fix_command", "") + result = session.run( + subscription="sub-123", + input_fn=raise_eof, + print_fn=lambda m: None, + ) + assert result.cancelled is True - @patch("azext_prototype.stages.deploy_session.check_az_login", return_value=True) - @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="other-sub") - def test_subscription_warn_mismatch(self, _mock_sub, _mock_login, tmp_project): - session = self._make_session(tmp_project) - result = session._check_subscription("sub-123") - assert result["status"] == "warn" - @patch("subprocess.run") - def test_iac_tool_terraform_pass(self, mock_run, tmp_project): - mock_run.return_value = MagicMock(returncode=0, stdout="Terraform v1.7.0\n") - session = self._make_session(tmp_project, iac_tool="terraform") - result = session._check_iac_tool() - assert result["status"] == "pass" - assert "Terraform" in result["message"] +# ------------------------------------------------------------------ +# _capture_stage_outputs +# ------------------------------------------------------------------ - @patch("subprocess.run", side_effect=FileNotFoundError) - def test_iac_tool_terraform_missing(self, _mock_run, tmp_project): - session = self._make_session(tmp_project, iac_tool="terraform") - result = session._check_iac_tool() - assert result["status"] == "fail" - def test_iac_tool_bicep_always_pass(self, tmp_project): - session = self._make_session(tmp_project, iac_tool="bicep") - result = session._check_iac_tool() - assert result["status"] == "pass" +class TestCaptureStageOutputs: + @patch("azext_prototype.stages.deploy_session._lookup_deployer_object_id", return_value=None) + @patch("azext_prototype.stages.deploy_session.build_deploy_env", return_value={}) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub") + def test_terraform_output_capture(self, mock_sub, mock_env, mock_oid, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + session._iac_tool = "terraform" + session._output_capture = MagicMock() + session._output_capture.capture_terraform.return_value = {"key_vault_id": "/sub/rg/kv"} + session._output_capture.get_all.return_value = {"key_vault_id": "/sub/rg/kv"} + + stage = {"stage": 1, "dir": "concept/infra/terraform/stage-1", "services": []} + session._capture_stage_outputs(stage) + session._output_capture.capture_terraform.assert_called_once() - @patch("subprocess.run") - def test_resource_group_exists(self, mock_run, tmp_project): - mock_run.return_value = MagicMock(returncode=0) - session = self._make_session(tmp_project) - result = session._check_resource_group("sub-123", "my-rg") - assert result["status"] == "pass" + @patch("azext_prototype.stages.deploy_session._lookup_deployer_object_id", return_value=None) + @patch("azext_prototype.stages.deploy_session.build_deploy_env", return_value={}) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub") + def test_bicep_output_capture(self, mock_sub, mock_env, mock_oid, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + session._iac_tool = "bicep" + session._output_capture = MagicMock() + session._output_capture.capture_bicep.return_value = {"result": "ok"} + session._output_capture.get_all.return_value = {"result": "ok"} + + stage = {"stage": 1, "dir": "concept/infra/bicep/stage-1", "deploy_output": "some output", "services": []} + session._capture_stage_outputs(stage) + session._output_capture.capture_bicep.assert_called_once_with("some output") - @patch("subprocess.run") - def test_resource_group_missing_warns(self, mock_run, tmp_project): - mock_run.return_value = MagicMock(returncode=1) - session = self._make_session(tmp_project) - result = session._check_resource_group("sub-123", "my-rg") - assert result["status"] == "warn" - assert "fix_command" in result + @patch("azext_prototype.stages.deploy_session._lookup_deployer_object_id", return_value=None) + @patch("azext_prototype.stages.deploy_session.build_deploy_env", return_value={}) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub") + def test_bicep_no_output_skips(self, mock_sub, mock_env, mock_oid, deploy_context, deploy_registry): + session = _make_session(deploy_context, deploy_registry) + session._iac_tool = "bicep" + session._output_capture = MagicMock() + session._output_capture.capture_bicep.return_value = {} + + stage = {"stage": 1, "dir": "concept/infra/bicep/stage-1", "services": []} + session._capture_stage_outputs(stage) + # No deploy_output key = no capture call + session._output_capture.capture_bicep.assert_not_called() + + +# --- Additional imports from merged flat test --- +from azext_prototype.agents.base import AgentContext +from azext_prototype.agents.builtin import register_all_builtin +from azext_prototype.agents.registry import AgentRegistry +from azext_prototype.ai.provider import AIResponse +from azext_prototype.config import DEFAULT_CONFIG +from azext_prototype.config import ProjectConfig +from azext_prototype.custom import _prepare_deploy_command +from azext_prototype.custom import prototype_deploy +from azext_prototype.stages.build_state import BuildState +from azext_prototype.stages.deploy_helpers import RollbackManager +from azext_prototype.stages.deploy_helpers import _terraform_validate +from azext_prototype.stages.deploy_helpers import check_az_login +from azext_prototype.stages.deploy_helpers import deploy_terraform +from azext_prototype.stages.deploy_helpers import find_bicep_params +from azext_prototype.stages.deploy_helpers import get_current_subscription +from azext_prototype.stages.deploy_helpers import get_current_tenant +from azext_prototype.stages.deploy_helpers import is_subscription_scoped +from azext_prototype.stages.deploy_helpers import login_service_principal +from azext_prototype.stages.deploy_helpers import plan_terraform +from azext_prototype.stages.deploy_helpers import rollback_terraform +from azext_prototype.stages.deploy_helpers import set_deployment_context +from azext_prototype.stages.deploy_stage import DeployStage +from azext_prototype.stages.deploy_state import DeployState +from azext_prototype.stages.deploy_state import SyncResult +from azext_prototype.stages.deploy_state import _format_display_id +from azext_prototype.stages.deploy_state import _status_icon +from azext_prototype.stages.deploy_state import parse_stage_ref +from azext_prototype.stages.intent import IntentKind, IntentResult +from knack.util import CLIError +import os +import yaml - @patch("subprocess.run") - def test_resource_providers_skips_non_microsoft_namespaces(self, mock_run, tmp_project): - """Non-Microsoft namespaces like 'External' should NOT be checked.""" - session = self._make_session(tmp_project) - session._deploy_state._state["deployment_stages"] = [ + +# ====================================================================== + + +def _make_response(content: str = "Mock response") -> AIResponse: + return AIResponse(content=content, model="gpt-4o", usage={}) + +def _build_yaml(stages: list[dict] | None = None, iac_tool: str = "terraform") -> dict: + """Return a realistic build.yaml structure.""" + if stages is None: + stages = [ { "stage": 1, - "name": "Infra", + "name": "Foundation", + "layer": "infra", "capability": "infra", "services": [ - {"name": "ext", "resource_type": "External/something", "sku": ""}, - {"name": "hashicorp", "resource_type": "hashicorp/random", "sku": ""}, - {"name": "kv", "resource_type": "Microsoft.KeyVault/vaults", "sku": ""}, + { + "name": "key-vault", + "computed_name": "zd-kv-api-dev-eus", + "resource_type": "Microsoft.KeyVault/vaults", + "sku": "standard", + }, ], "status": "generated", - "dir": "stage-1", + "dir": "concept/infra/terraform/stage-1-foundation", "files": [], }, - ] - - mock_run.return_value = MagicMock(returncode=0, stdout="Registered\n", stderr="") - results = session._check_resource_providers("sub-123") # noqa: F841 - - # Should have checked only Microsoft.* namespaces — not External or hashicorp - checked_namespaces = [c.args[0][4] for c in mock_run.call_args_list if "provider" in c.args[0]] - assert "Microsoft.KeyVault" in checked_namespaces - assert "External" not in checked_namespaces - assert "hashicorp" not in checked_namespaces - - @patch("subprocess.run") - def test_resource_providers_skips_empty_resource_types(self, mock_run, tmp_project): - """Services with empty resource_type should be skipped.""" - session = self._make_session(tmp_project) - session._deploy_state._state["deployment_stages"] = [ { - "stage": 1, - "name": "Infra", - "capability": "infra", + "stage": 2, + "name": "Data Layer", + "layer": "data", + "capability": "data", "services": [ - {"name": "custom", "resource_type": "", "sku": ""}, + { + "name": "sql-db", + "computed_name": "zd-sql-api-dev-eus", + "resource_type": "Microsoft.Sql/servers", + "sku": "S0", + }, ], "status": "generated", - "dir": "stage-1", + "dir": "concept/infra/terraform/stage-2-data", + "files": [], + }, + { + "stage": 3, + "name": "Application", + "layer": "app", + "capability": "app", + "services": [ + { + "name": "web-app", + "computed_name": "zd-app-web-dev-eus", + "resource_type": "Microsoft.Web/sites", + "sku": "B1", + }, + ], + "status": "generated", + "dir": "concept/apps/stage-3-application", "files": [], }, ] + return { + "iac_tool": iac_tool, + "deployment_stages": stages, + "_metadata": {"created": "2026-01-01T00:00:00", "last_updated": "2026-01-01T00:00:00", "iteration": 1}, + } - results = session._check_resource_providers("sub-123") - assert results == [] - mock_run.assert_not_called() - - -# ====================================================================== -# File-based resource provider extraction tests -# ====================================================================== - +def _write_build_yaml(project_dir, stages=None, iac_tool="terraform"): + """Write build.yaml into the project state dir.""" + state_dir = Path(project_dir) / ".prototype" / "state" + state_dir.mkdir(parents=True, exist_ok=True) + build_data = _build_yaml(stages, iac_tool) + with open(state_dir / "build.yaml", "w", encoding="utf-8") as f: + yaml.dump(build_data, f, default_flow_style=False) + return state_dir / "build.yaml" class TestExtractResourceProvidersFromFiles: """Verify _extract_providers_from_files() parses IaC files for namespaces.""" @@ -684,9 +832,6 @@ def test_falls_back_to_metadata(self, mock_run, tmp_project): checked_namespaces = [c.args[0][4] for c in mock_run.call_args_list if "provider" in c.args[0]] assert "Microsoft.KeyVault" in checked_namespaces - -# ====================================================================== -# DeploySession tests # ====================================================================== @@ -1104,9 +1249,6 @@ def test_docs_stage_auto_deployed(self, tmp_project): ) assert len(result.deployed_stages) == 1 - -# ====================================================================== -# DeployStage integration tests # ====================================================================== @@ -1244,9 +1386,6 @@ def test_single_stage_delegates(self, tmp_project): assert result["mode"] == "single_stage" assert result["deployed"] == 1 - -# ====================================================================== -# Deploy helpers tests # ====================================================================== @@ -1360,9 +1499,6 @@ def test_is_subscription_scoped(self, tmp_project): bicep_file.write_text("resource kv 'Microsoft.KeyVault/vaults@2023-07-01' = {}") assert is_subscription_scoped(bicep_file) is False - -# ====================================================================== -# Rollback ordering tests (specific edge cases) # ====================================================================== @@ -1422,9 +1558,6 @@ def test_default_state_has_tenant(self, tmp_project): ds = DeployState(str(tmp_project)) assert ds.state["tenant"] == "" - -# ====================================================================== -# AI-independent deploy tests # ====================================================================== @@ -1550,9 +1683,6 @@ def test_dry_run_without_ai(self, tmp_project): # Should not raise — result is a DeployResult assert not result.cancelled or result.cancelled # always passes: just no crash - -# ====================================================================== -# Service principal login tests # ====================================================================== @@ -1638,9 +1768,6 @@ def test_get_current_tenant(self, mock_run): result = get_current_tenant() assert result == "tenant-abc" - -# ====================================================================== -# Tenant preflight tests # ====================================================================== @@ -1687,9 +1814,6 @@ def test_tenant_preflight_mismatch(self, mock_tenant, tmp_project): assert "fix_command" in result assert "az login --tenant" in result["fix_command"] - -# ====================================================================== -# SP parameter validation in prototype_deploy # ====================================================================== @@ -1762,9 +1886,6 @@ def test_sp_login_success_proceeds(self, mock_guards, mock_login, mock_dir, mock assert call_kwargs["tenant"] == "ghi" assert call_kwargs["subscription"] == "sp-sub-123" - -# ====================================================================== -# Subscription resolution chain tests # ====================================================================== @@ -1820,9 +1941,6 @@ def test_config_sub_used_when_no_cli_arg(self, mock_sub, tmp_project): joined = "\n".join(output) assert "config-sub" in joined - -# ====================================================================== -# /login slash command tests # ====================================================================== @@ -1899,9 +2017,6 @@ def test_help_includes_login(self, tmp_project): joined = "\n".join(output) assert "/login" in joined - -# ====================================================================== -# _prepare_deploy_command tests # ====================================================================== @@ -1934,9 +2049,6 @@ def test_returns_ai_provider_when_factory_succeeds(self, mock_dir, mock_check_re assert agent_context.ai_provider is mock_provider - -# ====================================================================== -# Config SP routing tests # ====================================================================== @@ -1961,9 +2073,6 @@ def test_default_config_has_sp_section(self): assert "client_secret" in sp assert "tenant_id" in sp - -# ====================================================================== -# _terraform_validate tests # ====================================================================== @@ -2022,9 +2131,6 @@ def test_deploy_terraform_validate_pass_continues(self, mock_run, tmp_project): # Should have called: init, validate, plan, apply = 4 calls assert mock_run.call_count == 4 - -# ====================================================================== -# Terraform preflight validation tests # ====================================================================== @@ -2231,9 +2337,6 @@ def test_preflight_includes_terraform_validate(self, mock_run, tmp_project): names = [r["name"] for r in results] assert any("Terraform Validate" in n for n in names) - -# ====================================================================== -# Deploy env threading tests # ====================================================================== @@ -2419,9 +2522,6 @@ def test_rollback_passes_env(self, _mock_ctx, mock_rb, tmp_project): _, kwargs = mock_rb.call_args assert kwargs["env"]["ARM_SUBSCRIPTION_ID"] == "sub-123" - -# ====================================================================== -# Deployer object ID lookup tests # ====================================================================== @@ -2545,9 +2645,6 @@ def test_resolve_context_no_oid_when_lookup_fails(self, _mock_lookup, tmp_projec assert "TF_VAR_deployer_object_id" not in session._deploy_env - -# ====================================================================== -# Coverage expansion: run() phases, slash commands, remediation # ====================================================================== @@ -2872,7 +2969,6 @@ def test_run_natural_language_multi_stage(self, tmp_project): assert "/deploy 1" in calls assert "/deploy 2" in calls - class TestSingleStageFailureRemediation: """Tests for run_single_stage failure remediation (lines 587-598).""" @@ -2968,7 +3064,6 @@ def test_single_stage_remediation_success(self, mock_tf, tmp_project): joined = "\n".join(output) assert "remediation" in joined.lower() - class TestDeployPendingStagesAwaitingManual: """Tests covering awaiting_manual status (lines 892-909).""" @@ -3125,7 +3220,6 @@ def test_manual_step_other_breaks(self, tmp_project): joined = "\n".join(output) assert "pausing" in joined.lower() or "continue" in joined.lower() - class TestRollbackAllCoverage: """Tests for _rollback_all (lines 1618-1640).""" @@ -3277,7 +3371,6 @@ def test_rollback_all_eof_cancels(self, tmp_project): joined = "\n".join(output) assert "cancelled" in joined.lower() - class TestSlashCommandPlan: """Tests covering /plan slash command (lines 1842-1875).""" @@ -3477,7 +3570,6 @@ def test_plan_app_stage_no_preview(self, tmp_project): joined = "\n".join(output) assert "app stage" in joined.lower() - class TestSlashCommandSplit: """Tests covering /split slash command (lines 1878-1903).""" @@ -3610,7 +3702,6 @@ def test_split_eof_during_input(self, tmp_project): joined = "\n".join(output) assert "at least 2" in joined.lower() or "Split" in joined - class TestSlashCommandDestroy: """Tests covering /destroy slash command (lines 1906-1927).""" @@ -3747,7 +3838,6 @@ def test_destroy_eof_cancels(self, tmp_project): joined = "\n".join(output) assert "cancelled" in joined.lower() - class TestSlashCommandManual: """Tests covering /manual slash command (lines 1930-1952).""" @@ -3880,7 +3970,6 @@ def test_manual_view_no_instructions(self, tmp_project): joined = "\n".join(output) assert "No manual instructions" in joined - class TestHandleDescribe: """Tests for _handle_describe (lines 2020-2080).""" @@ -4030,7 +4119,6 @@ def test_describe_truncates_long_output(self, tmp_project): joined = "\n".join(output) assert "truncated" in joined.lower() - class TestUnknownSlashCommand: """Tests for unknown slash command (line 2020).""" @@ -4088,7 +4176,6 @@ def test_unknown_command(self, tmp_project): joined = "\n".join(output) assert "Unknown command" in joined - class TestMaybeSpinner: """Tests for _maybe_spinner (lines 2099-2116).""" @@ -4140,7 +4227,6 @@ def test_spinner_plain_mode(self, tmp_project): with session._maybe_spinner("Working...", use_styled=False): pass # Should not crash - class TestCollectStageFileContent: """Tests for _collect_stage_file_content (lines 1178-1225).""" @@ -4219,7 +4305,6 @@ def test_truncates_large_individual_files(self, tmp_project): content = session._collect_stage_file_content(stage) assert "truncated" in content.lower() - class TestParseStageNumbers: """Tests for _parse_stage_numbers static method.""" @@ -4252,7 +4337,6 @@ def test_empty_array(self): result = DeploySession._parse_stage_numbers("[]", valid) assert result == [] - class TestWriteStageFiles: """Tests for _write_stage_files (lines 1289-1330).""" @@ -4323,7 +4407,6 @@ def test_blocked_files_dropped(self, tmp_project): assert "versions.tf" not in written_names assert "main.tf" in written_names - class TestBuildFixTask: """Tests for _build_fix_task (lines 1227-1287).""" @@ -4418,9 +4501,6 @@ def test_includes_services_in_task(self, tmp_project): assert "mykv" in task assert "Microsoft.KeyVault" in task - -# ====================================================================== -# Natural Language Intent Detection — Deploy Integration # ====================================================================== @@ -4502,9 +4582,6 @@ def test_nl_describe_stage(self, tmp_project): joined = "\n".join(output) assert "Foundation" in joined or "Stage 1" in joined - -# ====================================================================== -# Deploy State Remediation tests # ====================================================================== @@ -4634,9 +4711,6 @@ def test_remediating_status_icon(self, tmp_project): status = ds.format_stage_status() assert "<>" in status - -# ====================================================================== -# Deploy Remediation Loop tests # ====================================================================== @@ -5173,9 +5247,6 @@ def test_no_ai_provider_skips_remediation(self, tmp_project): assert remediated is None - -# ====================================================================== -# Build-Deploy Decoupling: Stable IDs, Sync, Splitting, Manual Steps # ====================================================================== @@ -5237,7 +5308,6 @@ def _build_yaml_with_ids(stages=None, iac_tool="terraform"): "_metadata": {"created": "2026-01-01T00:00:00", "last_updated": "2026-01-01T00:00:00", "iteration": 1}, } - def _write_build_yaml_with_ids(project_dir, stages=None, iac_tool="terraform"): """Write build.yaml with stable IDs.""" state_dir = Path(project_dir) / ".prototype" / "state" @@ -5247,7 +5317,6 @@ def _write_build_yaml_with_ids(project_dir, stages=None, iac_tool="terraform"): yaml.dump(data, f, default_flow_style=False) return state_dir / "build.yaml" - class TestSyncFromBuildState: def test_sync_from_build_state_fresh(self, tmp_project): @@ -5377,7 +5446,6 @@ def test_sync_orphan_sets_removed_status(self, tmp_project): assert len(removed) == 1 assert removed[0]["build_stage_id"] == "data-layer" - class TestStageSpitting: def test_split_stage(self, tmp_project): @@ -5503,7 +5571,6 @@ def test_get_stage_by_display_id(self, tmp_project): assert found_b is not None assert found_b["name"] == "Data - Schema" - class TestDeployStateNewStatuses: def test_load_from_build_state_backward_compat(self, tmp_project): @@ -5561,7 +5628,6 @@ def test_awaiting_manual_status(self, tmp_project): ds.mark_stage_awaiting_manual(1) assert ds.get_stage(1)["deploy_status"] == "awaiting_manual" - class TestManualStepDeploy: def test_manual_step_deploy(self, tmp_project): @@ -5654,7 +5720,6 @@ def test_code_split_syncs_back_to_build(self, tmp_project): deploy_stage = ds.state["deployment_stages"][1] assert deploy_stage["build_stage_id"] == "data-layer" - class TestParseStageRef: def test_parse_simple_number(self): @@ -5691,7 +5756,6 @@ def test_parse_with_whitespace(self): assert num == 3 assert label == "b" - class TestRenumberWithSubstages: def test_renumber_preserves_substage_labels(self, tmp_project): @@ -5726,7 +5790,6 @@ def test_renumber_preserves_substage_labels(self, tmp_project): assert stages[2]["stage"] == 2 assert stages[2]["substage_label"] is None - class TestFormatDisplayId: def test_format_top_level(self): @@ -5744,7 +5807,6 @@ def test_format_no_label(self): assert _format_display_id({"stage": 1, "substage_label": None}) == "1" - class TestNewStatusIcons: def test_removed_icon(self): @@ -5770,7 +5832,6 @@ def test_existing_icons_unchanged(self): assert _status_icon("failed") == " x" assert _status_icon("remediating") == "<>" - class TestDeployReportFormatting: def test_format_shows_removed_stages(self, tmp_project): diff --git a/tests/stages/test_deploy_stage.py b/tests/stages/test_deploy_stage.py new file mode 100644 index 0000000..3aa2fba --- /dev/null +++ b/tests/stages/test_deploy_stage.py @@ -0,0 +1,309 @@ +"""Tests for DeployStage — routing logic, state transitions, guard conditions. + +Covers: +- Guard conditions (project_initialized, build_complete, az_logged_in) +- Routing: --status, --reset, --dry-run, --stage N, interactive +- State transitions between modes +- _result_to_dict conversion +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from azext_prototype.stages.base import StageState + +# ====================================================================== +# Fixtures +# ====================================================================== + + +@pytest.fixture +def deploy_stage(): + from azext_prototype.stages.deploy_stage import DeployStage + + return DeployStage() + + +@pytest.fixture +def agent_context(project_with_build, sample_config): + from azext_prototype.agents.base import AgentContext + + provider = MagicMock() + provider.provider_name = "github-models" + provider.chat.return_value = MagicMock(content="ok", model="test", usage={}) + return AgentContext( + project_config=sample_config, + project_dir=str(project_with_build), + ai_provider=provider, + ) + + +@pytest.fixture +def registry(): + return MagicMock() + + +# ====================================================================== +# Guard validation +# ====================================================================== + + +class TestDeployStageGuards: + """Test deploy stage prerequisites.""" + + def test_guards_return_three_guards(self, deploy_stage): + guards = deploy_stage.get_guards() + assert len(guards) == 3 + names = [g.name for g in guards] + assert "project_initialized" in names + assert "build_complete" in names + assert "az_logged_in" in names + + def test_all_guards_pass(self, deploy_stage, project_with_build, monkeypatch): + monkeypatch.chdir(project_with_build) + with patch("azext_prototype.stages.deploy_stage.check_az_login", return_value=True): + # Reload guards with the patched function + from azext_prototype.stages.deploy_stage import DeployStage + + stage = DeployStage() + can_run, failures = stage.can_run() + assert can_run is True + assert failures == [] + + def test_missing_project_yaml(self, deploy_stage, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + can_run, failures = deploy_stage.can_run() + assert can_run is False + assert any("init" in f.lower() or "prototype" in f.lower() for f in failures) + + def test_missing_build_state(self, deploy_stage, project_with_config, monkeypatch): + monkeypatch.chdir(project_with_config) + can_run, failures = deploy_stage.can_run() + assert can_run is False + assert any("build" in f.lower() for f in failures) + + def test_not_logged_in(self, deploy_stage, project_with_build, monkeypatch): + monkeypatch.chdir(project_with_build) + with patch("azext_prototype.stages.deploy_stage.check_az_login", return_value=False): + from azext_prototype.stages.deploy_stage import DeployStage + + stage = DeployStage() + can_run, failures = stage.can_run() + assert can_run is False + assert any("login" in f.lower() for f in failures) + + +# ====================================================================== +# --status routing +# ====================================================================== + + +class TestDeployStageStatusRoute: + """Test --status shows current progress without starting session.""" + + def test_status_route(self, deploy_stage, agent_context, registry): + with patch("azext_prototype.stages.deploy_stage.DeployState") as mock_ds, patch( + "azext_prototype.stages.deploy_stage.default_console" + ) as mock_console: + mock_ds.return_value.load.return_value = None + mock_ds.return_value.format_stage_status.return_value = "Stage status output" + result = deploy_stage.execute(agent_context, registry, status=True) + assert result["status"] == "status_displayed" + assert deploy_stage.state == StageState.COMPLETED + mock_console.print_info.assert_called_once() + + +# ====================================================================== +# --reset routing +# ====================================================================== + + +class TestDeployStageResetRoute: + """Test --reset clears deploy state.""" + + def test_reset_route(self, deploy_stage, agent_context, registry): + with patch("azext_prototype.stages.deploy_stage.DeployState") as mock_ds, patch( + "azext_prototype.stages.deploy_stage.default_console" + ): + result = deploy_stage.execute(agent_context, registry, reset=True) + assert result["status"] == "reset" + assert deploy_stage.state == StageState.COMPLETED + mock_ds.return_value.reset.assert_called_once() + + +# ====================================================================== +# --dry-run routing +# ====================================================================== + + +class TestDeployStageDryRunRoute: + """Test --dry-run delegates to session.run_dry_run().""" + + def test_dry_run_route(self, deploy_stage, agent_context, registry): + mock_result = MagicMock() + mock_result.failed_stages = [] + mock_result.cancelled = False + mock_result.deployed_stages = [] + mock_result.rolled_back_stages = [] + mock_result.captured_outputs = {} + + with patch("azext_prototype.stages.deploy_stage.DeploySession") as mock_session_cls: + mock_session_cls.return_value.run_dry_run.return_value = mock_result + result = deploy_stage.execute(agent_context, registry, dry_run=True, subscription="sub-123") + assert result["status"] == "success" + assert result["mode"] == "dry-run" + assert deploy_stage.state == StageState.COMPLETED + + def test_dry_run_with_stage(self, deploy_stage, agent_context, registry): + mock_result = MagicMock() + mock_result.failed_stages = [] + mock_result.cancelled = False + mock_result.deployed_stages = [] + mock_result.rolled_back_stages = [] + mock_result.captured_outputs = {} + + with patch("azext_prototype.stages.deploy_stage.DeploySession") as mock_session_cls: + mock_session_cls.return_value.run_dry_run.return_value = mock_result + result = deploy_stage.execute(agent_context, registry, dry_run=True, stage=2) + assert result["mode"] == "dry-run" + mock_session_cls.return_value.run_dry_run.assert_called_once() + + +# ====================================================================== +# --stage N routing +# ====================================================================== + + +class TestDeployStageSingleStageRoute: + """Test --stage N delegates to session.run_single_stage().""" + + def test_single_stage_success(self, deploy_stage, agent_context, registry): + mock_result = MagicMock() + mock_result.failed_stages = [] + mock_result.cancelled = False + mock_result.deployed_stages = ["stage-1"] + mock_result.rolled_back_stages = [] + mock_result.captured_outputs = {} + + with patch("azext_prototype.stages.deploy_stage.DeploySession") as mock_session_cls: + mock_session_cls.return_value.run_single_stage.return_value = mock_result + result = deploy_stage.execute(agent_context, registry, stage=1, subscription="sub-123") + assert result["mode"] == "single_stage" + assert deploy_stage.state == StageState.COMPLETED + + def test_single_stage_failure(self, deploy_stage, agent_context, registry): + mock_result = MagicMock() + mock_result.failed_stages = ["stage-1"] + mock_result.cancelled = False + mock_result.deployed_stages = [] + mock_result.rolled_back_stages = [] + mock_result.captured_outputs = {} + + with patch("azext_prototype.stages.deploy_stage.DeploySession") as mock_session_cls: + mock_session_cls.return_value.run_single_stage.return_value = mock_result + result = deploy_stage.execute(agent_context, registry, stage=1) + assert result["status"] == "partial_failure" + assert deploy_stage.state == StageState.FAILED + + +# ====================================================================== +# Interactive (default) routing +# ====================================================================== + + +class TestDeployStageInteractiveRoute: + """Test default interactive mode delegates to session.run().""" + + def test_interactive_success(self, deploy_stage, agent_context, registry): + mock_result = MagicMock() + mock_result.failed_stages = [] + mock_result.cancelled = False + mock_result.deployed_stages = ["stage-1", "stage-2"] + mock_result.rolled_back_stages = [] + mock_result.captured_outputs = {"terraform": {"key": "val"}} + + with patch("azext_prototype.stages.deploy_stage.DeploySession") as mock_session_cls: + mock_session_cls.return_value.run.return_value = mock_result + result = deploy_stage.execute(agent_context, registry) + assert result["status"] == "success" + assert result["mode"] == "interactive" + assert result["deployed"] == 2 + assert deploy_stage.state == StageState.COMPLETED + + def test_interactive_cancelled(self, deploy_stage, agent_context, registry): + mock_result = MagicMock() + mock_result.cancelled = True + mock_result.failed_stages = [] + mock_result.deployed_stages = [] + mock_result.rolled_back_stages = [] + mock_result.captured_outputs = {} + + with patch("azext_prototype.stages.deploy_stage.DeploySession") as mock_session_cls: + mock_session_cls.return_value.run.return_value = mock_result + result = deploy_stage.execute(agent_context, registry) + assert result["status"] == "cancelled" + assert deploy_stage.state == StageState.COMPLETED + + def test_interactive_partial_failure(self, deploy_stage, agent_context, registry): + mock_result = MagicMock() + mock_result.cancelled = False + mock_result.failed_stages = ["stage-2"] + mock_result.deployed_stages = ["stage-1"] + mock_result.rolled_back_stages = [] + mock_result.captured_outputs = {} + + with patch("azext_prototype.stages.deploy_stage.DeploySession") as mock_session_cls: + mock_session_cls.return_value.run.return_value = mock_result + result = deploy_stage.execute(agent_context, registry) + assert result["status"] == "partial_failure" + assert deploy_stage.state == StageState.FAILED + + +# ====================================================================== +# _result_to_dict +# ====================================================================== + + +class TestResultToDict: + """Test the result-to-dict conversion helper.""" + + def test_success(self): + from azext_prototype.stages.deploy_stage import _result_to_dict + + result = MagicMock() + result.failed_stages = [] + result.cancelled = False + result.deployed_stages = ["a", "b"] + result.rolled_back_stages = [] + result.captured_outputs = {"tf": {"x": 1}} + d = _result_to_dict(result, "test") + assert d["status"] == "success" + assert d["mode"] == "test" + assert d["deployed"] == 2 + assert d["failed"] == 0 + + def test_partial_failure(self): + from azext_prototype.stages.deploy_stage import _result_to_dict + + result = MagicMock() + result.failed_stages = ["x"] + result.cancelled = False + result.deployed_stages = ["a"] + result.rolled_back_stages = ["b"] + result.captured_outputs = {} + d = _result_to_dict(result, "interactive") + assert d["status"] == "partial_failure" + assert d["rolled_back"] == 1 + + def test_cancelled(self): + from azext_prototype.stages.deploy_stage import _result_to_dict + + result = MagicMock() + result.failed_stages = [] + result.cancelled = True + result.deployed_stages = [] + result.rolled_back_stages = [] + result.captured_outputs = {} + d = _result_to_dict(result, "interactive") + assert d["status"] == "cancelled" diff --git a/tests/stages/test_deploy_state.py b/tests/stages/test_deploy_state.py new file mode 100644 index 0000000..982d1f8 --- /dev/null +++ b/tests/stages/test_deploy_state.py @@ -0,0 +1,680 @@ +"""Tests for DeployState — stage sync, legacy fallback, state persistence. + +Covers: +- Stage sync with build state (matched, orphaned, new stages) +- Legacy fallback matching (name+capability) +- Post-load backfill of build_stage_ids +- Stage splitting (1:N divergence) +- Stage status transitions (deploying, deployed, failed, rolled_back, etc.) +- Rollback ordering enforcement +- Preflight result tracking +- Audit logging (deploy_log, rollback_log) +- Display formatting methods +- parse_stage_ref / _format_display_id / _status_icon +- Conversation tracking +- add_patch_stages / renumber_stages +""" + +from pathlib import Path + +import pytest +import yaml + +from azext_prototype.stages.deploy_state import ( + DeployState, + _enrich_deploy_fields, + _format_display_id, + _status_icon, + parse_stage_ref, +) + +# ====================================================================== +# Fixtures +# ====================================================================== + + +@pytest.fixture +def deploy_state(tmp_project): + ds = DeployState(str(tmp_project)) + return ds + + +@pytest.fixture +def deploy_state_with_stages(deploy_state): + """Deploy state with 2 stages loaded.""" + deploy_state._state["deployment_stages"] = [ + { + "stage": 1, + "name": "Foundation", + "capability": "infra", + "services": [{"name": "kv"}], + "build_stage_id": "foundation", + "deploy_status": "pending", + "deploy_timestamp": None, + "deploy_output": "", + "deploy_error": "", + "rollback_timestamp": None, + "remediation_attempts": 0, + "deploy_mode": "auto", + "manual_instructions": None, + "substage_label": None, + "_is_substage": False, + "_destruction_declined": False, + "dir": "concept/infra/stage-1", + "files": ["main.tf"], + }, + { + "stage": 2, + "name": "Application", + "capability": "app", + "services": [{"name": "web"}], + "build_stage_id": "application", + "deploy_status": "pending", + "deploy_timestamp": None, + "deploy_output": "", + "deploy_error": "", + "rollback_timestamp": None, + "remediation_attempts": 0, + "deploy_mode": "auto", + "manual_instructions": None, + "substage_label": None, + "_is_substage": False, + "_destruction_declined": False, + "dir": "concept/apps/stage-2", + "files": ["app.py"], + }, + ] + return deploy_state + + +# ====================================================================== +# load_from_build_state +# ====================================================================== + + +class TestLoadFromBuildState: + """Test importing deployment stages from build.yaml.""" + + def test_imports_stages(self, deploy_state, project_with_build): + build_path = Path(str(project_with_build)) / ".prototype" / "state" / "build.yaml" + result = deploy_state.load_from_build_state(build_path) + assert result is True + stages = deploy_state._state["deployment_stages"] + assert len(stages) == 2 + assert stages[0]["build_stage_id"] is not None + assert stages[0]["deploy_status"] == "pending" + + def test_missing_build_file(self, deploy_state, tmp_path): + result = deploy_state.load_from_build_state(tmp_path / "missing.yaml") + assert result is False + + def test_empty_build_stages(self, deploy_state, tmp_path): + build_file = tmp_path / "build.yaml" + build_file.write_text(yaml.dump({"deployment_stages": []}), encoding="utf-8") + result = deploy_state.load_from_build_state(build_file) + assert result is False + + def test_bad_yaml(self, deploy_state, tmp_path): + build_file = tmp_path / "build.yaml" + build_file.write_text(": invalid: yaml: {{", encoding="utf-8") + result = deploy_state.load_from_build_state(build_file) + assert result is False + + def test_iac_tool_carried_over(self, deploy_state, tmp_path): + build_file = tmp_path / "build.yaml" + build_file.write_text( + yaml.dump( + { + "iac_tool": "bicep", + "deployment_stages": [{"stage": 1, "name": "Foundation"}], + } + ), + encoding="utf-8", + ) + deploy_state.load_from_build_state(build_file) + assert deploy_state._state["iac_tool"] == "bicep" + + +# ====================================================================== +# sync_from_build_state +# ====================================================================== + + +class TestSyncFromBuildState: + """Test smart reconciliation: matched, orphaned, new stages.""" + + def test_matched_stages(self, deploy_state_with_stages, tmp_path): + """Build state with same IDs → matched, no new or orphaned.""" + build_file = tmp_path / "build.yaml" + build_file.write_text( + yaml.dump( + { + "deployment_stages": [ + {"id": "foundation", "name": "Foundation", "capability": "infra"}, + {"id": "application", "name": "Application", "capability": "app"}, + ] + } + ), + encoding="utf-8", + ) + result = deploy_state_with_stages.sync_from_build_state(build_file) + assert result.matched == 2 + assert result.created == 0 + assert result.orphaned == 0 + + def test_new_stage_created(self, deploy_state_with_stages, tmp_path): + """Build state has an extra stage → created.""" + build_file = tmp_path / "build.yaml" + build_file.write_text( + yaml.dump( + { + "deployment_stages": [ + {"id": "foundation", "name": "Foundation", "capability": "infra"}, + {"id": "application", "name": "Application", "capability": "app"}, + {"id": "database", "name": "Database", "capability": "db"}, + ] + } + ), + encoding="utf-8", + ) + result = deploy_state_with_stages.sync_from_build_state(build_file) + assert result.matched == 2 + assert result.created == 1 + assert any("Database" in d for d in result.details) + + def test_orphaned_stage(self, deploy_state_with_stages, tmp_path): + """Build state removed a stage → orphaned.""" + build_file = tmp_path / "build.yaml" + build_file.write_text( + yaml.dump( + { + "deployment_stages": [ + {"id": "foundation", "name": "Foundation", "capability": "infra"}, + ] + } + ), + encoding="utf-8", + ) + result = deploy_state_with_stages.sync_from_build_state(build_file) + assert result.orphaned == 1 + # The orphaned stage should be marked as "removed" + orphaned = [ + s for s in deploy_state_with_stages._state["deployment_stages"] if s.get("deploy_status") == "removed" + ] + assert len(orphaned) == 1 + + def test_legacy_fallback_matching(self, deploy_state, tmp_path): + """Stage without build_stage_id matches by name+capability.""" + deploy_state._state["deployment_stages"] = [ + { + "stage": 1, + "name": "Foundation", + "capability": "infra", + "deploy_status": "deployed", + "deploy_mode": "auto", + } + ] + build_file = tmp_path / "build.yaml" + build_file.write_text( + yaml.dump( + { + "deployment_stages": [ + {"id": "foundation", "name": "Foundation", "capability": "infra"}, + ] + } + ), + encoding="utf-8", + ) + result = deploy_state.sync_from_build_state(build_file) + assert result.matched == 1 + # build_stage_id should now be set + stage = deploy_state._state["deployment_stages"][0] + assert stage.get("build_stage_id") == "foundation" + + def test_code_change_detection(self, deploy_state_with_stages, tmp_path): + """When a matched stage's code changed, mark _code_updated.""" + deploy_state_with_stages._state["deployment_stages"][0]["deploy_status"] = "deployed" + build_file = tmp_path / "build.yaml" + build_file.write_text( + yaml.dump( + { + "deployment_stages": [ + { + "id": "foundation", + "name": "Foundation", + "capability": "infra", + "dir": "concept/infra/stage-1-v2", + "files": ["main.tf", "new.tf"], + }, + {"id": "application", "name": "Application", "capability": "app"}, + ] + } + ), + encoding="utf-8", + ) + result = deploy_state_with_stages.sync_from_build_state(build_file) + assert result.updated_code == 1 + + def test_missing_build_file(self, deploy_state, tmp_path): + result = deploy_state.sync_from_build_state(tmp_path / "missing.yaml") + assert "not found" in result.details[0].lower() + + def test_bad_yaml(self, deploy_state, tmp_path): + build_file = tmp_path / "build.yaml" + build_file.write_text(": bad yaml {{", encoding="utf-8") + result = deploy_state.sync_from_build_state(build_file) + assert len(result.details) == 1 + + def test_empty_deployment_stages(self, deploy_state, tmp_path): + build_file = tmp_path / "build.yaml" + build_file.write_text(yaml.dump({"deployment_stages": []}), encoding="utf-8") + result = deploy_state.sync_from_build_state(build_file) + assert "no deployment_stages" in result.details[0].lower() + + +# ====================================================================== +# Post-load backfill +# ====================================================================== + + +class TestPostLoadBackfill: + """Test _backfill_build_stage_ids on legacy state.""" + + def test_backfills_missing_ids(self, deploy_state): + deploy_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Data Layer"}, + ] + deploy_state._backfill_build_stage_ids() + stage = deploy_state._state["deployment_stages"][0] + assert stage["build_stage_id"] == "data-layer" + assert "deploy_status" in stage # _enrich_deploy_fields was called + + def test_preserves_existing_ids(self, deploy_state): + deploy_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Foundation", "build_stage_id": "custom-id"}, + ] + deploy_state._backfill_build_stage_ids() + assert deploy_state._state["deployment_stages"][0]["build_stage_id"] == "custom-id" + + +# ====================================================================== +# Stage splitting +# ====================================================================== + + +class TestStageSplitting: + """Test split_stage for 1:N divergence.""" + + def test_split_creates_substages(self, deploy_state_with_stages): + deploy_state_with_stages.split_stage( + 1, + [ + {"name": "Foundation-VNet", "dir": "concept/infra/vnet"}, + {"name": "Foundation-KV", "dir": "concept/infra/kv"}, + ], + ) + stages = deploy_state_with_stages._state["deployment_stages"] + substages = [s for s in stages if s.get("substage_label")] + assert len(substages) == 2 + assert substages[0]["substage_label"] == "a" + assert substages[1]["substage_label"] == "b" + assert all(s["_is_substage"] for s in substages) + assert all(s["build_stage_id"] == "foundation" for s in substages) + + def test_split_nonexistent_stage(self, deploy_state_with_stages): + """Splitting a stage that doesn't exist is a no-op.""" + deploy_state_with_stages.split_stage(99, [{"name": "X", "dir": "x"}]) + # No change + substages = [s for s in deploy_state_with_stages._state["deployment_stages"] if s.get("substage_label")] + assert len(substages) == 0 + + +# ====================================================================== +# Stage status transitions +# ====================================================================== + + +class TestStageStatusTransitions: + """Test all status transition methods.""" + + def test_mark_deploying(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_deploying(1) + assert deploy_state_with_stages.get_stage(1)["deploy_status"] == "deploying" + + def test_mark_deployed(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_deployed(1, output="tf output") + stage = deploy_state_with_stages.get_stage(1) + assert stage["deploy_status"] == "deployed" + assert stage["deploy_output"] == "tf output" + assert stage["deploy_error"] == "" + assert stage["deploy_timestamp"] is not None + + def test_mark_failed(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_failed(1, error="init failed") + stage = deploy_state_with_stages.get_stage(1) + assert stage["deploy_status"] == "failed" + assert stage["deploy_error"] == "init failed" + + def test_mark_rolled_back(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_rolled_back(1) + stage = deploy_state_with_stages.get_stage(1) + assert stage["deploy_status"] == "rolled_back" + assert stage["rollback_timestamp"] is not None + + def test_mark_remediating_bumps_counter(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_remediating(1) + assert deploy_state_with_stages.get_stage(1)["remediation_attempts"] == 1 + deploy_state_with_stages.mark_stage_remediating(1) + assert deploy_state_with_stages.get_stage(1)["remediation_attempts"] == 2 + + def test_reset_stage_to_pending(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_failed(1, error="err") + deploy_state_with_stages.reset_stage_to_pending(1) + stage = deploy_state_with_stages.get_stage(1) + assert stage["deploy_status"] == "pending" + assert stage["deploy_error"] == "" + + def test_mark_stage_removed(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_removed(1) + assert deploy_state_with_stages.get_stage(1)["deploy_status"] == "removed" + + def test_mark_stage_destroyed(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_destroyed(1) + assert deploy_state_with_stages.get_stage(1)["deploy_status"] == "destroyed" + + def test_mark_stage_awaiting_manual(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_awaiting_manual(1) + assert deploy_state_with_stages.get_stage(1)["deploy_status"] == "awaiting_manual" + + def test_mark_nonexistent_stage_no_error(self, deploy_state_with_stages): + """Marking a nonexistent stage is a no-op.""" + deploy_state_with_stages.mark_stage_deploying(99) + deploy_state_with_stages.mark_stage_deployed(99) + deploy_state_with_stages.mark_stage_failed(99) + deploy_state_with_stages.mark_stage_rolled_back(99) + + +# ====================================================================== +# Rollback ordering +# ====================================================================== + + +class TestRollbackOrdering: + """Test can_rollback enforces ordered rollback.""" + + def test_can_rollback_when_no_later_deployed(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_deployed(1) + assert deploy_state_with_stages.can_rollback(1) is True + + def test_cannot_rollback_when_later_deployed(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_deployed(1) + deploy_state_with_stages.mark_stage_deployed(2) + assert deploy_state_with_stages.can_rollback(1) is False + + def test_can_rollback_highest_stage(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_deployed(1) + deploy_state_with_stages.mark_stage_deployed(2) + assert deploy_state_with_stages.can_rollback(2) is True + + def test_get_rollback_candidates_sorted(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_deployed(1) + deploy_state_with_stages.mark_stage_deployed(2) + candidates = deploy_state_with_stages.get_rollback_candidates() + assert candidates[0]["stage"] == 2 + assert candidates[1]["stage"] == 1 + + +# ====================================================================== +# Stage queries +# ====================================================================== + + +class TestStageQueries: + """Test various stage query methods.""" + + def test_get_stage(self, deploy_state_with_stages): + assert deploy_state_with_stages.get_stage(1)["name"] == "Foundation" + assert deploy_state_with_stages.get_stage(99) is None + + def test_get_all_stages_for_num(self, deploy_state_with_stages): + stages = deploy_state_with_stages.get_all_stages_for_num(1) + assert len(stages) == 1 + + def test_get_pending_stages(self, deploy_state_with_stages): + pending = deploy_state_with_stages.get_pending_stages() + assert len(pending) == 2 + + def test_get_deployed_stages(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_deployed(1) + deployed = deploy_state_with_stages.get_deployed_stages() + assert len(deployed) == 1 + + def test_get_failed_stages(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_failed(1) + failed = deploy_state_with_stages.get_failed_stages() + assert len(failed) == 1 + + def test_get_stage_by_display_id(self, deploy_state_with_stages): + stage = deploy_state_with_stages.get_stage_by_display_id("1") + assert stage is not None + assert stage["name"] == "Foundation" + + def test_get_stage_by_display_id_invalid(self, deploy_state_with_stages): + assert deploy_state_with_stages.get_stage_by_display_id("abc") is None + + def test_get_stage_by_display_id_nonexistent(self, deploy_state_with_stages): + assert deploy_state_with_stages.get_stage_by_display_id("99") is None + + def test_get_stage_groups(self, deploy_state_with_stages): + groups = deploy_state_with_stages.get_stage_groups() + assert "foundation" in groups + assert "application" in groups + + def test_get_stages_for_build_stage(self, deploy_state_with_stages): + stages = deploy_state_with_stages.get_stages_for_build_stage("foundation") + assert len(stages) == 1 + + +# ====================================================================== +# Preflight +# ====================================================================== + + +class TestPreflight: + """Test preflight result tracking.""" + + def test_set_and_get_preflight_results(self, deploy_state): + results = [ + {"name": "az-login", "status": "pass", "message": "Logged in"}, + {"name": "rg-exists", "status": "fail", "message": "RG not found", "fix_command": "az group create"}, + ] + deploy_state.set_preflight_results(results) + failures = deploy_state.get_preflight_failures() + assert len(failures) == 1 + assert failures[0]["name"] == "rg-exists" + + def test_empty_preflight(self, deploy_state): + assert deploy_state.get_preflight_failures() == [] + + +# ====================================================================== +# Audit logging +# ====================================================================== + + +class TestAuditLogging: + """Test deploy and rollback log entries.""" + + def test_deploy_log_entry(self, deploy_state): + deploy_state.add_deploy_log_entry(1, "deploying") + logs = deploy_state._state["deploy_log"] + assert len(logs) == 1 + assert logs[0]["stage"] == 1 + assert logs[0]["action"] == "deploying" + + def test_rollback_log_entry(self, deploy_state): + deploy_state.add_rollback_log_entry(1, "user requested") + logs = deploy_state._state["rollback_log"] + assert len(logs) == 1 + assert logs[0]["stage"] == 1 + + +# ====================================================================== +# Conversation tracking +# ====================================================================== + + +class TestConversationTracking: + """Test exchange recording.""" + + def test_update_from_exchange(self, deploy_state): + deploy_state.update_from_exchange("deploy stage 1", "Deploying...", 1) + history = deploy_state._state["conversation_history"] + assert len(history) == 1 + assert history[0]["user"] == "deploy stage 1" + assert history[0]["exchange"] == 1 + + +# ====================================================================== +# add_patch_stages / renumber_stages +# ====================================================================== + + +class TestPatchAndRenumber: + """Test adding patch stages and renumbering.""" + + def test_add_patch_stages_before_docs(self, deploy_state): + deploy_state._state["deployment_stages"] = [ + {"stage": 1, "name": "Foundation", "capability": "infra"}, + {"stage": 2, "name": "Documentation", "capability": "docs"}, + ] + deploy_state.add_patch_stages([{"name": "Hotfix", "capability": "infra", "build_stage_id": "hotfix"}]) + stages = deploy_state._state["deployment_stages"] + names = [s["name"] for s in stages] + assert names.index("Hotfix") < names.index("Documentation") + + def test_renumber_stages(self, deploy_state): + deploy_state._state["deployment_stages"] = [ + {"stage": 5, "name": "A"}, + {"stage": 10, "name": "B"}, + ] + deploy_state.renumber_stages() + assert deploy_state._state["deployment_stages"][0]["stage"] == 1 + assert deploy_state._state["deployment_stages"][1]["stage"] == 2 + + +# ====================================================================== +# Formatting +# ====================================================================== + + +class TestFormatting: + """Test display formatting methods.""" + + def test_format_stage_status_empty(self, deploy_state): + result = deploy_state.format_stage_status() + assert "No deployment stages" in result + + def test_format_stage_status_with_stages(self, deploy_state_with_stages): + result = deploy_state_with_stages.format_stage_status() + assert "Foundation" in result + assert "Application" in result + assert "0/2" in result + + def test_format_deploy_report(self, deploy_state_with_stages): + deploy_state_with_stages.mark_stage_deployed(1) + report = deploy_state_with_stages.format_deploy_report() + assert "Deploy Report" in report + assert "Foundation" in report + + def test_format_preflight_report_empty(self, deploy_state): + result = deploy_state.format_preflight_report() + assert "No preflight checks" in result + + def test_format_preflight_report_with_results(self, deploy_state): + deploy_state.set_preflight_results( + [ + {"name": "login", "status": "pass", "message": "OK"}, + {"name": "rg", "status": "fail", "message": "Missing", "fix_command": "az group create"}, + ] + ) + result = deploy_state.format_preflight_report() + assert "login" in result + assert "Fix:" in result + + def test_format_outputs_empty(self, deploy_state): + result = deploy_state.format_outputs() + assert "No deployment outputs" in result + + def test_format_outputs_with_data(self, deploy_state): + deploy_state._state["captured_outputs"] = { + "terraform": {"endpoint": "https://app.com"}, + } + result = deploy_state.format_outputs() + assert "endpoint" in result + assert "https://app.com" in result + + +# ====================================================================== +# Module-level helpers +# ====================================================================== + + +class TestModuleHelpers: + """Test parse_stage_ref, _format_display_id, _status_icon.""" + + def test_parse_stage_ref_number_only(self): + num, label = parse_stage_ref("5") + assert num == 5 + assert label is None + + def test_parse_stage_ref_with_label(self): + num, label = parse_stage_ref("5a") + assert num == 5 + assert label == "a" + + def test_parse_stage_ref_invalid(self): + num, label = parse_stage_ref("abc") + assert num is None + assert label is None + + def test_parse_stage_ref_whitespace(self): + num, label = parse_stage_ref(" 3b ") + assert num == 3 + assert label == "b" + + def test_format_display_id_plain(self): + assert _format_display_id({"stage": 3}) == "3" + + def test_format_display_id_with_label(self): + assert _format_display_id({"stage": 3, "substage_label": "b"}) == "3b" + + def test_status_icon_mapping(self): + assert _status_icon("pending") == " " + assert _status_icon("deploying") == ">>" + assert _status_icon("deployed") == " v" + assert _status_icon("failed") == " x" + assert _status_icon("rolled_back") == " ~" + assert _status_icon("unknown") == " " + + +# ====================================================================== +# _enrich_deploy_fields +# ====================================================================== + + +class TestEnrichDeployFields: + """Test _enrich_deploy_fields sets defaults.""" + + def test_adds_all_fields(self): + stage = {"name": "test"} + enriched = _enrich_deploy_fields(stage) + assert enriched["deploy_status"] == "pending" + assert enriched["deploy_timestamp"] is None + assert enriched["remediation_attempts"] == 0 + assert enriched["_is_substage"] is False + + def test_preserves_existing_values(self): + stage = {"name": "test", "deploy_status": "deployed"} + enriched = _enrich_deploy_fields(stage) + assert enriched["deploy_status"] == "deployed" diff --git a/tests/stages/test_design_stage.py b/tests/stages/test_design_stage.py new file mode 100644 index 0000000..236693d --- /dev/null +++ b/tests/stages/test_design_stage.py @@ -0,0 +1,513 @@ +"""Tests for design_stage.py — branch coverage for artifact change detection, +skip-discovery flow, heading extraction, summary generation, template matching, +and format_section_elapsed. +""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from azext_prototype.agents.base import AgentCapability, AgentContext + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def design_context(project_with_config, sample_config): + provider = MagicMock() + provider.provider_name = "github-models" + provider.default_model = "gpt-4o" + provider.chat.return_value = MagicMock( + content="## Solution Overview\nSample design output.", + model="test", + usage={"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}, + ) + return AgentContext( + project_config=sample_config, + project_dir=str(project_with_config), + ai_provider=provider, + ) + + +@pytest.fixture +def design_registry(): + registry = MagicMock() + + mock_architect = MagicMock() + mock_architect.name = "cloud-architect" + mock_architect.execute = MagicMock( + return_value=MagicMock( + content='```json\n[{"name": "Solution Overview", "context": "overview"}]\n```', + model="test", + usage={"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}, + ) + ) + + mock_biz = MagicMock() + mock_biz.name = "biz-analyst" + mock_biz.get_system_messages.return_value = [] + mock_biz._temperature = 0.7 + mock_biz._max_tokens = 4096 + + mock_tf = MagicMock() + mock_tf.name = "terraform-agent" + mock_tf.execute = MagicMock( + return_value=MagicMock( + content="Terraform feasibility confirmed.", + model="test", + usage={"prompt_tokens": 50, "completion_tokens": 100, "total_tokens": 150}, + ) + ) + + def find_by_cap(cap): + mapping = { + AgentCapability.ARCHITECT: [mock_architect], + AgentCapability.BIZ_ANALYSIS: [mock_biz], + AgentCapability.TERRAFORM: [mock_tf], + AgentCapability.BICEP: [], + AgentCapability.QA: [], + } + return mapping.get(cap, []) + + registry.find_by_capability.side_effect = find_by_cap + return registry + + +# ------------------------------------------------------------------ +# _format_section_elapsed +# ------------------------------------------------------------------ + + +class TestFormatSectionElapsed: + def test_seconds_under_60(self): + from azext_prototype.stages.design_stage import _format_section_elapsed + + assert _format_section_elapsed(5.0) == "5s" + assert _format_section_elapsed(45.7) == "46s" + + def test_seconds_over_60(self): + from azext_prototype.stages.design_stage import _format_section_elapsed + + assert _format_section_elapsed(64.0) == "1m04s" + assert _format_section_elapsed(125.0) == "2m05s" + + def test_exactly_60(self): + from azext_prototype.stages.design_stage import _format_section_elapsed + + assert _format_section_elapsed(60.0) == "1m00s" + + +# ------------------------------------------------------------------ +# _extract_new_sections +# ------------------------------------------------------------------ + + +class TestExtractNewSections: + def test_valid_section_marker(self): + from azext_prototype.stages.design_stage import _extract_new_sections + + content = 'Some text [NEW_SECTION: {"name": "Security", "context": "auth details"}] more text' + result = _extract_new_sections(content) + assert len(result) == 1 + assert result[0]["name"] == "Security" + assert result[0]["context"] == "auth details" + + def test_defaults_context(self): + from azext_prototype.stages.design_stage import _extract_new_sections + + content = '[NEW_SECTION: {"name": "Foo"}]' + result = _extract_new_sections(content) + assert len(result) == 1 + assert result[0]["context"] == "" + + def test_invalid_json_skipped(self): + from azext_prototype.stages.design_stage import _extract_new_sections + + content = "[NEW_SECTION: {bad json}]" + assert _extract_new_sections(content) == [] + + def test_missing_name_skipped(self): + from azext_prototype.stages.design_stage import _extract_new_sections + + content = '[NEW_SECTION: {"context": "only context"}]' + assert _extract_new_sections(content) == [] + + def test_multiple_markers(self): + from azext_prototype.stages.design_stage import _extract_new_sections + + content = '[NEW_SECTION: {"name": "A"}] middle ' '[NEW_SECTION: {"name": "B", "context": "ctx"}]' + result = _extract_new_sections(content) + assert len(result) == 2 + assert result[0]["name"] == "A" + assert result[1]["name"] == "B" + + +# ------------------------------------------------------------------ +# DesignStage — guards +# ------------------------------------------------------------------ + + +class TestDesignStageGuards: + def test_get_guards_returns_one_guard(self): + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + guards = stage.get_guards() + assert len(guards) == 1 + assert guards[0].name == "project_initialized" + + def test_design_is_reentrant(self): + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + assert stage.reentrant is True + + +# ------------------------------------------------------------------ +# _load_design_state / _save_design_state +# ------------------------------------------------------------------ + + +class TestDesignStatePersistence: + def test_load_returns_fresh_state_when_no_file(self, tmp_project): + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + state = stage._load_design_state(str(tmp_project), reset=False) + assert state["architecture"] is None + assert state["artifacts"] == [] + assert state["iteration"] == 0 + + def test_load_with_reset_clears_existing(self, tmp_project): + from azext_prototype.stages.design_stage import DesignStage + + state_file = tmp_project / ".prototype" / "state" / "design.json" + state_file.parent.mkdir(parents=True, exist_ok=True) + state_file.write_text( + json.dumps({"architecture": "existing", "artifacts": [], "iteration": 3}), + encoding="utf-8", + ) + + stage = DesignStage() + state = stage._load_design_state(str(tmp_project), reset=True) + assert state["architecture"] is None + assert state["iteration"] == 0 + + def test_load_existing_state(self, tmp_project): + from azext_prototype.stages.design_stage import DesignStage + + state_file = tmp_project / ".prototype" / "state" / "design.json" + state_file.parent.mkdir(parents=True, exist_ok=True) + state_file.write_text( + json.dumps({"architecture": "arch", "artifacts": [{"path": "/foo"}], "iteration": 2}), + encoding="utf-8", + ) + + stage = DesignStage() + state = stage._load_design_state(str(tmp_project), reset=False) + assert state["architecture"] == "arch" + assert state["iteration"] == 2 + + def test_save_and_reload(self, tmp_project): + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + state = {"architecture": "test-arch", "artifacts": [], "iteration": 1} + stage._save_design_state(str(tmp_project), state) + + reloaded = stage._load_design_state(str(tmp_project), reset=False) + assert reloaded["architecture"] == "test-arch" + + +# ------------------------------------------------------------------ +# _write_architecture_docs +# ------------------------------------------------------------------ + + +class TestWriteArchitectureDocs: + def test_writes_architecture_md(self, tmp_project): + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + stage._write_architecture_docs(str(tmp_project), "# My Architecture\nSome content") + + arch_file = tmp_project / "concept" / "docs" / "ARCHITECTURE.md" + assert arch_file.exists() + content = arch_file.read_text() + assert "My Architecture" in content + + +# ------------------------------------------------------------------ +# _compute_artifact_hashes +# ------------------------------------------------------------------ + + +class TestArtifactHashes: + def test_computes_hashes_for_text_files(self, tmp_project): + from azext_prototype.stages.design_stage import DesignStage + + docs_dir = tmp_project / "concept" / "docs" + docs_dir.mkdir(parents=True, exist_ok=True) + (docs_dir / "spec.txt").write_text("hello", encoding="utf-8") + + stage = DesignStage() + hashes = stage._compute_artifact_hashes(str(docs_dir)) + assert len(hashes) >= 1 + # Hash should be a hex string + for path, h in hashes.items(): + assert len(h) == 64 # SHA-256 hex + + def test_nonexistent_path_returns_empty(self, tmp_project): + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + hashes = stage._compute_artifact_hashes(str(tmp_project / "nonexistent")) + assert hashes == {} + + +# ------------------------------------------------------------------ +# skip-discovery flow +# ------------------------------------------------------------------ + + +class TestSkipDiscovery: + def test_skip_discovery_without_state_raises(self, design_context, design_registry): + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + + with pytest.raises(Exception): + stage.execute( + design_context, + design_registry, + skip_discovery=True, + input_fn=lambda p: "", + print_fn=lambda m: None, + ) + + def test_skip_discovery_with_existing_state(self, design_context, design_registry, project_with_config): + from azext_prototype.stages.design_stage import DesignStage + from azext_prototype.stages.discovery_state import DiscoveryState + + # Create discovery state + ds = DiscoveryState(str(project_with_config)) + ds.load() + ds.state["project"] = {"summary": "API backend"} + ds.state["confirmed_items"] = ["Use Container Apps"] + ds.state["_metadata"]["exchange_count"] = 3 + # Add conversation history so _extract_last_summary can find it + ds.state["conversation_history"] = [ + {"role": "assistant", "content": "## Requirements Summary\nBuild an API."}, + ] + ds.save() + + stage = DesignStage() + + # Mock the architect execution chain + mock_arch = design_registry.find_by_capability(AgentCapability.ARCHITECT)[0] + # First call: plan sections, Second+: generate sections + mock_arch.execute.side_effect = [ + MagicMock( + content='```json\n[{"name": "Overview", "context": "test"}]\n```', + model="test", + usage={}, + ), + MagicMock( + content="## Overview\nSample arch.", + model="test", + usage={}, + ), + # IaC review + MagicMock(content="Terraform ok", model="test", usage={}), + ] + + result = stage.execute( + design_context, + design_registry, + skip_discovery=True, + input_fn=lambda p: "", + print_fn=lambda m: None, + ) + assert result["status"] == "success" + + +# ------------------------------------------------------------------ +# _refine_architecture_loop +# ------------------------------------------------------------------ + + +class TestRefineArchitectureLoop: + def test_empty_feedback_exits(self, design_context, design_registry): + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + mock_architect = design_registry.find_by_capability(AgentCapability.ARCHITECT)[0] + + design_state = {"architecture": "# Arch\nContent", "iteration": 1} + + from azext_prototype.config import ProjectConfig + + config = ProjectConfig(design_context.project_dir) + config.load() + + with patch("builtins.input", return_value=""): + result = stage._refine_architecture_loop( + design_context, + mock_architect, + design_state, + config, + ) + + assert result == "# Arch\nContent" + + def test_accept_keyword_exits(self, design_context, design_registry): + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + mock_architect = design_registry.find_by_capability(AgentCapability.ARCHITECT)[0] + + design_state = {"architecture": "# Arch", "iteration": 1} + + from azext_prototype.config import ProjectConfig + + config = ProjectConfig(design_context.project_dir) + config.load() + + with patch("builtins.input", return_value="done"): + result = stage._refine_architecture_loop( + design_context, + mock_architect, + design_state, + config, + ) + assert result == "# Arch" + + +# ------------------------------------------------------------------ +# _execute_with_prompt_trim +# ------------------------------------------------------------------ + + +class TestExecuteWithPromptTrim: + def test_normal_execution_passes_through(self): + from azext_prototype.stages.design_stage import DesignStage + + architect = MagicMock() + architect.execute.return_value = MagicMock(content="result") + ctx = MagicMock() + + result = DesignStage._execute_with_prompt_trim(architect, ctx, "prompt", []) + assert result.content == "result" + + def test_prompt_too_large_with_accumulated_retries(self): + from azext_prototype.ai.copilot_provider import CopilotPromptTooLargeError + from azext_prototype.stages.design_stage import DesignStage + + architect = MagicMock() + # First call raises, second succeeds + architect.execute.side_effect = [ + CopilotPromptTooLargeError("Prompt too large", token_count=200000, token_limit=100000), + MagicMock(content="trimmed result"), + ] + ctx = MagicMock() + + prompt = "Intro\n## Architecture So Far\nfull content\n\n## Instructions\nGenerate code" + accumulated = ["## Section 1\nContent 1", "## Section 2\nContent 2"] + + result = DesignStage._execute_with_prompt_trim(architect, ctx, prompt, accumulated) + assert result.content == "trimmed result" + + def test_prompt_too_large_no_accumulated_reraises(self): + from azext_prototype.stages.design_stage import DesignStage + + architect = MagicMock() + # When accumulated is empty and prompt lacks ## Architecture So Far, + # the code hits bare `raise` outside exception context → RuntimeError + from azext_prototype.ai.copilot_provider import CopilotPromptTooLargeError + + architect.execute.side_effect = CopilotPromptTooLargeError( + "Prompt too large", token_count=200000, token_limit=100000 + ) + ctx = MagicMock() + + with pytest.raises(RuntimeError): + DesignStage._execute_with_prompt_trim(architect, ctx, "prompt without marker", []) + + +# ------------------------------------------------------------------ +# _plan_architecture fallback +# ------------------------------------------------------------------ + + +class TestPlanArchitecture: + def test_fallback_on_invalid_json(self, design_context, design_registry): + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + mock_architect = design_registry.find_by_capability(AgentCapability.ARCHITECT)[0] + mock_architect.execute.return_value = MagicMock( + content="Not valid JSON at all", + model="test", + usage={}, + ) + + from azext_prototype.config import ProjectConfig + + config = ProjectConfig(design_context.project_dir) + config.load() + + sections = stage._plan_architecture( + None, + design_context, + mock_architect, + config, + "requirements", + lambda m: None, + ) + # Should fall back to _DEFAULT_SECTIONS + assert len(sections) > 0 + assert sections[0]["name"] == "Solution Overview" + + +# ------------------------------------------------------------------ +# _run_iac_review +# ------------------------------------------------------------------ + + +class TestRunIacReview: + def test_no_iac_agent_skips_review(self, design_context): + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + registry = MagicMock() + registry.find_by_capability.return_value = [] + + from azext_prototype.config import ProjectConfig + + config = ProjectConfig(design_context.project_dir) + config.load() + + mock_architect = MagicMock() + # Should not raise — silently skips + stage._run_iac_review(design_context, registry, config, mock_architect, "design output") + + def test_iac_review_stores_artifact(self, design_context, design_registry): + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + + from azext_prototype.config import ProjectConfig + + config = ProjectConfig(design_context.project_dir) + config.load() + + mock_architect = design_registry.find_by_capability(AgentCapability.ARCHITECT)[0] + + stage._run_iac_review(design_context, design_registry, config, mock_architect, "design output") + # Verify artifact was added + assert "iac_review" in design_context.artifacts diff --git a/tests/test_discovery.py b/tests/stages/test_discovery.py similarity index 85% rename from tests/test_discovery.py rename to tests/stages/test_discovery.py index 10367ee..43578ef 100644 --- a/tests/test_discovery.py +++ b/tests/stages/test_discovery.py @@ -1,22 +1,543 @@ -"""Tests for azext_prototype.stages.discovery — organic multi-turn conversation.""" +"""Tests for discovery.py — branch coverage for section header extraction, +slash command routing, opening message construction, vision content array +building, topic detection, context change handling, conversation state +management, parse_sections, and the main run loop. +""" from __future__ import annotations -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest from azext_prototype.agents.base import AgentCapability, AgentContext + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def discovery_context(project_with_config, sample_config): + provider = MagicMock() + provider.provider_name = "github-models" + provider.default_model = "gpt-4o" + provider.chat.return_value = MagicMock( + content="I understand. Let me ask some questions.", + model="test", + usage={"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}, + ) + return AgentContext( + project_config=sample_config, + project_dir=str(project_with_config), + ai_provider=provider, + ) + + +@pytest.fixture +def discovery_registry(): + registry = MagicMock() + + mock_biz = MagicMock() + mock_biz.name = "biz-analyst" + mock_biz.get_system_messages.return_value = [] + mock_biz._temperature = 0.7 + mock_biz._max_tokens = 4096 + + mock_architect = MagicMock() + mock_architect.name = "cloud-architect" + + mock_qa = MagicMock() + mock_qa.name = "qa-engineer" + + def find_by_cap(cap): + mapping = { + AgentCapability.BIZ_ANALYSIS: [mock_biz], + AgentCapability.ARCHITECT: [mock_architect], + AgentCapability.QA: [mock_qa], + } + return mapping.get(cap, []) + + registry.find_by_capability.side_effect = find_by_cap + return registry + + +# ------------------------------------------------------------------ +# extract_section_headers +# ------------------------------------------------------------------ + + +class TestExtractSectionHeaders: + def test_extracts_h2_headings(self): + from azext_prototype.stages.discovery import extract_section_headers + + text = "## Authentication\nContent\n## Data Storage\nContent" + headers = extract_section_headers(text) + assert len(headers) == 2 + assert headers[0] == ("Authentication", 2) + assert headers[1] == ("Data Storage", 2) + + def test_filters_skip_headings(self): + from azext_prototype.stages.discovery import extract_section_headers + + text = "## Summary\nContent\n## Next Steps\nContent\n## Real Topic\nContent" + headers = extract_section_headers(text) + assert len(headers) == 1 + assert headers[0][0] == "Real Topic" + + def test_filters_short_headings(self): + from azext_prototype.stages.discovery import extract_section_headers + + text = "## AB\nContent\n## Long Enough\nContent" + headers = extract_section_headers(text) + assert len(headers) == 1 + assert headers[0][0] == "Long Enough" + + def test_deduplicates_headings(self): + from azext_prototype.stages.discovery import extract_section_headers + + text = "## Auth\nContent\n## Auth\nDuplicate" + headers = extract_section_headers(text) + assert len(headers) == 1 + + def test_bold_headings(self): + from azext_prototype.stages.discovery import extract_section_headers + + text = "**Authentication Model**\nContent\n**Data Layer**\nContent" + headers = extract_section_headers(text) + assert len(headers) == 2 + assert headers[0][0] == "Authentication Model" + + def test_h3_headings_excluded(self): + from azext_prototype.stages.discovery import extract_section_headers + + text = "## Top Level\nContent\n### Sub Section\nContent" + headers = extract_section_headers(text) + assert len(headers) == 1 + assert headers[0][0] == "Top Level" + + def test_empty_text(self): + from azext_prototype.stages.discovery import extract_section_headers + + assert extract_section_headers("") == [] + + def test_no_headings(self): + from azext_prototype.stages.discovery import extract_section_headers + + assert extract_section_headers("Just plain text with no headings.") == [] + + +# ------------------------------------------------------------------ +# parse_sections +# ------------------------------------------------------------------ + + +class TestParseSections: + def test_basic_sections(self): + from azext_prototype.stages.discovery import parse_sections + + text = "Intro text\n\n## Section A\nContent A\n\n## Section B\nContent B" + preamble, sections = parse_sections(text) + assert "Intro text" in preamble + assert len(sections) == 2 + assert sections[0].heading == "Section A" + assert sections[1].heading == "Section B" + + def test_no_sections_returns_full_text(self): + from azext_prototype.stages.discovery import parse_sections + + text = "Just a paragraph without headings." + preamble, sections = parse_sections(text) + assert preamble == text + assert sections == [] + + def test_skip_headings_filtered(self): + from azext_prototype.stages.discovery import parse_sections + + text = "## Summary\nSkip this\n## Real Section\nKeep this" + preamble, sections = parse_sections(text) + assert len(sections) == 1 + assert sections[0].heading == "Real Section" + + def test_task_id_generated(self): + from azext_prototype.stages.discovery import parse_sections + + text = "## Data Storage\nContent" + _, sections = parse_sections(text) + assert len(sections) == 1 + assert sections[0].task_id == "design-section-data-storage" + + def test_h3_folded_into_parent(self): + from azext_prototype.stages.discovery import parse_sections + + text = "## Parent\nP content\n### Child\nC content\n## Another\nA content" + _, sections = parse_sections(text) + assert len(sections) == 2 + assert "Child" in sections[0].content + assert sections[1].heading == "Another" + + def test_bold_heading_sections(self): + from azext_prototype.stages.discovery import parse_sections + + text = "Preamble\n\n**Security Model**\nSecurity content\n\n**Deployment**\nDeploy content" + preamble, sections = parse_sections(text) + assert len(sections) == 2 + assert sections[0].heading == "Security Model" + + def test_only_h3_returns_no_sections(self): + from azext_prototype.stages.discovery import parse_sections + + text = "### Sub Section\nContent" + preamble, sections = parse_sections(text) + assert sections == [] + assert "Sub Section" in preamble + + +# ------------------------------------------------------------------ +# _build_opening +# ------------------------------------------------------------------ + + +class TestBuildOpening: + def _make_session(self, ctx, registry): + from azext_prototype.stages.discovery import DiscoverySession + + return DiscoverySession(ctx, registry) + + def test_no_context_no_artifacts(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + opening = session._build_opening("", "", "") + assert isinstance(opening, str) + assert "Azure prototype" in opening + + def test_seed_context_only(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + opening = session._build_opening("Build an API", "", "") + assert "Build an API" in opening + + def test_artifacts_only(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + opening = session._build_opening("", "Requirements doc content", "") + assert "requirement documents" in opening.lower() + + def test_seed_and_artifacts(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + opening = session._build_opening("Build an API", "Doc content", "") + assert "Build an API" in opening + assert "Doc content" in opening + + def test_existing_context_included(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + opening = session._build_opening("New info", "", "Previous session learnings") + assert "Previous session learnings" in opening + assert "conflicts" in opening.lower() + + def test_images_produce_multimodal(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + images = [{"filename": "test.png", "data": "abc123", "mime": "image/png"}] + opening = session._build_opening("Context", "", "", images=images) + assert isinstance(opening, list) + assert opening[0]["type"] == "text" + assert opening[1]["type"] == "image_url" + assert "abc123" in opening[1]["image_url"]["url"] + + def test_no_context_with_existing_only(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + opening = session._build_opening("", "", "existing") + assert "existing" in opening + + def test_multiple_images(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + images = [ + {"filename": "a.png", "data": "aaa", "mime": "image/png"}, + {"filename": "b.jpg", "data": "bbb", "mime": "image/jpeg"}, + ] + opening = session._build_opening("", "artifacts", "", images=images) + assert isinstance(opening, list) + assert len(opening) == 3 # text + 2 images + + +# ------------------------------------------------------------------ +# DiscoveryResult +# ------------------------------------------------------------------ + + +class TestDiscoveryResult: + def test_default_not_cancelled(self): + from azext_prototype.stages.discovery import DiscoveryResult + + result = DiscoveryResult( + requirements="reqs", + conversation=[], + policy_overrides=[], + exchange_count=5, + ) + assert result.cancelled is False + assert result.exchange_count == 5 + + def test_cancelled_result(self): + from azext_prototype.stages.discovery import DiscoveryResult + + result = DiscoveryResult( + requirements="", + conversation=[], + policy_overrides=[], + exchange_count=0, + cancelled=True, + ) + assert result.cancelled is True + + +# ------------------------------------------------------------------ +# Session run — no biz-agent fallback +# ------------------------------------------------------------------ + + +class TestDiscoverySessionNoBizAgent: + def test_no_biz_agent_prompts_user(self, discovery_context): + from azext_prototype.stages.discovery import DiscoverySession + + registry = MagicMock() + registry.find_by_capability.return_value = [] + + session = DiscoverySession(discovery_context, registry) + result = session.run( + seed_context="test", + input_fn=lambda p: "my requirements", + print_fn=lambda m: None, + ) + assert result.requirements == "my requirements" + assert result.exchange_count == 0 + + def test_no_biz_agent_eof(self, discovery_context): + from azext_prototype.stages.discovery import DiscoverySession + + registry = MagicMock() + registry.find_by_capability.return_value = [] + + session = DiscoverySession(discovery_context, registry) + + def raise_eof(p): + raise EOFError + + result = session.run( + seed_context="", + input_fn=raise_eof, + print_fn=lambda m: None, + ) + assert result.requirements == "" + + +# ------------------------------------------------------------------ +# Session run — quit/done in main loop +# ------------------------------------------------------------------ + + +class TestDiscoverySessionMainLoop: + def _make_session(self, ctx, registry): + from azext_prototype.stages.discovery import DiscoverySession + + return DiscoverySession(ctx, registry) + + def test_quit_returns_cancelled(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + + # First response from AI has no sections (plain text), so we enter free-form loop + discovery_context.ai_provider.chat.return_value = MagicMock( + content="What would you like to build?", + model="test", + usage={"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300}, + ) + + inputs = iter(["quit"]) + result = session.run( + seed_context="Build an API", + input_fn=lambda p: next(inputs), + print_fn=lambda m: None, + ) + assert result.cancelled is True + + def test_done_produces_summary(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + + call_count = [0] + + def mock_chat(messages, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return MagicMock( + content="What would you like to build?", + model="test", + usage={"prompt_tokens": 50, "completion_tokens": 100, "total_tokens": 150}, + ) + # Summary call + return MagicMock( + content="## Requirements Summary\nBuild an API with auth.", + model="test", + usage={"prompt_tokens": 50, "completion_tokens": 100, "total_tokens": 150}, + ) + + discovery_context.ai_provider.chat.side_effect = mock_chat + + inputs = iter(["done"]) + result = session.run( + seed_context="Build an API", + input_fn=lambda p: next(inputs), + print_fn=lambda m: None, + ) + assert result.cancelled is False + assert result.requirements # Should have summary text + + def test_eof_in_main_loop_ends_session(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + + discovery_context.ai_provider.chat.return_value = MagicMock( + content="Plain response no sections", + model="test", + usage={"prompt_tokens": 50, "completion_tokens": 100, "total_tokens": 150}, + ) + + def raise_eof(p): + raise EOFError + + result = session.run( + seed_context="test", + input_fn=raise_eof, + print_fn=lambda m: None, + ) + # Should produce a summary (not cancelled) + assert result is not None + + def test_slash_command_help(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + + discovery_context.ai_provider.chat.return_value = MagicMock( + content="What would you like to build?", + model="test", + usage={"prompt_tokens": 50, "completion_tokens": 100, "total_tokens": 150}, + ) + + inputs = iter(["/help", "done"]) + result = session.run( + seed_context="test", + input_fn=lambda p: next(inputs), + print_fn=lambda m: None, + ) + assert result is not None + assert not result.cancelled + + +# ------------------------------------------------------------------ +# _chat — vision fallback +# ------------------------------------------------------------------ + + +class TestChatVisionFallback: + def test_vision_failure_degrades_to_text(self, discovery_context, discovery_registry): + from azext_prototype.stages.discovery import DiscoverySession + + session = DiscoverySession(discovery_context, discovery_registry) + + call_count = [0] + + def mock_chat(messages, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise Exception("Vision not supported") + return MagicMock( + content="Text fallback response", + model="test", + usage={"prompt_tokens": 50, "completion_tokens": 100, "total_tokens": 150}, + ) + + discovery_context.ai_provider.chat.side_effect = mock_chat + + content = [ + {"type": "text", "text": "Review these files"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ] + response = session._chat(content) + assert response == "Text fallback response" + # Should have fallen back to text-only + assert call_count[0] == 2 + + +# ------------------------------------------------------------------ +# _chat_lightweight +# ------------------------------------------------------------------ + + +class TestChatLightweight: + def test_returns_ai_content(self, discovery_context, discovery_registry): + from azext_prototype.stages.discovery import DiscoverySession + + session = DiscoverySession(discovery_context, discovery_registry) + + discovery_context.ai_provider.chat.return_value = MagicMock( + content="Lightweight response", + model="test", + usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + ) + + result = session._chat_lightweight("classify this text") + assert result == "Lightweight response" + + def test_does_not_append_to_messages(self, discovery_context, discovery_registry): + from azext_prototype.stages.discovery import DiscoverySession + + session = DiscoverySession(discovery_context, discovery_registry) + + discovery_context.ai_provider.chat.return_value = MagicMock( + content="response", + model="test", + usage={}, + ) + + initial_len = len(session._messages) + session._chat_lightweight("test") + assert len(session._messages) == initial_len + + +# ------------------------------------------------------------------ +# _handle_incremental_context +# ------------------------------------------------------------------ + + +class TestHandleIncrementalContext: + def _make_session(self, ctx, registry): + from azext_prototype.stages.discovery import DiscoverySession + + return DiscoverySession(ctx, registry) + + def test_no_new_topics_records_decision(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + + discovery_context.ai_provider.chat.return_value = MagicMock( + content="[NO_NEW_TOPICS]", + model="test", + usage={}, + ) + + result = session._handle_incremental_context("Use Redis for caching", "", None, lambda m: None, False, None) + assert result is False + + def test_new_topics_added(self, discovery_context, discovery_registry): + session = self._make_session(discovery_context, discovery_registry) + + discovery_context.ai_provider.chat.return_value = MagicMock( + content="## Caching Strategy\nHow should Redis be configured?\n\n## Performance\nWhat SLA is needed?", + model="test", + usage={}, + ) + + result = session._handle_incremental_context("Add Redis caching", "", None, lambda m: None, False, None) + assert result is True + +# --- Additional imports from merged flat test --- from azext_prototype.ai.provider import AIMessage, AIResponse -from azext_prototype.stages.discovery import ( - _DONE_WORDS, - _QUIT_WORDS, - _READY_MARKER, - DiscoveryResult, - DiscoverySession, - extract_section_headers, - parse_sections, -) +from azext_prototype.stages.discovery import _DONE_WORDS, _QUIT_WORDS, _READY_MARKER, DiscoveryResult, DiscoverySession, extract_section_headers, parse_sections +from unittest.mock import patch + # ====================================================================== # Fixtures @@ -80,34 +601,6 @@ def _make_response(content: str) -> AIResponse: return AIResponse(content=content, model="gpt-4o", usage={}) -# ====================================================================== -# DiscoveryResult -# ====================================================================== - - -class TestDiscoveryResult: - def test_basic_creation(self): - result = DiscoveryResult( - requirements="Build a web app", - conversation=[], - policy_overrides=[], - exchange_count=3, - ) - assert result.requirements == "Build a web app" - assert result.exchange_count == 3 - assert result.cancelled is False - - def test_cancelled(self): - result = DiscoveryResult( - requirements="", - conversation=[], - policy_overrides=[], - exchange_count=0, - cancelled=True, - ) - assert result.cancelled is True - - # ====================================================================== # DiscoverySession — basic conversation flow # ====================================================================== @@ -1333,138 +1826,6 @@ def test_nl_status(self, mock_agent_context, mock_registry): assert any("status" in o.lower() or "discovery" in o.lower() for o in output if isinstance(o, str)) -# ====================================================================== -# extract_section_headers -# ====================================================================== - - -class TestExtractSectionHeaders: - """Unit tests for extract_section_headers().""" - - def test_extracts_h2_headings(self): - text = "## Project Context & Scope\nSome text\n## Data & Content\nMore text" - result = extract_section_headers(text) - assert result == [("Project Context & Scope", 2), ("Data & Content", 2)] - - def test_h3_only_returns_empty(self): - """Level-3 only responses produce no headers (subsections are not topics).""" - text = "### Authentication\nDetails\n### Authorization\nMore details" - result = extract_section_headers(text) - assert result == [] - - def test_mixed_h2_h3_returns_only_h2(self): - """Level-3 subsections are filtered out — only level-2 topics returned.""" - text = "## Overview\nText\n### Sub-section\nText\n## Architecture\nText" - result = extract_section_headers(text) - assert result == [("Overview", 2), ("Architecture", 2)] - - def test_skips_structural_headings(self): - text = "## Project Context\nText\n" "## Summary\nText\n" "## Policy Overrides\nText\n" "## Next Steps\nText\n" - result = extract_section_headers(text) - assert result == [("Project Context", 2)] - - def test_skips_policy_override_singular(self): - text = "## Policy Override\nText" - result = extract_section_headers(text) - assert result == [] - - def test_skips_short_headings(self): - text = "## AB\nText\n## OK\nMore" - result = extract_section_headers(text) - assert result == [] - - def test_empty_string(self): - assert extract_section_headers("") == [] - - def test_no_headings(self): - text = "Just plain text without any headings at all." - assert extract_section_headers(text) == [] - - def test_h1_not_extracted(self): - """Only ## and ### are extracted, not #.""" - text = "# Title\n## Section One\nContent" - result = extract_section_headers(text) - assert result == [("Section One", 2)] - - def test_strips_whitespace(self): - text = "## Padded Heading \nText" - result = extract_section_headers(text) - assert result == [("Padded Heading", 2)] - - def test_case_insensitive_skip(self): - text = "## SUMMARY\nText\n## NEXT STEPS\nText\n## Actual Content\nText" - result = extract_section_headers(text) - assert result == [("Actual Content", 2)] - - def test_bold_headings_extracted(self): - """**Bold Heading** on its own line should be extracted as level 2.""" - text = ( - "Let me ask about your project.\n" - "\n" - "**Hosting & Deployment**\n" - "How do you plan to host this?\n" - "\n" - "**Data Layer**\n" - "What database will you use?" - ) - result = extract_section_headers(text) - assert ("Hosting & Deployment", 2) in result - assert ("Data Layer", 2) in result - - def test_bold_inline_not_extracted(self): - """Bold text mid-line should NOT be extracted as a heading.""" - text = "I think **this is important** for the project." - result = extract_section_headers(text) - assert result == [] - - def test_bold_and_markdown_headings_merged(self): - """Both ## headings and **bold headings** should be found with levels.""" - text = "## Architecture Overview\n" "Details here.\n" "\n" "**Security Considerations**\n" "More details." - result = extract_section_headers(text) - assert ("Architecture Overview", 2) in result - assert ("Security Considerations", 2) in result - - def test_bold_headings_deduped(self): - """Duplicate headings (same text in both formats) should appear once.""" - text = "## Security\n" "Details.\n" "\n" "**Security**\n" "More details." - result = extract_section_headers(text) - texts = [h[0] for h in result] - assert texts.count("Security") == 1 - - def test_bold_headings_skip_structural(self): - """Bold structural headings (Summary, Next Steps) should be skipped.""" - text = "**Summary**\nText\n**Actual Topic**\nMore text" - result = extract_section_headers(text) - texts = [h[0] for h in result] - assert "Summary" not in texts - assert "Actual Topic" in texts - - def test_bold_heading_too_short(self): - """Bold headings under 3 chars should be skipped.""" - text = "**AB**\nText" - result = extract_section_headers(text) - assert result == [] - - def test_skip_what_ive_understood(self): - """'What I've Understood So Far' and variants should be filtered.""" - text = ( - "## What I've Understood So Far\nStuff\n" - "## What We've Covered\nMore stuff\n" - "## Actual Topic\nReal content" - ) - result = extract_section_headers(text) - texts = [h[0] for h in result] - assert "What I've Understood So Far" not in texts - assert "What We've Covered" not in texts - assert "Actual Topic" in texts - - def test_position_ordering(self): - """Headers should be sorted by their position in the response.""" - text = "**First Bold**\n" "Text\n" "## Second Markdown\n" "Text\n" "**Third Bold**\n" "Text" - result = extract_section_headers(text) - assert result == [("First Bold", 2), ("Second Markdown", 2), ("Third Bold", 2)] - - # ====================================================================== # section_fn callback integration # ====================================================================== @@ -1616,117 +1977,6 @@ def test_response_fn_takes_precedence_over_print_fn( assert not any("Tell me about your project" in p for p in printed if isinstance(p, str)) -# ====================================================================== -# parse_sections() -# ====================================================================== - - -class TestParseSections: - """Verify section parsing from AI responses.""" - - def test_basic_section_splitting(self): - text = ( - "Here's my analysis.\n\n" - "## Authentication\n" - "How do users sign in?\n\n" - "## Data Layer\n" - "What database do you prefer?" - ) - preamble, sections = parse_sections(text) - assert preamble == "Here's my analysis." - assert len(sections) == 2 - assert sections[0].heading == "Authentication" - assert sections[0].level == 2 - assert "How do users sign in?" in sections[0].content - assert sections[1].heading == "Data Layer" - assert "What database" in sections[1].content - - def test_preamble_only(self): - text = "No headings here, just a plain response." - preamble, sections = parse_sections(text) - assert preamble == text - assert sections == [] - - def test_empty_preamble(self): - text = "## First Topic\nQuestion here." - preamble, sections = parse_sections(text) - assert preamble == "" - assert len(sections) == 1 - - def test_skip_headings_filtered(self): - text = ( - "## Authentication\nHow do users sign in?\n\n" - "## Summary\nThis is a summary.\n\n" - "## Next Steps\nDo this next." - ) - _, sections = parse_sections(text) - assert len(sections) == 1 - assert sections[0].heading == "Authentication" - - def test_task_id_generation(self): - text = "## Data & Content\nWhat kind of data?" - _, sections = parse_sections(text) - assert len(sections) == 1 - assert sections[0].task_id == "design-section-data-content" - - def test_bold_headings(self): - text = ( - "Here's what I need to know.\n\n" - "**Authentication & Security**\n" - "How do users log in?\n\n" - "**Data Storage**\n" - "What database?" - ) - preamble, sections = parse_sections(text) - assert len(sections) == 2 - assert sections[0].heading == "Authentication & Security" - assert sections[0].level == 2 - - def test_level_3_only_returns_empty(self): - """Level-3 only response produces no sections (not treated as topics).""" - text = "### Sub-topic\nDetailed question." - _, sections = parse_sections(text) - assert len(sections) == 0 - - def test_subsections_folded_into_parent(self): - """Level-3 subsections are folded into their parent level-2 section.""" - text = "## Main Topic\nOverview.\n\n" "### Sub-topic\nDetail." - _, sections = parse_sections(text) - assert len(sections) == 1 - assert sections[0].heading == "Main Topic" - assert sections[0].level == 2 - # Subsection content is included in the parent's content - assert "### Sub-topic" in sections[0].content - assert "Detail." in sections[0].content - - def test_multiple_subsections_folded(self): - """Multiple ### subsections under one ## are all included.""" - text = ( - "## Scope Boundary\nLet's clarify scope.\n\n" - "### In Scope\n- Item A\n- Item B\n\n" - "### Out of Scope\n- Item C\n\n" - "## Next Topic\nQuestions here." - ) - _, sections = parse_sections(text) - assert len(sections) == 2 - assert sections[0].heading == "Scope Boundary" - assert "### In Scope" in sections[0].content - assert "### Out of Scope" in sections[0].content - assert "Item A" in sections[0].content - assert "Item C" in sections[0].content - assert sections[1].heading == "Next Topic" - - def test_empty_string(self): - preamble, sections = parse_sections("") - assert preamble == "" - assert sections == [] - - def test_duplicate_headings_deduped(self): - text = "## Authentication\nFirst mention.\n\n" "## Authentication\nSecond mention." - _, sections = parse_sections(text) - assert len(sections) == 1 - - # ====================================================================== # Section completion via AI "Yes" gate # ====================================================================== @@ -3062,161 +3312,6 @@ def test_empty_state(self, tmp_path): assert ds.topic_at_exchange(1) is None -# ====================================================================== -# _chat_lightweight edge cases -# ====================================================================== - - -class TestChatLightweight: - """Tests for _chat_lightweight — minimal AI call for classification tasks.""" - - def test_empty_content(self, mock_agent_context, mock_registry): - """Empty string content should still work.""" - mock_agent_context.ai_provider.chat.return_value = _make_response("[NO_NEW_TOPICS]") - session = DiscoverySession(mock_agent_context, mock_registry) - - result = session._chat_lightweight("") - assert result == "[NO_NEW_TOPICS]" - - # Verify it used a minimal system prompt (not the full governance payload) - call_args = mock_agent_context.ai_provider.chat.call_args - messages = call_args[0][0] - system_msgs = [m for m in messages if m.role == "system"] - assert len(system_msgs) == 1 - assert len(system_msgs[0].content) < 200 # Lightweight — not 69KB - - def test_does_not_add_to_messages(self, mock_agent_context, mock_registry): - """_chat_lightweight is ephemeral — should NOT add to self._messages.""" - mock_agent_context.ai_provider.chat.return_value = _make_response("analysis result") - session = DiscoverySession(mock_agent_context, mock_registry) - - initial_count = len(session._messages) - session._chat_lightweight("classify this") - assert len(session._messages) == initial_count - - def test_records_tokens(self, mock_agent_context, mock_registry): - """Token usage from lightweight calls should be tracked.""" - mock_agent_context.ai_provider.chat.return_value = _make_response("result") - session = DiscoverySession(mock_agent_context, mock_registry) - - session._chat_lightweight("test prompt") - # TokenTracker.record was called (uses AIResponse) - assert session._token_tracker._turn_count >= 1 - - def test_uses_low_temperature(self, mock_agent_context, mock_registry): - """Lightweight calls use temperature=0.3 for determinism.""" - mock_agent_context.ai_provider.chat.return_value = _make_response("ok") - session = DiscoverySession(mock_agent_context, mock_registry) - - session._chat_lightweight("test") - call_kwargs = mock_agent_context.ai_provider.chat.call_args[1] - assert call_kwargs.get("temperature") == 0.3 - - -# ====================================================================== -# _handle_incremental_context edge cases -# ====================================================================== - - -class TestHandleIncrementalContext: - """Tests for _handle_incremental_context — re-entry topic detection.""" - - def test_returns_false_no_topics_no_seed_context( - self, - mock_agent_context, - mock_registry, - ): - """When AI says [NO_NEW_TOPICS] and no seed_context, returns False.""" - mock_agent_context.ai_provider.chat.return_value = _make_response("[NO_NEW_TOPICS]") - session = DiscoverySession(mock_agent_context, mock_registry) - - result = session._handle_incremental_context( - seed_context="", - artifacts="some artifact text", - artifact_images=None, - _print=lambda x: None, - use_styled=False, - status_fn=None, - ) - assert result is False - - def test_returns_false_no_topics_with_seed_context( - self, - mock_agent_context, - mock_registry, - ): - """When AI says [NO_NEW_TOPICS] with seed_context, records decision.""" - mock_agent_context.ai_provider.chat.return_value = _make_response("[NO_NEW_TOPICS]") - session = DiscoverySession(mock_agent_context, mock_registry) - - printed = [] - result = session._handle_incremental_context( - seed_context="Change app name to Contoso", - artifacts="", - artifact_images=None, - _print=printed.append, - use_styled=False, - status_fn=None, - ) - assert result is False - # Seed context should be recorded as a confirmed decision - decisions = session._discovery_state.state["decisions"] - assert "Change app name to Contoso" in decisions - assert any("Context recorded" in p for p in printed) - - def test_returns_true_when_new_topics_found( - self, - mock_agent_context, - mock_registry, - ): - """When AI returns new sections, topics are appended and returns True.""" - new_topics_response = ( - "## Authentication Strategy\n" - "1. What identity provider?\n" - "2. SSO required?\n\n" - "## Data Residency\n" - "1. Which region?\n" - "2. Compliance needs?\n" - ) - mock_agent_context.ai_provider.chat.return_value = _make_response(new_topics_response) - session = DiscoverySession(mock_agent_context, mock_registry) - - result = session._handle_incremental_context( - seed_context="Add GDPR compliance", - artifacts="", - artifact_images=None, - _print=lambda x: None, - use_styled=False, - status_fn=None, - ) - assert result is True - # Topics should be appended to discovery state - assert session._discovery_state.has_items - - def test_no_parseable_sections_records_decision( - self, - mock_agent_context, - mock_registry, - ): - """When AI response has no parseable sections, seed_context is saved as decision.""" - mock_agent_context.ai_provider.chat.return_value = _make_response( - "The new information is already covered by existing topics." - ) - session = DiscoverySession(mock_agent_context, mock_registry) - - result = session._handle_incremental_context( - seed_context="Use Redis for caching", - artifacts="", - artifact_images=None, - _print=lambda x: None, - use_styled=False, - status_fn=None, - ) - assert result is False - decisions = session._discovery_state.state["decisions"] - assert "Use Redis for caching" in decisions - - # ====================================================================== # add_confirmed_decision deduplication # ====================================================================== diff --git a/tests/stages/test_discovery_state.py b/tests/stages/test_discovery_state.py new file mode 100644 index 0000000..99c2df7 --- /dev/null +++ b/tests/stages/test_discovery_state.py @@ -0,0 +1,775 @@ +"""Tests for DiscoveryState — legacy migration, exchange handling, state persistence. + +Covers: +- Legacy state migration (topics, open_items, confirmed_items → unified items) +- update_from_exchange with str vs list content +- Image stripping during persistence (multi-modal content arrays) +- Item management (add, resolve, mark, append, dedup) +- Format methods (open_items, confirmed_items, status_summary, as_context) +- Conversation summary extraction +- Search history +- Topic at exchange +- Artifact inventory +- Context hash +""" + +from pathlib import Path + +import pytest +import yaml + +from azext_prototype.stages.discovery_state import ( + DiscoveryState, + TrackedItem, + _default_discovery_state, +) + +# ====================================================================== +# Fixtures +# ====================================================================== + + +@pytest.fixture +def disco_state(tmp_project): + ds = DiscoveryState(str(tmp_project)) + return ds + + +@pytest.fixture +def disco_state_with_items(disco_state): + """State with some items pre-loaded.""" + disco_state._state["items"] = [ + { + "heading": "Auth approach", + "detail": "How to authenticate?", + "kind": "topic", + "status": "pending", + "answer_exchange": None, + }, + { + "heading": "DB choice", + "detail": "Which database?", + "kind": "decision", + "status": "confirmed", + "answer_exchange": 2, + }, + {"heading": "Hosting", "detail": "Where to host?", "kind": "topic", "status": "answered", "answer_exchange": 3}, + ] + disco_state._loaded = True + return disco_state + + +# ====================================================================== +# Legacy migration +# ====================================================================== + + +class TestLegacyMigration: + """Test _migrate_legacy_state converts old fields to unified items.""" + + def test_migrate_topics(self, disco_state): + """Old 'topics' key migrates to items with kind=topic.""" + disco_state._state["topics"] = [ + {"heading": "Auth", "questions": "How to authenticate?", "status": "pending"}, + ] + disco_state._state["items"] = [] + disco_state._migrate_legacy_state() + assert len(disco_state._state["items"]) == 1 + assert disco_state._state["items"][0]["kind"] == "topic" + assert disco_state._state["items"][0]["detail"] == "How to authenticate?" + assert "topics" not in disco_state._state + + def test_migrate_open_items(self, disco_state): + """Old 'open_items' list migrates to decision items with pending status.""" + disco_state._state["open_items"] = ["Which region?", "Which SKU?"] + disco_state._state["items"] = [] + disco_state._migrate_legacy_state() + assert len(disco_state._state["items"]) == 2 + assert all(i["kind"] == "decision" for i in disco_state._state["items"]) + assert all(i["status"] == "pending" for i in disco_state._state["items"]) + assert "open_items" not in disco_state._state + + def test_migrate_confirmed_items(self, disco_state): + """Old 'confirmed_items' list migrates to confirmed decision items.""" + disco_state._state["confirmed_items"] = ["Use PaaS"] + disco_state._state["items"] = [] + disco_state._migrate_legacy_state() + assert len(disco_state._state["items"]) == 1 + assert disco_state._state["items"][0]["status"] == "confirmed" + assert "confirmed_items" not in disco_state._state + + def test_migrate_deduplicates(self, disco_state): + """Migration deduplicates by heading (case-insensitive).""" + disco_state._state["topics"] = [ + {"heading": "Auth", "questions": "q", "status": "pending"}, + ] + disco_state._state["open_items"] = ["Auth"] # Same heading + disco_state._state["items"] = [] + disco_state._migrate_legacy_state() + assert len(disco_state._state["items"]) == 1 + + def test_migrate_empty_legacy_keys_cleaned(self, disco_state): + """Empty legacy keys are removed even if they have no items.""" + disco_state._state["topics"] = [] + disco_state._state["open_items"] = [] + disco_state._state["confirmed_items"] = [] + disco_state._migrate_legacy_state() + assert "topics" not in disco_state._state + assert "open_items" not in disco_state._state + assert "confirmed_items" not in disco_state._state + + def test_no_legacy_keys_no_op(self, disco_state): + """When no legacy keys exist, migration is a no-op.""" + original_items = list(disco_state._state["items"]) + disco_state._migrate_legacy_state() + assert disco_state._state["items"] == original_items + + def test_post_load_calls_migrate(self, disco_state, tmp_path): + """Loading state from disk triggers migration.""" + state_dir = Path(str(tmp_path)) / "test-project" / ".prototype" / "state" + state_dir.mkdir(parents=True, exist_ok=True) + legacy_state = _default_discovery_state() + legacy_state["topics"] = [ + {"heading": "Legacy Topic", "questions": "q", "status": "pending"}, + ] + state_file = state_dir / "discovery.yaml" + with open(state_file, "w", encoding="utf-8") as f: + yaml.dump(legacy_state, f) + + ds = DiscoveryState(str(tmp_path / "test-project")) + ds.load() + assert len(ds._state["items"]) == 1 + assert ds._state["items"][0]["heading"] == "Legacy Topic" + assert "topics" not in ds._state + + +# ====================================================================== +# update_from_exchange +# ====================================================================== + + +class TestUpdateFromExchange: + """Test exchange recording with str and list content.""" + + def test_string_input(self, disco_state): + disco_state.update_from_exchange("Hello", "Hi there!", 1) + history = disco_state._state["conversation_history"] + assert len(history) == 1 + assert history[0]["user"] == "Hello" + assert history[0]["assistant"] == "Hi there!" + assert history[0]["exchange"] == 1 + + def test_list_input_text_only(self, disco_state): + """Multi-modal content with only text parts.""" + content = [ + {"type": "text", "text": "Part 1"}, + {"type": "text", "text": "Part 2"}, + ] + disco_state.update_from_exchange(content, "Response", 1) + history = disco_state._state["conversation_history"] + assert "Part 1" in history[0]["user"] + assert "Part 2" in history[0]["user"] + + def test_list_input_with_images_stripped(self, disco_state): + """Multi-modal content with images — base64 data replaced with placeholder.""" + content = [ + {"type": "text", "text": "See this diagram"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,ABC123..."}}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,DEF456..."}}, + ] + disco_state.update_from_exchange(content, "I see the diagram", 1) + history = disco_state._state["conversation_history"] + assert "[2 image(s) attached]" in history[0]["user"] + assert "base64" not in history[0]["user"] + + def test_exchange_count_updated(self, disco_state): + disco_state.update_from_exchange("Q1", "A1", 1) + disco_state.update_from_exchange("Q2", "A2", 2) + assert disco_state._state["_metadata"]["exchange_count"] == 2 + + def test_list_input_with_no_text(self, disco_state): + """Multi-modal content with only images.""" + content = [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,ABC"}}, + ] + disco_state.update_from_exchange(content, "I see", 1) + history = disco_state._state["conversation_history"] + assert "[1 image(s) attached]" in history[0]["user"] + + +# ====================================================================== +# Item management +# ====================================================================== + + +class TestItemManagement: + """Test add, resolve, mark, append, dedup operations.""" + + def test_add_open_item(self, disco_state): + disco_state.add_open_item("Which region?") + items = disco_state._state["items"] + assert len(items) == 1 + assert items[0]["heading"] == "Which region?" + assert items[0]["status"] == "pending" + assert items[0]["kind"] == "decision" + + def test_add_open_item_dedup(self, disco_state): + disco_state.add_open_item("Which region?") + disco_state.add_open_item("Which region?") + assert len(disco_state._state["items"]) == 1 + + def test_resolve_item_by_heading(self, disco_state): + disco_state.add_open_item("Which region?") + disco_state.resolve_item("Which region?") + assert disco_state._state["items"][0]["status"] == "confirmed" + + def test_resolve_item_by_confirmed_text(self, disco_state): + disco_state.add_open_item("Which region?") + disco_state.resolve_item("different", confirmed_text="Which region?") + assert disco_state._state["items"][0]["status"] == "confirmed" + + def test_resolve_creates_if_not_found(self, disco_state): + disco_state.resolve_item("nonexistent", confirmed_text="New decision") + assert len(disco_state._state["items"]) == 1 + assert disco_state._state["items"][0]["status"] == "confirmed" + + def test_resolve_no_match_no_text_no_op(self, disco_state): + disco_state.resolve_item("nonexistent") + assert len(disco_state._state["items"]) == 0 + + def test_add_confirmed_decision(self, disco_state): + disco_state.add_confirmed_decision("Use PaaS services") + assert "Use PaaS services" in disco_state._state["decisions"] + + def test_add_confirmed_decision_dedup(self, disco_state): + disco_state.add_confirmed_decision("Use PaaS") + disco_state.add_confirmed_decision("Use PaaS") + assert disco_state._state["decisions"].count("Use PaaS") == 1 + + def test_set_items(self, disco_state): + items = [ + TrackedItem(heading="T1", detail="D1", kind="topic", status="pending", answer_exchange=None), + ] + disco_state.set_items(items) + assert len(disco_state._state["items"]) == 1 + + def test_append_items_dedup(self, disco_state): + disco_state.set_items( + [TrackedItem(heading="T1", detail="D1", kind="topic", status="pending", answer_exchange=None)] + ) + disco_state.append_items( + [ + TrackedItem(heading="T1", detail="D1", kind="topic", status="pending", answer_exchange=None), + TrackedItem(heading="T2", detail="D2", kind="topic", status="pending", answer_exchange=None), + ] + ) + assert len(disco_state._state["items"]) == 2 + + def test_mark_item(self, disco_state_with_items): + disco_state_with_items.mark_item("Auth approach", "answered", exchange=5) + item = disco_state_with_items._state["items"][0] + assert item["status"] == "answered" + assert item["answer_exchange"] == 5 + + def test_first_pending_index(self, disco_state_with_items): + idx = disco_state_with_items.first_pending_index() + assert idx == 0 # Auth approach is pending + + def test_first_pending_index_by_kind(self, disco_state_with_items): + # Add a pending decision + disco_state_with_items._state["items"].append( + {"heading": "D1", "detail": "d", "kind": "decision", "status": "pending", "answer_exchange": None} + ) + idx = disco_state_with_items.first_pending_index(kind="decision") + assert idx == 3 + + def test_first_pending_index_none(self, disco_state): + assert disco_state.first_pending_index() is None + + +# ====================================================================== +# Item properties +# ====================================================================== + + +class TestItemProperties: + """Test item accessor properties.""" + + def test_open_count(self, disco_state_with_items): + assert disco_state_with_items.open_count == 1 + + def test_confirmed_count(self, disco_state_with_items): + assert disco_state_with_items.confirmed_count == 2 # confirmed + answered + + def test_has_items(self, disco_state_with_items): + assert disco_state_with_items.has_items is True + + def test_has_items_empty(self, disco_state): + assert disco_state.has_items is False + + def test_items_property(self, disco_state_with_items): + items = disco_state_with_items.items + assert len(items) == 3 + assert all(isinstance(i, TrackedItem) for i in items) + + def test_topic_items(self, disco_state_with_items): + topics = disco_state_with_items.topic_items + assert len(topics) == 2 # Auth approach + Hosting + + def test_items_by_status(self, disco_state_with_items): + pending = disco_state_with_items.items_by_status("pending") + assert len(pending) == 1 + assert pending[0].heading == "Auth approach" + + +# ====================================================================== +# Backward-compat aliases +# ====================================================================== + + +class TestBackwardCompatAliases: + """Test old method names still work.""" + + def test_topics_alias(self, disco_state_with_items): + assert disco_state_with_items.topics == disco_state_with_items.items + + def test_has_topics_alias(self, disco_state_with_items): + assert disco_state_with_items.has_topics == disco_state_with_items.has_items + + def test_set_topics_alias(self, disco_state): + items = [TrackedItem(heading="X", detail="x", kind="topic", status="pending", answer_exchange=None)] + disco_state.set_topics(items) + assert len(disco_state._state["items"]) == 1 + + def test_mark_topic_alias(self, disco_state_with_items): + disco_state_with_items.mark_topic("Auth approach", "confirmed") + assert disco_state_with_items._state["items"][0]["status"] == "confirmed" + + def test_first_pending_topic_index_alias(self, disco_state_with_items): + assert disco_state_with_items.first_pending_topic_index() == disco_state_with_items.first_pending_index() + + +# ====================================================================== +# Format methods +# ====================================================================== + + +class TestFormatMethods: + """Test display formatting methods.""" + + def test_format_open_items_with_pending(self, disco_state_with_items): + result = disco_state_with_items.format_open_items() + assert "Auth approach" in result + assert "Open items" in result + + def test_format_open_items_none_pending(self, disco_state): + result = disco_state.format_open_items() + assert "No open items" in result + + def test_format_confirmed_items(self, disco_state_with_items): + result = disco_state_with_items.format_confirmed_items() + assert "DB choice" in result + assert "Hosting" in result + + def test_format_confirmed_items_none(self, disco_state): + result = disco_state.format_confirmed_items() + assert "No items confirmed" in result + + def test_format_status_summary(self, disco_state_with_items): + result = disco_state_with_items.format_status_summary() + assert "2 confirmed" in result + assert "1 open" in result + + def test_format_status_summary_empty(self, disco_state): + result = disco_state.format_status_summary() + assert "No items tracked" in result + + def test_format_as_context_structured(self, disco_state): + disco_state._loaded = True + disco_state._state["project"]["summary"] = "Test project" + disco_state._state["project"]["goals"] = ["Goal 1"] + disco_state._state["requirements"]["functional"] = ["API support"] + disco_state._state["constraints"] = ["PaaS only"] + disco_state._state["decisions"] = ["Use Cosmos DB"] + disco_state._state["architecture"]["services"] = ["cosmos-db"] + result = disco_state.format_as_context() + assert "Test project" in result + assert "Goal 1" in result + assert "API support" in result + assert "PaaS only" in result + assert "Use Cosmos DB" in result + assert "cosmos-db" in result + + def test_format_as_context_falls_back_to_conversation(self, disco_state): + """When structured fields are empty, falls back to conversation summary.""" + disco_state._loaded = True + disco_state._state["conversation_history"] = [ + { + "user": "Tell me about the project", + "assistant": "## Project Summary\nThis is a test project.\n## Confirmed Functional Requirements\n- API", + } + ] + result = disco_state.format_as_context() + assert "Project Summary" in result + + def test_format_as_context_not_loaded(self, disco_state): + result = disco_state.format_as_context() + assert result == "" + + +# ====================================================================== +# Merge learnings +# ====================================================================== + + +class TestMergeLearnings: + """Test merge_learnings integrates structured data.""" + + def test_merge_project_summary(self, disco_state): + disco_state.merge_learnings({"project": {"summary": "New summary"}}) + assert disco_state._state["project"]["summary"] == "New summary" + + def test_merge_goals(self, disco_state): + disco_state.merge_learnings({"project": {"goals": ["G1", "G2"]}}) + assert disco_state._state["project"]["goals"] == ["G1", "G2"] + + def test_merge_requirements(self, disco_state): + disco_state.merge_learnings({"requirements": {"functional": ["R1"], "non_functional": ["NF1"]}}) + assert disco_state._state["requirements"]["functional"] == ["R1"] + assert disco_state._state["requirements"]["non_functional"] == ["NF1"] + + def test_merge_deduplicates(self, disco_state): + disco_state.merge_learnings({"constraints": ["C1"]}) + disco_state.merge_learnings({"constraints": ["C1", "C2"]}) + assert disco_state._state["constraints"] == ["C1", "C2"] + + def test_merge_open_items_creates_decisions(self, disco_state): + disco_state.merge_learnings({"open_items": ["Choose DB"]}) + assert len(disco_state._state["items"]) == 1 + assert disco_state._state["items"][0]["kind"] == "decision" + + def test_merge_resolved_items(self, disco_state): + disco_state.add_open_item("Choose DB") + disco_state.merge_learnings({"resolved_items": ["Choose DB"]}) + assert disco_state._state["items"][0]["status"] == "confirmed" + + def test_merge_scope(self, disco_state): + disco_state.merge_learnings({"scope": {"in_scope": ["API"], "deferred": ["ML"]}}) + assert "API" in disco_state._state["scope"]["in_scope"] + assert "ML" in disco_state._state["scope"]["deferred"] + + def test_merge_architecture(self, disco_state): + disco_state.merge_learnings({"architecture": {"services": ["cosmos-db"], "data_flow": "API -> DB"}}) + assert "cosmos-db" in disco_state._state["architecture"]["services"] + assert disco_state._state["architecture"]["data_flow"] == "API -> DB" + + +# ====================================================================== +# Search history +# ====================================================================== + + +class TestSearchHistory: + """Test conversation history search.""" + + def test_search_finds_match(self, disco_state): + disco_state._state["conversation_history"] = [ + {"user": "Tell me about Cosmos DB", "assistant": "It is a NoSQL database"}, + {"user": "What about SQL?", "assistant": "Relational database"}, + ] + results = disco_state.search_history("cosmos") + assert len(results) == 1 + + def test_search_no_match(self, disco_state): + disco_state._state["conversation_history"] = [ + {"user": "Hello", "assistant": "Hi"}, + ] + results = disco_state.search_history("cosmos") + assert len(results) == 0 + + +# ====================================================================== +# topic_at_exchange +# ====================================================================== + + +class TestTopicAtExchange: + """Test finding which topic was discussed at a given exchange.""" + + def test_finds_topic_at_exchange(self, disco_state_with_items): + # DB choice answered at exchange 2, Hosting at 3 + result = disco_state_with_items.topic_at_exchange(2) + assert result == "DB choice" + + def test_returns_none_no_answered_items(self, disco_state): + assert disco_state.topic_at_exchange(1) is None + + def test_returns_none_past_all_exchanges(self, disco_state_with_items): + # Exchange 10 is past all answered items + result = disco_state_with_items.topic_at_exchange(10) + assert result is None + + +# ====================================================================== +# Artifact inventory +# ====================================================================== + + +class TestArtifactInventory: + """Test artifact hash tracking.""" + + def test_update_and_get(self, disco_state): + disco_state.update_artifact_inventory({"/path/to/file.txt": "abc123"}) + hashes = disco_state.get_artifact_hashes() + assert hashes["/path/to/file.txt"] == "abc123" + + def test_additive_updates(self, disco_state): + disco_state.update_artifact_inventory({"/a": "hash1"}) + disco_state.update_artifact_inventory({"/b": "hash2"}) + hashes = disco_state.get_artifact_hashes() + assert "/a" in hashes + assert "/b" in hashes + + +# ====================================================================== +# Context hash +# ====================================================================== + + +class TestContextHash: + """Test context hash for change detection.""" + + def test_update_and_get(self, disco_state): + disco_state.update_context_hash("abc123") + assert disco_state.get_context_hash() == "abc123" + + def test_default_empty(self, disco_state): + assert disco_state.get_context_hash() == "" + + +# ====================================================================== +# Reset +# ====================================================================== + + +class TestReset: + """Test state reset.""" + + def test_reset_clears_state(self, disco_state_with_items): + disco_state_with_items.reset() + assert disco_state_with_items._state["items"] == [] + assert disco_state_with_items._loaded is False + + +# ====================================================================== +# TrackedItem dataclass +# ====================================================================== + + +class TestTrackedItem: + """Test TrackedItem serialization.""" + + def test_to_dict(self): + item = TrackedItem(heading="H", detail="D", kind="topic", status="pending", answer_exchange=None) + d = item.to_dict() + assert d["heading"] == "H" + assert d["kind"] == "topic" + + def test_from_dict(self): + d = {"heading": "H", "detail": "D", "kind": "decision", "status": "confirmed", "answer_exchange": 3} + item = TrackedItem.from_dict(d) + assert item.heading == "H" + assert item.answer_exchange == 3 + + def test_from_dict_legacy_questions_key(self): + """Old format used 'questions' instead of 'detail'.""" + d = {"heading": "H", "questions": "Q?", "status": "pending"} + item = TrackedItem.from_dict(d) + assert item.detail == "Q?" + + + +class TestDiscoveryStateScope: + """Test the scope fields in DiscoveryState.""" + + def test_default_state_has_scope(self): + state = _default_discovery_state() + assert "scope" in state + assert state["scope"] == { + "in_scope": [], + "out_of_scope": [], + "deferred": [], + } + + def test_merge_learnings_with_scope(self, tmp_path): + ds = DiscoveryState(str(tmp_path)) + ds.load() + + learnings = { + "scope": { + "in_scope": ["REST API", "SQL Database"], + "out_of_scope": ["Mobile app"], + "deferred": ["CI/CD pipeline"], + }, + } + ds.merge_learnings(learnings) + + assert ds.state["scope"]["in_scope"] == ["REST API", "SQL Database"] + assert ds.state["scope"]["out_of_scope"] == ["Mobile app"] + assert ds.state["scope"]["deferred"] == ["CI/CD pipeline"] + + def test_merge_learnings_deduplicates_scope(self, tmp_path): + ds = DiscoveryState(str(tmp_path)) + ds.load() + ds.state["scope"]["in_scope"] = ["REST API"] + + learnings = { + "scope": { + "in_scope": ["REST API", "SQL Database"], + }, + } + ds.merge_learnings(learnings) + + assert ds.state["scope"]["in_scope"] == ["REST API", "SQL Database"] + + def test_merge_learnings_partial_scope(self, tmp_path): + ds = DiscoveryState(str(tmp_path)) + ds.load() + + learnings = { + "scope": { + "in_scope": ["API endpoints"], + }, + } + ds.merge_learnings(learnings) + + assert ds.state["scope"]["in_scope"] == ["API endpoints"] + assert ds.state["scope"]["out_of_scope"] == [] + assert ds.state["scope"]["deferred"] == [] + + def test_merge_learnings_without_scope(self, tmp_path): + """Learnings without scope should not break merge.""" + ds = DiscoveryState(str(tmp_path)) + ds.load() + + learnings = { + "project": {"summary": "Test", "goals": ["Goal 1"]}, + } + ds.merge_learnings(learnings) + + assert ds.state["scope"]["in_scope"] == [] + + def test_format_as_context_includes_scope(self, tmp_path): + ds = DiscoveryState(str(tmp_path)) + ds.load() + ds._loaded = True + ds.state["scope"] = { + "in_scope": ["REST API"], + "out_of_scope": ["Mobile app"], + "deferred": ["CI/CD"], + } + + context = ds.format_as_context() + assert "## Prototype Scope" in context + assert "### In Scope" in context + assert "REST API" in context + assert "### Out of Scope" in context + assert "Mobile app" in context + assert "### Deferred / Future Work" in context + assert "CI/CD" in context + + def test_format_as_context_partial_scope(self, tmp_path): + ds = DiscoveryState(str(tmp_path)) + ds.load() + ds._loaded = True + ds.state["scope"]["in_scope"] = ["REST API"] + + context = ds.format_as_context() + assert "### In Scope" in context + assert "### Out of Scope" not in context + assert "### Deferred" not in context + + def test_format_as_context_omits_empty_scope(self, tmp_path): + ds = DiscoveryState(str(tmp_path)) + ds.load() + ds._loaded = True + ds.state["project"]["summary"] = "Test project" + + context = ds.format_as_context() + assert "Prototype Scope" not in context + + def test_format_as_context_falls_back_to_conversation(self, tmp_path): + """When structured fields are empty, format_as_context uses conversation history.""" + ds = DiscoveryState(str(tmp_path)) + ds.load() + ds._loaded = True + # Structured fields are all empty (default), but conversation has content + ds.state["conversation_history"] = [ + {"exchange": 1, "assistant": "Tell me more."}, + { + "exchange": 2, + "assistant": ( + "## Project Summary\nA web app for email drafting.\n\n" + "## Confirmed Functional Requirements\n- Feature A\n\n" + "[READY]" + ), + }, + ] + + context = ds.format_as_context() + assert "## Project Summary" in context + assert "email drafting" in context + assert "Feature A" in context + assert "[READY]" not in context + + def test_format_as_context_prefers_structured_fields(self, tmp_path): + """When structured fields are populated, those are used instead of conversation.""" + ds = DiscoveryState(str(tmp_path)) + ds.load() + ds._loaded = True + ds.state["project"]["summary"] = "Structured summary" + ds.state["conversation_history"] = [ + { + "exchange": 1, + "assistant": "## Project Summary\nConversation summary.\n\n## Confirmed Functional Requirements\n- X", + }, + ] + + context = ds.format_as_context() + assert "Structured summary" in context + assert "Conversation summary" not in context + + def test_extract_conversation_summary(self, tmp_path): + """extract_conversation_summary returns last assistant message with summary headings.""" + ds = DiscoveryState(str(tmp_path)) + ds.load() + ds.state["conversation_history"] = [ + {"exchange": 1, "assistant": "Tell me more."}, + { + "exchange": 2, + "assistant": "## Project Summary\nA web app.\n\n[READY]", + }, + ] + + result = ds.extract_conversation_summary() + assert "## Project Summary" in result + assert "[READY]" not in result + + def test_extract_conversation_summary_empty_history(self, tmp_path): + ds = DiscoveryState(str(tmp_path)) + ds.load() + + assert ds.extract_conversation_summary() == "" + + def test_scope_persists_to_yaml(self, tmp_path): + ds = DiscoveryState(str(tmp_path)) + ds.load() + ds.state["scope"]["in_scope"] = ["API endpoints"] + ds.state["scope"]["out_of_scope"] = ["Mobile app"] + ds.save() + + ds2 = DiscoveryState(str(tmp_path)) + ds2.load() + assert ds2.state["scope"]["in_scope"] == ["API endpoints"] + assert ds2.state["scope"]["out_of_scope"] == ["Mobile app"] + assert ds2.state["scope"]["deferred"] == [] diff --git a/tests/stages/test_escalation.py b/tests/stages/test_escalation.py new file mode 100644 index 0000000..7f8b8c2 --- /dev/null +++ b/tests/stages/test_escalation.py @@ -0,0 +1,1062 @@ +from __future__ import annotations + +"""Tests for EscalationTracker — 4-level escalation chain. + +Tier 2: Conditional branches with multiple paths. + +Covers: +- EscalationEntry dataclass: to_dict / from_dict round-trip +- record_blocker() creates L1 entry, persists +- record_attempted_solution() appends and saves +- resolve() marks resolved, saves +- get_active_blockers() filters resolved +- Escalation chain: + - L1 (documented) -> L2 (agent: architect vs PM) + - L2 scope keywords -> project-manager, else -> cloud-architect + - L2 with no agent available -> fallback message + - L2 agent execution failure -> error message + - L3 web search -> success / failure / import error + - L4 human flag + - Already at L4 -> no escalation +- should_auto_escalate(): + - resolved entry -> False + - L4 entry -> False + - within timeout -> False + - exceeded timeout -> True + - bad timestamp -> False +- format_escalation_report() formatting +- State persistence: save / load round-trip +""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from azext_prototype.stages.escalation import EscalationEntry, EscalationTracker + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def tracker(tmp_path): + """EscalationTracker with a temp project directory.""" + project_dir = str(tmp_path / "test-project") + (tmp_path / "test-project" / ".prototype" / "state").mkdir(parents=True) + return EscalationTracker(project_dir) + + +@pytest.fixture +def sample_entry(): + """A sample escalation entry at L1.""" + now = datetime.now(timezone.utc).isoformat() + return EscalationEntry( + task_description="Deploy container app", + blocker="Container registry not accessible", + attempted_solutions=["Checked ACR network rules"], + escalation_level=1, + source_agent="terraform-agent", + source_stage="build", + created_at=now, + last_escalated_at=now, + ) + + +# ------------------------------------------------------------------ +# EscalationEntry dataclass +# ------------------------------------------------------------------ + + +class TestEscalationEntry: + def test_to_dict_round_trip(self, sample_entry): + d = sample_entry.to_dict() + restored = EscalationEntry.from_dict(d) + assert restored.task_description == sample_entry.task_description + assert restored.blocker == sample_entry.blocker + assert restored.attempted_solutions == sample_entry.attempted_solutions + assert restored.escalation_level == sample_entry.escalation_level + assert restored.source_agent == sample_entry.source_agent + assert restored.resolved == sample_entry.resolved + + def test_from_dict_missing_fields_uses_defaults(self): + entry = EscalationEntry.from_dict({"task_description": "Deploy", "blocker": "Blocked"}) + assert entry.escalation_level == 1 + assert entry.attempted_solutions == [] + assert entry.resolved is False + assert entry.source_agent == "" + + def test_from_dict_empty_dict(self): + entry = EscalationEntry.from_dict({}) + assert entry.task_description == "" + assert entry.blocker == "" + + +# ------------------------------------------------------------------ +# Blocker management +# ------------------------------------------------------------------ + + +class TestBlockerManagement: + def test_record_blocker_creates_l1_entry(self, tracker): + entry = tracker.record_blocker( + "Deploy app", + "Auth failure", + source_agent="terraform-agent", + source_stage="deploy", + ) + assert entry.escalation_level == 1 + assert entry.blocker == "Auth failure" + assert entry.source_agent == "terraform-agent" + assert entry.created_at != "" + + def test_record_blocker_persists(self, tracker): + tracker.record_blocker("task", "blocker", source_agent="agent", source_stage="stage") + assert tracker._state_path.exists() + + def test_record_attempted_solution(self, tracker, sample_entry): + tracker._entries.append(sample_entry) + tracker.record_attempted_solution(sample_entry, "Tried a different SKU") + assert "Tried a different SKU" in sample_entry.attempted_solutions + + def test_resolve_marks_entry(self, tracker, sample_entry): + tracker._entries.append(sample_entry) + tracker.resolve(sample_entry, "Switched to public ACR") + assert sample_entry.resolved is True + assert sample_entry.resolution == "Switched to public ACR" + + def test_get_active_blockers_excludes_resolved(self, tracker): + e1 = tracker.record_blocker("t1", "b1", source_agent="a", source_stage="s") + tracker.record_blocker("t2", "b2", source_agent="a", source_stage="s") + tracker.resolve(e1, "Fixed") + active = tracker.get_active_blockers() + assert len(active) == 1 + assert active[0].task_description == "t2" + + +# ------------------------------------------------------------------ +# State persistence +# ------------------------------------------------------------------ + + +class TestStatePersistence: + def test_save_and_load_round_trip(self, tracker): + tracker.record_blocker("task1", "blocker1", source_agent="agent1", source_stage="build") + tracker.record_blocker("task2", "blocker2", source_agent="agent2", source_stage="deploy") + + tracker2 = EscalationTracker(tracker._project_dir) + tracker2.load() + assert len(tracker2._entries) == 2 + assert tracker2._entries[0].task_description == "task1" + assert tracker2._entries[1].blocker == "blocker2" + + def test_load_nonexistent_file(self, tmp_path): + t = EscalationTracker(str(tmp_path / "no-project")) + t.load() + assert t._entries == [] + + def test_exists_property(self, tracker): + assert tracker.exists is False + tracker.record_blocker("t", "b", source_agent="a", source_stage="s") + assert tracker.exists is True + + +# ------------------------------------------------------------------ +# Escalation chain — L1 -> L2 +# ------------------------------------------------------------------ + + +class TestEscalateToAgent: + def _make_registry_and_context(self, agent_response="Here is the fix"): + mock_agent = MagicMock() + mock_agent.execute.return_value = MagicMock(content=agent_response) + + registry = MagicMock() + registry.find_by_capability.return_value = [mock_agent] + + agent_context = MagicMock() + agent_context.ai_provider = MagicMock() + + return registry, agent_context, mock_agent + + def test_technical_blocker_escalates_to_architect(self, tracker, sample_entry): + tracker._entries.append(sample_entry) + registry, ctx, agent = self._make_registry_and_context() + printed = [] + + result = tracker.escalate(sample_entry, registry, ctx, printed.append) + + assert result["escalated"] is True + assert result["level"] == 2 + # Should use ARCHITECT capability (not BACKLOG_GENERATION) + from azext_prototype.agents.base import AgentCapability + + registry.find_by_capability.assert_called_once_with(AgentCapability.ARCHITECT) + + def test_scope_blocker_escalates_to_pm(self, tracker): + entry = EscalationEntry( + task_description="Define feature scope", + blocker="Unclear requirement for the backlog story", + source_agent="biz-analyst", + source_stage="design", + created_at=datetime.now(timezone.utc).isoformat(), + last_escalated_at=datetime.now(timezone.utc).isoformat(), + ) + tracker._entries.append(entry) + + registry, ctx, agent = self._make_registry_and_context() + result = tracker.escalate(entry, registry, ctx, MagicMock()) + + from azext_prototype.agents.base import AgentCapability + + registry.find_by_capability.assert_called_once_with(AgentCapability.BACKLOG_GENERATION) + assert result["level"] == 2 + + def test_scope_keywords_detected(self, tracker): + """Each scope keyword routes to PM.""" + scope_keywords = ["scope", "requirement", "backlog", "story", "feature", "stakeholder", "priority", "sprint"] + for kw in scope_keywords: + entry = EscalationEntry( + task_description="task", + blocker=f"Issue with {kw}", + source_agent="a", + source_stage="s", + created_at=datetime.now(timezone.utc).isoformat(), + last_escalated_at=datetime.now(timezone.utc).isoformat(), + ) + tracker._entries.append(entry) + + registry, ctx, _ = self._make_registry_and_context() + tracker.escalate(entry, registry, ctx, MagicMock()) + + from azext_prototype.agents.base import AgentCapability + + registry.find_by_capability.assert_called_once_with(AgentCapability.BACKLOG_GENERATION) + + def test_no_agent_available(self, tracker, sample_entry): + tracker._entries.append(sample_entry) + registry = MagicMock() + registry.find_by_capability.return_value = [] + ctx = MagicMock() + ctx.ai_provider = MagicMock() + printed = [] + + result = tracker.escalate(sample_entry, registry, ctx, printed.append) + + assert result["level"] == 2 + assert "No cloud-architect available" in result["content"] + + def test_no_agent_context(self, tracker, sample_entry): + tracker._entries.append(sample_entry) + registry = MagicMock() + registry.find_by_capability.return_value = [MagicMock()] + + result = tracker.escalate(sample_entry, registry, None, MagicMock()) + assert "No cloud-architect available" in result["content"] + + def test_no_ai_provider_on_context(self, tracker, sample_entry): + tracker._entries.append(sample_entry) + registry = MagicMock() + registry.find_by_capability.return_value = [MagicMock()] + ctx = MagicMock() + ctx.ai_provider = None + + result = tracker.escalate(sample_entry, registry, ctx, MagicMock()) + assert "No cloud-architect available" in result["content"] + + def test_agent_execution_failure(self, tracker, sample_entry): + tracker._entries.append(sample_entry) + mock_agent = MagicMock() + mock_agent.execute.side_effect = RuntimeError("model down") + + registry = MagicMock() + registry.find_by_capability.return_value = [mock_agent] + ctx = MagicMock() + ctx.ai_provider = MagicMock() + + result = tracker.escalate(sample_entry, registry, ctx, MagicMock()) + assert "Agent escalation failed" in result["content"] + + def test_agent_returns_none_response(self, tracker, sample_entry): + tracker._entries.append(sample_entry) + mock_agent = MagicMock() + mock_agent.execute.return_value = None + + registry = MagicMock() + registry.find_by_capability.return_value = [mock_agent] + ctx = MagicMock() + ctx.ai_provider = MagicMock() + + result = tracker.escalate(sample_entry, registry, ctx, MagicMock()) + assert result["content"] == "" + + +# ------------------------------------------------------------------ +# Escalation chain — L2 -> L3 (web search) +# ------------------------------------------------------------------ + + +class TestEscalateToWebSearch: + def test_web_search_success(self, tracker, sample_entry): + sample_entry.escalation_level = 2 + tracker._entries.append(sample_entry) + + with patch( + "azext_prototype.stages.escalation.EscalationTracker._escalate_to_web_search", + return_value="Found docs on ACR networking", + ): + result = tracker.escalate(sample_entry, MagicMock(), MagicMock(), MagicMock()) + + assert result["level"] == 3 + assert result["escalated"] is True + + def test_web_search_with_real_import(self, tracker, sample_entry): + sample_entry.escalation_level = 2 + tracker._entries.append(sample_entry) + + with patch("azext_prototype.knowledge.web_search.search_and_fetch", return_value="Doc content"): + printed = [] + result = tracker.escalate(sample_entry, MagicMock(), MagicMock(), printed.append) + + assert result["level"] == 3 + assert result["content"] == "Doc content" + + def test_web_search_no_results(self, tracker, sample_entry): + sample_entry.escalation_level = 2 + tracker._entries.append(sample_entry) + + with patch("azext_prototype.knowledge.web_search.search_and_fetch", return_value=""): + printed = [] + result = tracker.escalate(sample_entry, MagicMock(), MagicMock(), printed.append) + + assert result["level"] == 3 + assert "No web results found" in result["content"] + + def test_web_search_exception(self, tracker, sample_entry): + sample_entry.escalation_level = 2 + tracker._entries.append(sample_entry) + + with patch( + "azext_prototype.knowledge.web_search.search_and_fetch", + side_effect=RuntimeError("network down"), + ): + printed = [] + result = tracker.escalate(sample_entry, MagicMock(), MagicMock(), printed.append) + + assert result["level"] == 3 + assert "Web search failed" in result["content"] + + +# ------------------------------------------------------------------ +# Escalation chain — L3 -> L4 (human) +# ------------------------------------------------------------------ + + +class TestEscalateToHuman: + def test_human_escalation(self, tracker, sample_entry): + sample_entry.escalation_level = 3 + tracker._entries.append(sample_entry) + + printed = [] + result = tracker.escalate(sample_entry, MagicMock(), MagicMock(), printed.append) + + assert result["level"] == 4 + assert result["escalated"] is True + assert "Flagged for human intervention" in result["content"] + assert any("HUMAN INTERVENTION REQUIRED" in msg for msg in printed) + + def test_human_escalation_includes_details(self, tracker, sample_entry): + sample_entry.escalation_level = 3 + sample_entry.attempted_solutions = ["Tried A", "Tried B"] + tracker._entries.append(sample_entry) + + printed = [] + tracker.escalate(sample_entry, MagicMock(), MagicMock(), printed.append) + + full_output = "\n".join(printed) + assert sample_entry.task_description in full_output + assert sample_entry.blocker in full_output + assert "Tried A" in full_output + assert "Tried B" in full_output + + +# ------------------------------------------------------------------ +# Already at L4 — no further escalation +# ------------------------------------------------------------------ + + +class TestAlreadyAtMaxLevel: + def test_l4_cannot_escalate(self, tracker, sample_entry): + sample_entry.escalation_level = 4 + tracker._entries.append(sample_entry) + + result = tracker.escalate(sample_entry, MagicMock(), MagicMock(), MagicMock()) + + assert result["escalated"] is False + assert result["level"] == 4 + assert "Already at human level" in result["content"] + + +# ------------------------------------------------------------------ +# should_auto_escalate() +# ------------------------------------------------------------------ + + +class TestShouldAutoEscalate: + def test_resolved_entry_returns_false(self, tracker): + entry = EscalationEntry( + task_description="t", + blocker="b", + resolved=True, + last_escalated_at=datetime(2020, 1, 1, tzinfo=timezone.utc).isoformat(), + ) + assert tracker.should_auto_escalate(entry, timeout_seconds=0) is False + + def test_l4_entry_returns_false(self, tracker): + entry = EscalationEntry( + task_description="t", + blocker="b", + escalation_level=4, + last_escalated_at=datetime(2020, 1, 1, tzinfo=timezone.utc).isoformat(), + ) + assert tracker.should_auto_escalate(entry, timeout_seconds=0) is False + + def test_within_timeout_returns_false(self, tracker): + recent = datetime.now(timezone.utc).isoformat() + entry = EscalationEntry( + task_description="t", + blocker="b", + escalation_level=1, + last_escalated_at=recent, + ) + assert tracker.should_auto_escalate(entry, timeout_seconds=120) is False + + def test_exceeded_timeout_returns_true(self, tracker): + old = (datetime.now(timezone.utc) - timedelta(seconds=300)).isoformat() + entry = EscalationEntry( + task_description="t", + blocker="b", + escalation_level=1, + last_escalated_at=old, + ) + assert tracker.should_auto_escalate(entry, timeout_seconds=120) is True + + def test_bad_timestamp_returns_false(self, tracker): + entry = EscalationEntry( + task_description="t", + blocker="b", + escalation_level=1, + last_escalated_at="not-a-date", + ) + assert tracker.should_auto_escalate(entry, timeout_seconds=0) is False + + def test_empty_timestamp_returns_false(self, tracker): + entry = EscalationEntry( + task_description="t", + blocker="b", + escalation_level=1, + last_escalated_at="", + ) + assert tracker.should_auto_escalate(entry, timeout_seconds=0) is False + + def test_l2_entry_can_auto_escalate(self, tracker): + old = (datetime.now(timezone.utc) - timedelta(seconds=300)).isoformat() + entry = EscalationEntry( + task_description="t", + blocker="b", + escalation_level=2, + last_escalated_at=old, + ) + assert tracker.should_auto_escalate(entry, timeout_seconds=120) is True + + def test_l3_entry_can_auto_escalate(self, tracker): + old = (datetime.now(timezone.utc) - timedelta(seconds=300)).isoformat() + entry = EscalationEntry( + task_description="t", + blocker="b", + escalation_level=3, + last_escalated_at=old, + ) + assert tracker.should_auto_escalate(entry, timeout_seconds=120) is True + + +# ------------------------------------------------------------------ +# format_escalation_report() +# ------------------------------------------------------------------ + + +class TestFormatReport: + def test_no_entries(self, tracker): + report = tracker.format_escalation_report() + assert "No blockers recorded" in report + + def test_active_blockers_in_report(self, tracker): + tracker.record_blocker("Deploy app", "Auth error", source_agent="tf", source_stage="deploy") + report = tracker.format_escalation_report() + assert "Active Blockers (1)" in report + assert "Deploy app" in report + assert "Auth error" in report + assert "Documented" in report # L1 label + + def test_resolved_in_report(self, tracker): + entry = tracker.record_blocker("task", "blocker", source_agent="a", source_stage="s") + tracker.resolve(entry, "Fixed by reconfig") + report = tracker.format_escalation_report() + assert "Resolved (1)" in report + assert "Fixed by reconfig" in report + + def test_mixed_active_and_resolved(self, tracker): + e1 = tracker.record_blocker("t1", "b1", source_agent="a", source_stage="s") + tracker.record_blocker("t2", "b2", source_agent="a", source_stage="s") + tracker.resolve(e1, "done") + report = tracker.format_escalation_report() + assert "Active Blockers (1)" in report + assert "Resolved (1)" in report + + def test_level_labels(self, tracker): + entry = tracker.record_blocker("t", "b", source_agent="a", source_stage="s") + entry.escalation_level = 2 + report = tracker.format_escalation_report() + assert "Agent" in report + + def test_attempted_solutions_count(self, tracker): + entry = tracker.record_blocker("t", "b", source_agent="a", source_stage="s") + tracker.record_attempted_solution(entry, "sol1") + tracker.record_attempted_solution(entry, "sol2") + report = tracker.format_escalation_report() + assert "Attempts: 2" in report + + +# --- Additional imports from merged flat test --- +from azext_prototype.agents.base import AgentContext +from azext_prototype.ai.provider import AIResponse +from azext_prototype.stages.backlog_session import BacklogSession +from azext_prototype.stages.backlog_state import BacklogState +from azext_prototype.stages.build_session import BuildSession +from azext_prototype.stages.deploy_session import DeploySession +from azext_prototype.stages.deploy_state import DeployState +from azext_prototype.stages.qa_router import route_error_to_qa +from pathlib import Path +import yaml + + +# ====================================================================== + + +def _make_entry(**kwargs) -> EscalationEntry: + defaults = { + "task_description": "Build Stage 3: Data Layer", + "blocker": "Cosmos DB requires premium tier", + "source_agent": "terraform-agent", + "source_stage": "build", + "created_at": datetime.now(timezone.utc).isoformat(), + "last_escalated_at": datetime.now(timezone.utc).isoformat(), + } + defaults.update(kwargs) + return EscalationEntry(**defaults) + +def _make_registry(architect_response=None, pm_response=None): + from azext_prototype.agents.base import AgentCapability + + architect = MagicMock() + architect.name = "cloud-architect" + if architect_response: + architect.execute.return_value = architect_response + else: + architect.execute.return_value = MagicMock(content="Use Standard tier instead") + + pm = MagicMock() + pm.name = "project-manager" + if pm_response: + pm.execute.return_value = pm_response + else: + pm.execute.return_value = MagicMock(content="Descope this item") + + registry = MagicMock() + + def find_by_cap(cap): + if cap == AgentCapability.ARCHITECT: + return [architect] + if cap == AgentCapability.BACKLOG_GENERATION: + return [pm] + return [] + + registry.find_by_capability.side_effect = find_by_cap + + return registry, architect, pm + +def _make_context(): + from azext_prototype.agents.base import AgentContext + + return AgentContext( + project_config={"project": {"name": "test"}}, + project_dir="/tmp/test", + ai_provider=MagicMock(), + ) + +# ====================================================================== + + +class TestEscalationTrackerState: + + def test_record_blocker(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + + entry = tracker.record_blocker( + "Deploy Redis", + "Premium tier required", + "terraform-agent", + "deploy", + ) + + assert entry.task_description == "Deploy Redis" + assert entry.blocker == "Premium tier required" + assert entry.escalation_level == 1 + assert entry.created_at != "" + assert len(tracker.get_active_blockers()) == 1 + + def test_record_attempted_solution(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + + tracker.record_attempted_solution(entry, "Tried standard tier") + tracker.record_attempted_solution(entry, "Tried basic tier") + + assert len(entry.attempted_solutions) == 2 + assert "Tried standard tier" in entry.attempted_solutions + + def test_resolve_blocker(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + + tracker.resolve(entry, "Used standard tier instead") + + assert entry.resolved is True + assert entry.resolution == "Used standard tier instead" + assert len(tracker.get_active_blockers()) == 0 + + def test_get_active_blockers_filters_resolved(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + e1 = tracker.record_blocker("task1", "blocked1", "a1", "s1") + e2 = tracker.record_blocker("task2", "blocked2", "a2", "s2") # noqa: F841 + tracker.resolve(e1, "fixed") + + active = tracker.get_active_blockers() + assert len(active) == 1 + assert active[0].task_description == "task2" + + def test_save_load_roundtrip(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + tracker.record_blocker("task1", "blocked1", "agent1", "stage1") + tracker.record_blocker("task2", "blocked2", "agent2", "stage2") + + tracker2 = EscalationTracker(str(tmp_project)) + tracker2.load() + + assert len(tracker2.get_active_blockers()) == 2 + assert tracker2.get_active_blockers()[0].task_description == "task1" + + def test_save_creates_yaml(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + tracker.record_blocker("task", "blocked", "agent", "stage") + + yaml_path = Path(str(tmp_project)) / ".prototype" / "state" / "escalation.yaml" + assert yaml_path.exists() + + with open(yaml_path) as f: + data = yaml.safe_load(f) + assert len(data["entries"]) == 1 + + def test_exists_property(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + assert not tracker.exists + + tracker.record_blocker("task", "blocked", "agent", "stage") + assert tracker.exists + + def test_empty_load(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + tracker.load() # No file exists + assert tracker.get_active_blockers() == [] + + def test_multiple_records_and_resolves(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + e1 = tracker.record_blocker("t1", "b1", "a", "s") + e2 = tracker.record_blocker("t2", "b2", "a", "s") # noqa: F841 + e3 = tracker.record_blocker("t3", "b3", "a", "s") + + tracker.resolve(e1, "fixed") + tracker.resolve(e3, "workaround") + + assert len(tracker.get_active_blockers()) == 1 + assert tracker.get_active_blockers()[0].task_description == "t2" + +# ====================================================================== + + +class TestEscalationChain: + + def test_level_1_to_2_technical(self, tmp_project): + """Technical blocker escalates to architect.""" + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker( + "Deploy Cosmos DB", + "Premium tier required for multi-region", + "terraform-agent", + "build", + ) + + registry, architect, pm = _make_registry() + ctx = _make_context() + printed = [] + + result = tracker.escalate(entry, registry, ctx, printed.append) + + assert result["escalated"] is True + assert result["level"] == 2 + assert entry.escalation_level == 2 + architect.execute.assert_called_once() + pm.execute.assert_not_called() + + def test_level_1_to_2_scope(self, tmp_project): + """Scope blocker escalates to project-manager.""" + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker( + "Backlog items", + "Scope of feature is unclear", + "biz-analyst", + "design", + ) + + registry, architect, pm = _make_registry() + ctx = _make_context() + printed = [] + + result = tracker.escalate(entry, registry, ctx, printed.append) + + assert result["escalated"] is True + assert result["level"] == 2 + pm.execute.assert_called_once() + architect.execute.assert_not_called() + + @patch("azext_prototype.stages.escalation.EscalationTracker._escalate_to_web_search") + def test_level_2_to_3_web_search(self, mock_web, tmp_project): + """Level 2→3 triggers web search.""" + mock_web.return_value = "Found: Azure docs suggest..." + + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + entry.escalation_level = 2 # Already at level 2 + + registry, _, _ = _make_registry() + ctx = _make_context() + printed = [] + + result = tracker.escalate(entry, registry, ctx, printed.append) + + assert result["escalated"] is True + assert result["level"] == 3 + mock_web.assert_called_once() + + def test_level_3_to_4_human(self, tmp_project): + """Level 3→4 flags for human intervention.""" + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + entry.escalation_level = 3 # Already at level 3 + + registry, _, _ = _make_registry() + ctx = _make_context() + printed = [] + + result = tracker.escalate(entry, registry, ctx, printed.append) + + assert result["escalated"] is True + assert result["level"] == 4 + assert any("HUMAN INTERVENTION" in p for p in printed) + + def test_already_at_level_4_no_escalation(self, tmp_project): + """Cannot escalate past level 4.""" + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + entry.escalation_level = 4 + + registry, _, _ = _make_registry() + ctx = _make_context() + printed = [] + + result = tracker.escalate(entry, registry, ctx, printed.append) + + assert result["escalated"] is False + assert result["level"] == 4 + + def test_no_agent_available_for_escalation(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + + registry = MagicMock() + registry.find_by_capability.return_value = [] + ctx = _make_context() + printed = [] + + result = tracker.escalate(entry, registry, ctx, printed.append) + + assert result["level"] == 2 + assert "No cloud-architect available" in result["content"] + + def test_agent_escalation_failure(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + + registry, architect, _ = _make_registry() + architect.execute.side_effect = RuntimeError("AI crashed") + ctx = _make_context() + printed = [] + + result = tracker.escalate(entry, registry, ctx, printed.append) + + assert result["level"] == 2 + assert "failed" in result["content"].lower() + + def test_web_search_failure_graceful(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + entry.escalation_level = 2 + + printed = [] + + with patch("azext_prototype.stages.escalation.EscalationTracker._escalate_to_web_search") as mock_ws: + mock_ws.return_value = "Web search failed: connection error" + + registry, _, _ = _make_registry() + ctx = _make_context() + result = tracker.escalate(entry, registry, ctx, printed.append) + + assert result["level"] == 3 + assert "failed" in result["content"].lower() + +# ====================================================================== + + +class TestAutoEscalation: + + def test_timeout_triggers_escalation(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + + # Set last_escalated_at to 5 minutes ago + old_time = datetime.now(timezone.utc) - timedelta(minutes=5) + entry.last_escalated_at = old_time.isoformat() + + assert tracker.should_auto_escalate(entry, timeout_seconds=120) + + def test_not_yet_timed_out(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + + # Just created, so not timed out + assert not tracker.should_auto_escalate(entry, timeout_seconds=120) + + def test_resolved_stops_escalation(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + tracker.resolve(entry, "fixed") + + old_time = datetime.now(timezone.utc) - timedelta(minutes=5) + entry.last_escalated_at = old_time.isoformat() + + assert not tracker.should_auto_escalate(entry) + + def test_level_4_stops_escalation(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + entry.escalation_level = 4 + + old_time = datetime.now(timezone.utc) - timedelta(minutes=5) + entry.last_escalated_at = old_time.isoformat() + + assert not tracker.should_auto_escalate(entry) + + def test_invalid_timestamp_returns_false(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + entry = tracker.record_blocker("task", "blocked", "agent", "stage") + entry.last_escalated_at = "not-a-timestamp" + + assert not tracker.should_auto_escalate(entry) + +# ====================================================================== + + +class TestQARouterIntegration: + + def test_qa_router_records_blocker_on_undiagnosed(self, tmp_project): + from azext_prototype.ai.provider import AIResponse + from azext_prototype.stages.qa_router import route_error_to_qa + + tracker = EscalationTracker(str(tmp_project)) + + # QA returns empty — undiagnosed + qa = MagicMock() + qa.execute.return_value = AIResponse(content="", model="gpt-4o", usage={}) + + ctx = _make_context() + + result = route_error_to_qa( + "Deployment failed", + "Deploy Stage 1", + qa, + ctx, + None, + lambda m: None, + escalation_tracker=tracker, + source_agent="terraform-agent", + source_stage="deploy", + ) + + assert result["diagnosed"] is False + assert len(tracker.get_active_blockers()) == 1 + blocker = tracker.get_active_blockers()[0] + assert blocker.source_agent == "terraform-agent" + assert blocker.source_stage == "deploy" + + def test_qa_router_no_tracker_no_error(self, tmp_project): + from azext_prototype.ai.provider import AIResponse + from azext_prototype.stages.qa_router import route_error_to_qa + + qa = MagicMock() + qa.execute.return_value = AIResponse(content="", model="gpt-4o", usage={}) + + ctx = _make_context() + + # No escalation tracker — should not raise + result = route_error_to_qa( + "error", + "context", + qa, + ctx, + None, + lambda m: None, + escalation_tracker=None, + ) + + assert result["diagnosed"] is False + + @patch("azext_prototype.stages.qa_router._submit_knowledge") + def test_qa_router_diagnosed_no_blocker(self, mock_knowledge, tmp_project): + from azext_prototype.ai.provider import AIResponse + from azext_prototype.stages.qa_router import route_error_to_qa + + tracker = EscalationTracker(str(tmp_project)) + + qa = MagicMock() + qa.execute.return_value = AIResponse(content="Root cause: X", model="gpt-4o", usage={}) + + ctx = _make_context() + + result = route_error_to_qa( + "error", + "context", + qa, + ctx, + None, + lambda m: None, + escalation_tracker=tracker, + ) + + assert result["diagnosed"] is True + # No blocker should be recorded when QA diagnoses successfully + assert len(tracker.get_active_blockers()) == 0 + + def test_build_session_has_escalation_tracker(self, tmp_project): + from azext_prototype.agents.base import AgentContext + from azext_prototype.stages.build_session import BuildSession + + ctx = AgentContext( + project_config={"project": {"name": "test", "location": "eastus"}}, + project_dir=str(tmp_project), + ai_provider=MagicMock(), + ) + + registry = MagicMock() + registry.find_by_capability.return_value = [] + + with patch("azext_prototype.stages.build_session.ProjectConfig") as mock_config: + mock_config.return_value.load.return_value = None + mock_config.return_value.get.side_effect = lambda k, d=None: { + "project.iac_tool": "terraform", + "project.name": "test", + }.get(k, d) + mock_config.return_value.to_dict.return_value = { + "naming": {"strategy": "simple"}, + "project": {"name": "test"}, + } + session = BuildSession(ctx, registry) + + assert hasattr(session, "_escalation_tracker") + assert isinstance(session._escalation_tracker, EscalationTracker) + + def test_deploy_session_has_escalation_tracker(self, tmp_project): + from azext_prototype.agents.base import AgentContext + from azext_prototype.stages.deploy_session import DeploySession + from azext_prototype.stages.deploy_state import DeployState + + ctx = AgentContext( + project_config={"project": {"name": "test", "location": "eastus"}}, + project_dir=str(tmp_project), + ai_provider=MagicMock(), + ) + + registry = MagicMock() + registry.find_by_capability.return_value = [] + + with patch("azext_prototype.stages.deploy_session.ProjectConfig") as mock_config: + mock_config.return_value.load.return_value = None + mock_config.return_value.get.side_effect = lambda k, d=None: { + "project.iac_tool": "terraform", + }.get(k, d) + session = DeploySession(ctx, registry, deploy_state=DeployState(str(tmp_project))) + + assert hasattr(session, "_escalation_tracker") + assert isinstance(session._escalation_tracker, EscalationTracker) + + def test_backlog_session_has_escalation_tracker(self, tmp_project): + from azext_prototype.agents.base import AgentContext + from azext_prototype.stages.backlog_session import BacklogSession + from azext_prototype.stages.backlog_state import BacklogState + + ctx = AgentContext( + project_config={"project": {"name": "test", "location": "eastus"}}, + project_dir=str(tmp_project), + ai_provider=MagicMock(), + ) + + registry = MagicMock() + registry.find_by_capability.return_value = [] + + session = BacklogSession(ctx, registry, backlog_state=BacklogState(str(tmp_project))) + + assert hasattr(session, "_escalation_tracker") + assert isinstance(session._escalation_tracker, EscalationTracker) + +# ====================================================================== + + +class TestReportFormatting: + + def test_empty_report(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + report = tracker.format_escalation_report() + assert "No blockers recorded" in report + + def test_report_with_active_and_resolved(self, tmp_project): + tracker = EscalationTracker(str(tmp_project)) + e1 = tracker.record_blocker("Deploy Redis", "Premium needed", "tf", "build") # noqa: F841 + e2 = tracker.record_blocker("Deploy Cosmos", "Multi-region", "tf", "build") + tracker.resolve(e2, "Used single region") + + report = tracker.format_escalation_report() + + assert "Active Blockers (1)" in report + assert "Deploy Redis" in report + assert "Resolved (1)" in report + assert "Used single region" in report diff --git a/tests/test_intent.py b/tests/stages/test_intent.py similarity index 100% rename from tests/test_intent.py rename to tests/stages/test_intent.py diff --git a/tests/stages/test_knowledge_contributor.py b/tests/stages/test_knowledge_contributor.py new file mode 100644 index 0000000..f57e921 --- /dev/null +++ b/tests/stages/test_knowledge_contributor.py @@ -0,0 +1,598 @@ +"""Tests for knowledge_contributor — gap detection and contribution submission. + +Covers: +- Namespace-to-filename conversion +- Knowledge file path resolution (namespace lookup, friendly name fallback) +- Gap detection with fallbacks (missing files, namespace resolution, empty finding) +- Contribution formatting (title, body, new-service type promotion) +- Submission with label retry (auth check, label retry fallback, FileNotFoundError) +- QA finding builder +- Fire-and-forget wrapper (submit_if_gap) +""" + +from unittest.mock import MagicMock, patch + +_KC_MODULE = "azext_prototype.stages.knowledge_contributor" +_BP_MODULE = "azext_prototype.stages.backlog_push" +_CUSTOM_MODULE = "azext_prototype.custom" + +# ====================================================================== +# _namespace_to_filename +# ====================================================================== + + +class TestNamespaceToFilename: + """Test ARM namespace to knowledge filename conversion.""" + + def test_typical_namespace(self): + from azext_prototype.stages.knowledge_contributor import _namespace_to_filename + + assert _namespace_to_filename("Microsoft.Sql/servers/databases") == "sql-servers-databases" + + def test_container_apps(self): + from azext_prototype.stages.knowledge_contributor import _namespace_to_filename + + assert _namespace_to_filename("Microsoft.App/containerApps") == "app-containerapps" + + def test_empty_namespace(self): + from azext_prototype.stages.knowledge_contributor import _namespace_to_filename + + assert _namespace_to_filename("") == "unknown" + + def test_double_hyphens_cleaned(self): + from azext_prototype.stages.knowledge_contributor import _namespace_to_filename + + # Simulate edge case with consecutive separators + result = _namespace_to_filename("Microsoft..Foo//bar") + assert "--" not in result + + +# ====================================================================== +# _resolve_knowledge_file_path +# ====================================================================== + + +class TestResolveKnowledgeFilePath: + """Test file path resolution with namespace vs friendly name.""" + + def test_namespace_via_loader_index(self): + from azext_prototype.stages.knowledge_contributor import _resolve_knowledge_file_path + + mock_loader_cls = MagicMock() + mock_loader_cls.return_value._build_namespace_index.return_value = { + "Microsoft.Web/sites": "app-service.md", + } + with patch("azext_prototype.knowledge.KnowledgeLoader", mock_loader_cls): + result = _resolve_knowledge_file_path("Microsoft.Web/sites", "app-service") + assert result == "knowledge/services/app-service.md" + + def test_namespace_not_in_index_generates_from_namespace(self): + from azext_prototype.stages.knowledge_contributor import _resolve_knowledge_file_path + + mock_loader_cls = MagicMock() + mock_loader_cls.return_value._build_namespace_index.return_value = {} + with patch("azext_prototype.knowledge.KnowledgeLoader", mock_loader_cls): + result = _resolve_knowledge_file_path("Microsoft.NewService/items", "new-service") + assert result == "knowledge/services/newservice-items.md" + + def test_namespace_loader_import_fails(self): + """When KnowledgeLoader construction fails, still generates from namespace.""" + from azext_prototype.stages.knowledge_contributor import _resolve_knowledge_file_path + + with patch( + "azext_prototype.knowledge.KnowledgeLoader", + side_effect=RuntimeError("loader broken"), + ): + result = _resolve_knowledge_file_path("Microsoft.Storage/storageAccounts", "storage") + assert result == "knowledge/services/storage-storageaccounts.md" + + def test_friendly_name_fallback(self): + """When namespace is empty, falls back to friendly name.""" + from azext_prototype.stages.knowledge_contributor import _resolve_knowledge_file_path + + result = _resolve_knowledge_file_path("", "cosmos-db") + # Should contain the friendly name + assert "cosmos-db" in result + + def test_no_namespace_no_service(self): + from azext_prototype.stages.knowledge_contributor import _resolve_knowledge_file_path + + result = _resolve_knowledge_file_path("", "") + assert result == "knowledge/services/unknown.md" + + +# ====================================================================== +# check_knowledge_gap +# ====================================================================== + + +class TestCheckKnowledgeGap: + """Test gap detection logic.""" + + def test_empty_finding_returns_false(self): + from azext_prototype.stages.knowledge_contributor import check_knowledge_gap + + assert check_knowledge_gap({}, MagicMock()) is False + assert check_knowledge_gap(None, MagicMock()) is False + + def test_no_service_or_context_returns_false(self): + from azext_prototype.stages.knowledge_contributor import check_knowledge_gap + + assert check_knowledge_gap({"service": "cosmos-db"}, MagicMock()) is False + assert check_knowledge_gap({"context": "some error"}, MagicMock()) is False + + def test_no_existing_content_is_gap(self): + from azext_prototype.stages.knowledge_contributor import check_knowledge_gap + + loader = MagicMock() + loader.load_service.return_value = "" + finding = {"service": "cosmos-db", "context": "Missing retry logic for 429 errors"} + assert check_knowledge_gap(finding, loader) is True + + def test_loader_exception_treated_as_gap(self): + from azext_prototype.stages.knowledge_contributor import check_knowledge_gap + + loader = MagicMock() + loader.load_service.side_effect = FileNotFoundError("not found") + finding = {"service": "cosmos-db", "context": "Some new pitfall discovered"} + assert check_knowledge_gap(finding, loader) is True + + def test_context_already_covered_no_gap(self): + from azext_prototype.stages.knowledge_contributor import check_knowledge_gap + + loader = MagicMock() + loader.load_service.return_value = "Common issue: missing retry logic for 429 errors and throttling" + finding = {"service": "cosmos-db", "context": "Missing retry logic for 429 errors"} + assert check_knowledge_gap(finding, loader) is False + + def test_context_not_in_content_is_gap(self): + from azext_prototype.stages.knowledge_contributor import check_knowledge_gap + + loader = MagicMock() + loader.load_service.return_value = "This file covers connection strings only." + finding = {"service": "cosmos-db", "context": "Missing retry logic for 429 errors"} + assert check_knowledge_gap(finding, loader) is True + + def test_prefers_namespace_for_resolution(self): + from azext_prototype.stages.knowledge_contributor import check_knowledge_gap + + loader = MagicMock() + loader.load_service.return_value = "" + finding = { + "service_namespace": "Microsoft.DocumentDB/databaseAccounts", + "service": "cosmos-db", + "context": "Issue found", + } + check_knowledge_gap(finding, loader) + loader.load_service.assert_called_with("Microsoft.DocumentDB/databaseAccounts") + + def test_empty_context_snippet_returns_false(self): + from azext_prototype.stages.knowledge_contributor import check_knowledge_gap + + loader = MagicMock() + loader.load_service.return_value = "some content" + finding = {"service": "x", "context": " "} # whitespace only + assert check_knowledge_gap(finding, loader) is False + + +# ====================================================================== +# format_contribution_title +# ====================================================================== + + +class TestFormatContributionTitle: + """Test issue title formatting.""" + + def test_basic_title(self): + from azext_prototype.stages.knowledge_contributor import format_contribution_title + + finding = {"service": "cosmos-db", "context": "Short context"} + title = format_contribution_title(finding) + assert title == "[Knowledge] cosmos-db: Short context" + + def test_namespace_preferred(self): + from azext_prototype.stages.knowledge_contributor import format_contribution_title + + finding = { + "service_namespace": "Microsoft.DocumentDB/databaseAccounts", + "service": "cosmos-db", + "context": "Some issue", + } + title = format_contribution_title(finding) + assert "Microsoft.DocumentDB/databaseAccounts" in title + + def test_truncation_at_60_chars(self): + from azext_prototype.stages.knowledge_contributor import format_contribution_title + + finding = {"service": "x", "context": "A" * 100} + title = format_contribution_title(finding) + assert title.endswith("...") + # The context part should be 60 chars + ... + assert "A" * 60 in title + + def test_empty_context_fallback(self): + from azext_prototype.stages.knowledge_contributor import format_contribution_title + + finding = {"service": "x"} + title = format_contribution_title(finding) + assert "Knowledge contribution" in title + + +# ====================================================================== +# format_contribution_body +# ====================================================================== + + +class TestFormatContributionBody: + """Test issue body formatting.""" + + def test_basic_body_has_sections(self): + from azext_prototype.stages.knowledge_contributor import format_contribution_body + + finding = { + "service": "cosmos-db", + "service_namespace": "Microsoft.DocumentDB/databaseAccounts", + "context": "Missing retry", + "section": "Common Pitfalls", + "content": "Add retry for 429", + "source": "QA diagnosis", + } + body = format_contribution_body(finding) + assert "## Knowledge Contribution" in body + assert "### Context" in body + assert "### Rationale" in body + assert "### Content to Add" in body + assert "### Source" in body + assert "Common Pitfalls" in body + + def test_new_service_type_promotion(self): + """When file doesn't exist and type is Pitfall, promote to New service.""" + from azext_prototype.stages.knowledge_contributor import format_contribution_body + + finding = { + "type": "Pitfall", + "service": "brand-new", + "file": "knowledge/services/nonexistent.md", + "context": "New service info", + } + body = format_contribution_body(finding) + assert "**Type:** New service" in body + assert "### Required Knowledge File Sections" in body + assert "NEW FILE" in body + + def test_no_content_placeholder(self): + from azext_prototype.stages.knowledge_contributor import format_contribution_body + + finding = {"service": "x", "context": "some issue"} + body = format_contribution_body(finding) + assert "No specific content provided" in body + + +# ====================================================================== +# submit_contribution +# ====================================================================== + + +class TestSubmitContribution: + """Test issue submission with auth check and label retry.""" + + @patch("azext_prototype.stages.knowledge_contributor.format_contribution_body", return_value="body") + @patch("azext_prototype.stages.knowledge_contributor.format_contribution_title", return_value="title") + @patch("azext_prototype.stages.knowledge_contributor._run_gh_issue_create") + @patch("azext_prototype.stages.backlog_push.check_gh_auth", return_value=True) + @patch("azext_prototype.debug_log.log_flow") + def test_success_first_try(self, mock_log, mock_auth, mock_create, mock_title, mock_body): + from azext_prototype.stages.knowledge_contributor import submit_contribution + + mock_create.return_value = MagicMock(returncode=0, stdout="https://github.com/org/repo/issues/42\n") + result = submit_contribution({"service": "x", "context": "y"}) + assert result["url"] == "https://github.com/org/repo/issues/42" + assert result["number"] == "42" + + @patch("azext_prototype.stages.knowledge_contributor._run_gh_issue_create") + @patch("azext_prototype.stages.backlog_push.check_gh_auth", return_value=True) + @patch("azext_prototype.debug_log.log_flow") + def test_label_retry_fallback(self, mock_log, mock_auth, mock_create): + """First call fails (bad label), retry with fallback labels succeeds.""" + from azext_prototype.stages.knowledge_contributor import submit_contribution + + mock_create.side_effect = [ + MagicMock(returncode=1, stderr="label not found", stdout=""), + MagicMock(returncode=0, stdout="https://github.com/org/repo/issues/99\n"), + ] + result = submit_contribution({"service": "x", "context": "y"}) + assert result["url"] == "https://github.com/org/repo/issues/99" + assert mock_create.call_count == 2 + + @patch("azext_prototype.stages.knowledge_contributor._run_gh_issue_create") + @patch("azext_prototype.stages.backlog_push.check_gh_auth", return_value=True) + @patch("azext_prototype.debug_log.log_flow") + def test_both_attempts_fail(self, mock_log, mock_auth, mock_create): + from azext_prototype.stages.knowledge_contributor import submit_contribution + + mock_create.side_effect = [ + MagicMock(returncode=1, stderr="error1", stdout=""), + MagicMock(returncode=1, stderr="error2", stdout=""), + ] + result = submit_contribution({"service": "x", "context": "y"}) + assert "error" in result + + @patch("azext_prototype.stages.backlog_push.check_gh_auth", return_value=False) + def test_auth_check_fails(self, mock_auth): + from azext_prototype.stages.knowledge_contributor import submit_contribution + + result = submit_contribution({"service": "x", "context": "y"}) + assert "error" in result + assert "authenticated" in result["error"] + + @patch("azext_prototype.stages.knowledge_contributor._run_gh_issue_create", side_effect=FileNotFoundError) + @patch("azext_prototype.stages.backlog_push.check_gh_auth", return_value=True) + @patch("azext_prototype.debug_log.log_flow") + def test_gh_cli_not_found(self, mock_log, mock_auth, mock_create): + from azext_prototype.stages.knowledge_contributor import submit_contribution + + result = submit_contribution({"service": "x", "context": "y"}) + assert "error" in result + assert "gh CLI not found" in result["error"] + + @patch("azext_prototype.stages.knowledge_contributor._run_gh_issue_create") + @patch("azext_prototype.stages.backlog_push.check_gh_auth", return_value=True) + @patch("azext_prototype.debug_log.log_flow") + def test_type_label_mapping(self, mock_log, mock_auth, mock_create): + """Different contribution types map to correct labels.""" + from azext_prototype.stages.knowledge_contributor import submit_contribution + + mock_create.return_value = MagicMock(returncode=0, stdout="https://github.com/issues/1\n") + + for contrib_type, expected_label in [ + ("New service", "new-service"), + ("Tool pattern", "tool-pattern"), + ("Pitfall", "pitfall"), + ]: + submit_contribution({"service": "x", "context": "y", "type": contrib_type}) + call_args = mock_create.call_args + labels = call_args[0][3] if len(call_args[0]) > 3 else call_args[1].get("labels", []) + assert expected_label in labels + + +# ====================================================================== +# build_finding_from_qa +# ====================================================================== + + +class TestBuildFindingFromQa: + """Test QA finding builder.""" + + def test_basic_finding(self): + from azext_prototype.stages.knowledge_contributor import build_finding_from_qa + + finding = build_finding_from_qa( + "Error: timeout connecting to Cosmos DB", + service="cosmos-db", + service_namespace="Microsoft.DocumentDB/databaseAccounts", + section="Common Pitfalls", + ) + assert finding["service"] == "cosmos-db" + assert finding["service_namespace"] == "Microsoft.DocumentDB/databaseAccounts" + assert finding["section"] == "Common Pitfalls" + assert finding["type"] == "Pitfall" + assert "timeout" in finding["context"] + + def test_truncation(self): + from azext_prototype.stages.knowledge_contributor import build_finding_from_qa + + long_text = "A" * 1000 + finding = build_finding_from_qa(long_text) + assert len(finding["context"]) <= 500 + assert len(finding["content"]) <= 200 + + def test_empty_qa_content(self): + from azext_prototype.stages.knowledge_contributor import build_finding_from_qa + + finding = build_finding_from_qa("") + assert finding["context"] == "" + assert finding["content"] == "" + + +# ====================================================================== +# submit_if_gap +# ====================================================================== + + +class TestSubmitIfGap: + """Test fire-and-forget wrapper.""" + + @patch("azext_prototype.stages.knowledge_contributor.submit_contribution") + @patch("azext_prototype.stages.knowledge_contributor.check_knowledge_gap", return_value=True) + def test_gap_found_submits(self, mock_gap, mock_submit): + from azext_prototype.stages.knowledge_contributor import submit_if_gap + + mock_submit.return_value = {"url": "https://github.com/issues/1"} + printed = [] + result = submit_if_gap( + {"service": "x", "context": "y"}, + MagicMock(), + print_fn=printed.append, + ) + assert result["url"] == "https://github.com/issues/1" + assert any("submitted" in p for p in printed) + + @patch("azext_prototype.stages.knowledge_contributor.check_knowledge_gap", return_value=False) + def test_no_gap_returns_none(self, mock_gap): + from azext_prototype.stages.knowledge_contributor import submit_if_gap + + result = submit_if_gap({"service": "x", "context": "y"}, MagicMock()) + assert result is None + + @patch("azext_prototype.stages.knowledge_contributor.submit_contribution") + @patch("azext_prototype.stages.knowledge_contributor.check_knowledge_gap", return_value=True) + def test_submit_error_no_print(self, mock_gap, mock_submit): + from azext_prototype.stages.knowledge_contributor import submit_if_gap + + mock_submit.return_value = {"error": "auth failed"} + printed = [] + result = submit_if_gap( + {"service": "x", "context": "y"}, + MagicMock(), + print_fn=printed.append, + ) + assert result["error"] == "auth failed" + assert len(printed) == 0 + + @patch("azext_prototype.stages.knowledge_contributor.check_knowledge_gap", side_effect=RuntimeError("boom")) + def test_exception_caught_returns_none(self, mock_gap): + from azext_prototype.stages.knowledge_contributor import submit_if_gap + + result = submit_if_gap({"service": "x", "context": "y"}, MagicMock()) + assert result is None + +# --- Additional imports from merged flat test --- +import pytest + + +# ====================================================================== +# Helpers +# ====================================================================== + + +def _make_finding(**overrides) -> dict: + """Create a minimal finding dict with optional overrides.""" + finding = { + "service": "cosmos-db", + "type": "Pitfall", + "file": "knowledge/services/cosmos-db.md", + "section": "Terraform Patterns", + "context": "RU throughput must be set to at least 400 for serverless", + "rationale": "Setting below 400 causes deployment failure", + "content": "minimum_throughput = 400", + "source": "QA diagnosis", + } + finding.update(overrides) + return finding + + +def _make_loader(service_content: str = "") -> MagicMock: + """Create a mock KnowledgeLoader that returns *service_content*.""" + loader = MagicMock() + loader.load_service.return_value = service_content + return loader + + +# ====================================================================== +# TestKnowledgeContributeCommand +# ====================================================================== + + +class TestKnowledgeContributeCommand: + """Tests for ``prototype_knowledge_contribute()`` CLI command.""" + + def test_draft_mode(self, project_with_config): + from azext_prototype.custom import prototype_knowledge_contribute + + cmd = MagicMock() + with patch(f"{_CUSTOM_MODULE}._get_project_dir", return_value=str(project_with_config)): + result = prototype_knowledge_contribute( + cmd, + service="cosmos-db", + description="RU throughput must be >= 400", + draft=True, + json_output=True, + ) + + assert result["status"] == "draft" + assert "cosmos-db" in result["title"] + + def test_noninteractive_submit(self, project_with_config): + from azext_prototype.custom import prototype_knowledge_contribute + + cmd = MagicMock() + with patch(f"{_CUSTOM_MODULE}._get_project_dir", return_value=str(project_with_config)), patch( + f"{_BP_MODULE}.subprocess.run" + ) as mock_auth, patch(f"{_KC_MODULE}.subprocess.run") as mock_create: + mock_auth.return_value = MagicMock(returncode=0) + mock_create.return_value = MagicMock( + returncode=0, + stdout="https://github.com/Azure/az-prototype/issues/55\n", + ) + + result = prototype_knowledge_contribute( + cmd, + service="cosmos-db", + description="RU throughput must be >= 400", + json_output=True, + ) + + assert result["status"] == "submitted" + assert result["url"] == "https://github.com/Azure/az-prototype/issues/55" + + def test_gh_not_authed_raises(self, project_with_config): + from knack.util import CLIError + + from azext_prototype.custom import prototype_knowledge_contribute + + cmd = MagicMock() + with patch(f"{_CUSTOM_MODULE}._get_project_dir", return_value=str(project_with_config)), patch( + f"{_BP_MODULE}.subprocess.run" + ) as mock_auth: + mock_auth.return_value = MagicMock(returncode=1) + + with pytest.raises(CLIError, match="not authenticated"): + prototype_knowledge_contribute( + cmd, + service="cosmos-db", + description="RU throughput", + ) + + def test_file_input(self, project_with_config): + from azext_prototype.custom import prototype_knowledge_contribute + + # Create a finding file + finding_file = project_with_config / "finding.md" + finding_file.write_text( + "Service: cosmos-db\nContext: RU must be >= 400\nContent: min_ru = 400", + encoding="utf-8", + ) + + cmd = MagicMock() + with patch(f"{_CUSTOM_MODULE}._get_project_dir", return_value=str(project_with_config)): + result = prototype_knowledge_contribute( + cmd, + file=str(finding_file), + draft=True, + json_output=True, + ) + + assert result["status"] == "draft" + + def test_file_not_found_raises(self, project_with_config): + from knack.util import CLIError + + from azext_prototype.custom import prototype_knowledge_contribute + + cmd = MagicMock() + with patch(f"{_CUSTOM_MODULE}._get_project_dir", return_value=str(project_with_config)): + with pytest.raises(CLIError, match="not found"): + prototype_knowledge_contribute( + cmd, + file="/nonexistent/path/finding.md", + draft=True, + ) + + def test_contribution_type_forwarded(self, project_with_config): + from azext_prototype.custom import prototype_knowledge_contribute + + cmd = MagicMock() + with patch(f"{_CUSTOM_MODULE}._get_project_dir", return_value=str(project_with_config)): + result = prototype_knowledge_contribute( + cmd, + service="redis", + description="Cache eviction pitfall", + contribution_type="Service pattern update", + section="Pitfalls", + draft=True, + json_output=True, + ) + + assert result["status"] == "draft" + assert "Service pattern update" in result["body"] + assert "Pitfalls" in result["body"] diff --git a/tests/stages/test_policy_resolver.py b/tests/stages/test_policy_resolver.py new file mode 100644 index 0000000..b52b459 --- /dev/null +++ b/tests/stages/test_policy_resolver.py @@ -0,0 +1,480 @@ +"""Tests for PolicyResolver — 3-way policy violation resolution. + +Tier 2: Conditional branches with multiple paths. + +Covers: +- No violations → early return ([], False) +- Auto-accept mode → all violations auto-accepted +- Interactive accept path (default choice) +- Override path with justification (provided and empty/default) +- Regenerate path → needs_regen=True +- Mixed resolution paths in a single check +- build_fix_instructions() generation with regen-only and mixed items +- _extract_rule_id() with bracketed prefix and fallback +- build_state.add_policy_check/add_policy_override calls +""" + +from unittest.mock import MagicMock + +import pytest + +from azext_prototype.stages.policy_resolver import PolicyResolution, PolicyResolver + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def mock_governance(): + """GovernanceContext mock with configurable violations.""" + gov = MagicMock() + gov.check_response_for_violations.return_value = [] + return gov + + +@pytest.fixture +def mock_build_state(): + """BuildState mock that records policy checks and overrides.""" + state = MagicMock() + state.add_policy_check = MagicMock() + state.add_policy_override = MagicMock() + return state + + +@pytest.fixture +def resolver(mock_governance): + """Standard interactive PolicyResolver.""" + return PolicyResolver( + console=MagicMock(), + prompt=MagicMock(), + governance_context=mock_governance, + auto_accept=False, + ) + + +@pytest.fixture +def auto_resolver(mock_governance): + """PolicyResolver with auto_accept=True.""" + return PolicyResolver( + console=MagicMock(), + prompt=MagicMock(), + governance_context=mock_governance, + auto_accept=True, + ) + + +# ------------------------------------------------------------------ +# No violations — early return +# ------------------------------------------------------------------ + + +class TestNoViolations: + def test_no_violations_returns_empty(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = [] + resolutions, needs_regen = resolver.check_and_resolve( + "terraform-agent", + "resource aws_s3_bucket {}", + mock_build_state, + 1, + print_fn=MagicMock(), + ) + assert resolutions == [] + assert needs_regen is False + + def test_no_violations_does_not_call_build_state(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = [] + resolver.check_and_resolve( + "bicep-agent", + "resource storageAccount 'Microsoft.Storage/storageAccounts@2023-01-01' = {}", + mock_build_state, + 2, + print_fn=MagicMock(), + ) + mock_build_state.add_policy_check.assert_not_called() + mock_build_state.add_policy_override.assert_not_called() + + +# ------------------------------------------------------------------ +# Auto-accept mode +# ------------------------------------------------------------------ + + +class TestAutoAccept: + def test_auto_accept_all_violations(self, auto_resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = [ + "[managed-identity] Use managed identity instead of keys", + "[tls-version] Enforce TLS 1.2", + ] + printed = [] + resolutions, needs_regen = auto_resolver.check_and_resolve( + "terraform-agent", + "some content", + mock_build_state, + 1, + print_fn=printed.append, + ) + assert len(resolutions) == 2 + assert all(r.action == "accept" for r in resolutions) + assert needs_regen is False + + def test_auto_accept_extracts_rule_ids(self, auto_resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = [ + "[managed-identity] Use managed identity", + "[tls-version] Enforce TLS 1.2", + ] + resolutions, _ = auto_resolver.check_and_resolve( + "bicep-agent", "content", mock_build_state, 1, print_fn=MagicMock() + ) + assert resolutions[0].rule_id == "managed-identity" + assert resolutions[1].rule_id == "tls-version" + + def test_auto_accept_records_policy_check(self, auto_resolver, mock_governance, mock_build_state): + violations = ["[sec-001] No public endpoints"] + mock_governance.check_response_for_violations.return_value = violations + auto_resolver.check_and_resolve("terraform-agent", "code", mock_build_state, 3, print_fn=MagicMock()) + mock_build_state.add_policy_check.assert_called_once_with( + 3, + violations=violations, + overrides=[], + ) + + def test_auto_accept_does_not_call_override(self, auto_resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[rule-x] violation"] + auto_resolver.check_and_resolve("terraform-agent", "code", mock_build_state, 1, print_fn=MagicMock()) + mock_build_state.add_policy_override.assert_not_called() + + def test_auto_accept_prints_auto_accepted_message(self, auto_resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[rule-x] violation"] + printed = [] + auto_resolver.check_and_resolve("terraform-agent", "code", mock_build_state, 1, print_fn=printed.append) + assert any("Auto-accepted" in msg for msg in printed) + + +# ------------------------------------------------------------------ +# Interactive — Accept (default path) +# ------------------------------------------------------------------ + + +class TestInteractiveAccept: + def test_accept_explicit_a(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[rule-1] issue"] + resolutions, needs_regen = resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + input_fn=lambda _: "a", + print_fn=MagicMock(), + ) + assert len(resolutions) == 1 + assert resolutions[0].action == "accept" + assert needs_regen is False + + def test_accept_default_empty_input(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[rule-1] issue"] + resolutions, _ = resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + input_fn=lambda _: "", + print_fn=MagicMock(), + ) + assert resolutions[0].action == "accept" + + def test_accept_unknown_input_defaults_to_accept(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[rule-1] issue"] + resolutions, _ = resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + input_fn=lambda _: "xyz", + print_fn=MagicMock(), + ) + assert resolutions[0].action == "accept" + + def test_accept_prints_accepted_message(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[rule-1] issue"] + printed = [] + resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + input_fn=lambda _: "a", + print_fn=printed.append, + ) + assert any("Accepted compliant recommendation" in msg for msg in printed) + + +# ------------------------------------------------------------------ +# Interactive — Override path +# ------------------------------------------------------------------ + + +class TestInteractiveOverride: + def test_override_with_justification(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[managed-identity] use MI"] + inputs = iter(["o", "Legacy system requires key auth"]) + resolutions, needs_regen = resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 2, + input_fn=lambda _: next(inputs), + print_fn=MagicMock(), + ) + assert len(resolutions) == 1 + assert resolutions[0].action == "override" + assert resolutions[0].justification == "Legacy system requires key auth" + assert needs_regen is False + + def test_override_word_form(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[rule-1] issue"] + inputs = iter(["override", "needed"]) + resolutions, _ = resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + input_fn=lambda _: next(inputs), + print_fn=MagicMock(), + ) + assert resolutions[0].action == "override" + + def test_override_empty_justification_uses_default(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[rule-1] issue"] + inputs = iter(["o", ""]) + resolutions, _ = resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + input_fn=lambda _: next(inputs), + print_fn=MagicMock(), + ) + assert resolutions[0].action == "override" + assert resolutions[0].justification == "User chose to override" + + def test_override_calls_build_state_add_policy_override(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[sec-001] issue"] + inputs = iter(["o", "Approved by security team"]) + resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + input_fn=lambda _: next(inputs), + print_fn=MagicMock(), + ) + mock_build_state.add_policy_override.assert_called_once_with("sec-001", "Approved by security team") + + def test_override_recorded_in_policy_check_overrides(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[sec-001] issue"] + inputs = iter(["o", "Approved"]) + resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 5, + input_fn=lambda _: next(inputs), + print_fn=MagicMock(), + ) + mock_build_state.add_policy_check.assert_called_once() + args = mock_build_state.add_policy_check.call_args + assert ( + args.kwargs.get("overrides") + or args[1].get("overrides") + or [d for d in (args[1] if len(args) > 1 else []) if isinstance(d, list)] + ) + # Verify via the call — overrides list should have one item + call_args = mock_build_state.add_policy_check.call_args + overrides_arg = call_args[1]["overrides"] if "overrides" in call_args[1] else call_args[0][2] + assert len(overrides_arg) == 1 + assert overrides_arg[0]["rule_id"] == "sec-001" + + +# ------------------------------------------------------------------ +# Interactive — Regenerate path +# ------------------------------------------------------------------ + + +class TestInteractiveRegenerate: + def test_regenerate_sets_needs_regen_true(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[rule-1] issue"] + resolutions, needs_regen = resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + input_fn=lambda _: "r", + print_fn=MagicMock(), + ) + assert needs_regen is True + assert resolutions[0].action == "regenerate" + + def test_regenerate_word_form(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[rule-1] issue"] + resolutions, needs_regen = resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + input_fn=lambda _: "regenerate", + print_fn=MagicMock(), + ) + assert needs_regen is True + assert resolutions[0].action == "regenerate" + + def test_regenerate_prints_will_regenerate_message(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = ["[rule-1] issue"] + printed = [] + resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + input_fn=lambda _: "r", + print_fn=printed.append, + ) + assert any("regenerate" in msg.lower() for msg in printed) + + +# ------------------------------------------------------------------ +# Mixed resolutions in a single check +# ------------------------------------------------------------------ + + +class TestMixedResolutions: + def test_mixed_accept_override_regenerate(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = [ + "[rule-a] first issue", + "[rule-b] second issue", + "[rule-c] third issue", + ] + # First: accept, Second: override with justification, Third: regenerate + inputs = iter(["a", "o", "Because reasons", "r"]) + resolutions, needs_regen = resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + input_fn=lambda _: next(inputs), + print_fn=MagicMock(), + ) + assert len(resolutions) == 3 + assert resolutions[0].action == "accept" + assert resolutions[1].action == "override" + assert resolutions[1].justification == "Because reasons" + assert resolutions[2].action == "regenerate" + assert needs_regen is True + + def test_mixed_only_override_recorded_in_overrides(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = [ + "[rule-a] first", + "[rule-b] second", + ] + inputs = iter(["a", "o", "justified"]) + resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + input_fn=lambda _: next(inputs), + print_fn=MagicMock(), + ) + call_args = mock_build_state.add_policy_check.call_args + overrides = call_args[1]["overrides"] if "overrides" in call_args[1] else call_args[0][2] + assert len(overrides) == 1 + assert overrides[0]["rule_id"] == "rule-b" + + +# ------------------------------------------------------------------ +# build_fix_instructions() +# ------------------------------------------------------------------ + + +class TestBuildFixInstructions: + def test_no_regen_items_returns_empty(self, resolver): + resolutions = [ + PolicyResolution(rule_id="r1", action="accept", violation_text="v1"), + PolicyResolution(rule_id="r2", action="override", justification="ok", violation_text="v2"), + ] + assert resolver.build_fix_instructions(resolutions) == "" + + def test_regen_items_produces_fix_block(self, resolver): + resolutions = [ + PolicyResolution(rule_id="r1", action="regenerate", violation_text="missing MI auth"), + ] + result = resolver.build_fix_instructions(resolutions) + assert "## Policy Fix Instructions" in result + assert "missing MI auth" in result + assert "Fix these violations" in result + + def test_regen_with_overrides_includes_override_section(self, resolver): + resolutions = [ + PolicyResolution(rule_id="r1", action="regenerate", violation_text="issue A"), + PolicyResolution(rule_id="r2", action="override", justification="approved", violation_text="issue B"), + ] + result = resolver.build_fix_instructions(resolutions) + assert "## Policy Fix Instructions" in result + assert "issue A" in result + assert "overridden by the user" in result + assert "r2: approved" in result + + def test_multiple_regen_items(self, resolver): + resolutions = [ + PolicyResolution(rule_id="r1", action="regenerate", violation_text="A"), + PolicyResolution(rule_id="r2", action="regenerate", violation_text="B"), + ] + result = resolver.build_fix_instructions(resolutions) + assert "- A" in result + assert "- B" in result + + +# ------------------------------------------------------------------ +# _extract_rule_id() +# ------------------------------------------------------------------ + + +class TestExtractRuleId: + def test_bracketed_prefix(self): + assert PolicyResolver._extract_rule_id("[managed-identity] Use MI") == "managed-identity" + + def test_no_brackets_returns_unknown(self): + assert PolicyResolver._extract_rule_id("No brackets here") == "unknown" + + def test_empty_brackets_returns_empty_string(self): + # "[]" has end=1 > 0, so it extracts text[1:1] == "" + assert PolicyResolver._extract_rule_id("[] empty bracket") == "" + + def test_starts_with_bracket_no_close(self): + assert PolicyResolver._extract_rule_id("[no-close-bracket") == "unknown" + + def test_nested_brackets_takes_first(self): + assert PolicyResolver._extract_rule_id("[outer] some [inner] text") == "outer" + + +# ------------------------------------------------------------------ +# iac_tool parameter forwarding +# ------------------------------------------------------------------ + + +class TestIacToolForwarding: + def test_iac_tool_passed_to_governance(self, resolver, mock_governance, mock_build_state): + mock_governance.check_response_for_violations.return_value = [] + resolver.check_and_resolve( + "terraform-agent", + "code", + mock_build_state, + 1, + print_fn=MagicMock(), + iac_tool="terraform", + ) + mock_governance.check_response_for_violations.assert_called_once_with( + "terraform-agent", + "code", + iac_tool="terraform", + ) diff --git a/tests/test_qa_router.py b/tests/stages/test_qa_router.py similarity index 56% rename from tests/test_qa_router.py rename to tests/stages/test_qa_router.py index 5fa8a03..22abc14 100644 --- a/tests/test_qa_router.py +++ b/tests/stages/test_qa_router.py @@ -1,24 +1,521 @@ -"""Tests for azext_prototype.stages.qa_router — shared QA error routing.""" - from __future__ import annotations +"""Tests for route_error_to_qa() — QA error routing. + +Tier 2: Conditional branches with multiple paths. + +Covers: +- QA agent None -> early return +- agent_context None -> early return +- agent_context.ai_provider None -> early return +- QA agent executes successfully -> diagnosed=True +- QA agent returns empty content -> diagnosed=False +- QA agent returns None -> diagnosed=False +- QA agent raises exception -> diagnosed=False +- Token tracking when tracker provided +- Token tracking exception swallowed +- Knowledge contribution fire-and-forget (success + failure) +- Blocker recording when QA can't diagnose + escalation tracker present +- Blocker recording exception swallowed +- Error text truncation (max_error_chars) +- Display text truncation (max_display_chars) +- None/empty error handling +""" + from unittest.mock import MagicMock, patch import pytest +from azext_prototype.stages.qa_router import route_error_to_qa + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def qa_agent(): + """Mock QA agent with successful response.""" + agent = MagicMock() + agent.execute.return_value = MagicMock(content="Root cause: misconfigured SKU. Fix: use B1.") + return agent + + +@pytest.fixture +def agent_context(): + """Mock AgentContext with AI provider.""" + ctx = MagicMock() + ctx.ai_provider = MagicMock() + return ctx + + +@pytest.fixture +def token_tracker(): + """Mock TokenTracker.""" + return MagicMock() + + +@pytest.fixture +def escalation_tracker(): + """Mock EscalationTracker.""" + return MagicMock() + + +# ------------------------------------------------------------------ +# Early returns — no QA agent / no context / no provider +# ------------------------------------------------------------------ + + +class TestEarlyReturns: + def test_qa_agent_none(self, agent_context): + result = route_error_to_qa( + "Error occurred", + "Build Stage 3", + qa_agent=None, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + ) + assert result["diagnosed"] is False + assert "Error occurred" in result["content"] + assert result["response"] is None + + def test_agent_context_none(self, qa_agent): + result = route_error_to_qa( + "Error", + "Build Stage 1", + qa_agent=qa_agent, + agent_context=None, + token_tracker=None, + print_fn=MagicMock(), + ) + assert result["diagnosed"] is False + assert result["response"] is None + + def test_ai_provider_none(self, qa_agent): + ctx = MagicMock() + ctx.ai_provider = None + result = route_error_to_qa( + "Error", + "Deploy", + qa_agent=qa_agent, + agent_context=ctx, + token_tracker=None, + print_fn=MagicMock(), + ) + assert result["diagnosed"] is False + + def test_all_none(self): + result = route_error_to_qa( + "Error", + "context", + qa_agent=None, + agent_context=None, + token_tracker=None, + print_fn=MagicMock(), + ) + assert result["diagnosed"] is False + + +# ------------------------------------------------------------------ +# Successful QA diagnosis +# ------------------------------------------------------------------ + + +class TestSuccessfulDiagnosis: + def test_diagnosed_true(self, qa_agent, agent_context): + result = route_error_to_qa( + "Terraform apply failed", + "Build Stage 2", + qa_agent=qa_agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + ) + assert result["diagnosed"] is True + assert "Root cause" in result["content"] + assert result["response"] is not None + + def test_prints_qa_diagnosis(self, qa_agent, agent_context): + printed = [] + route_error_to_qa( + "Error", + "Build", + qa_agent=qa_agent, + agent_context=agent_context, + token_tracker=None, + print_fn=printed.append, + ) + assert any("QA Diagnosis" in msg for msg in printed) + + def test_exception_error_converted_to_string(self, qa_agent, agent_context): + route_error_to_qa( + RuntimeError("connection timeout"), + "Deploy Stage 1", + qa_agent=qa_agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + ) + task_arg = qa_agent.execute.call_args[0][1] + assert "connection timeout" in task_arg + + def test_services_kwarg_forwarded(self, qa_agent, agent_context): + with patch("azext_prototype.stages.qa_router._submit_knowledge") as mock_submit: + route_error_to_qa( + "Error", + "Build", + qa_agent=qa_agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + services=["key-vault", "cosmos-db"], + ) + mock_submit.assert_called_once() + _, _, services_arg, _ = mock_submit.call_args[0] + assert services_arg == ["key-vault", "cosmos-db"] + + +# ------------------------------------------------------------------ +# QA agent failures +# ------------------------------------------------------------------ + + +class TestQAFailures: + def test_qa_agent_raises_exception(self, agent_context): + bad_agent = MagicMock() + bad_agent.execute.side_effect = RuntimeError("model overloaded") + result = route_error_to_qa( + "Error", + "Build", + qa_agent=bad_agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + ) + assert result["diagnosed"] is False + assert result["response"] is None + + def test_qa_returns_none(self, agent_context): + agent = MagicMock() + agent.execute.return_value = None + result = route_error_to_qa( + "Error", + "Build", + qa_agent=agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + ) + assert result["diagnosed"] is False + + def test_qa_returns_empty_content(self, agent_context): + agent = MagicMock() + agent.execute.return_value = MagicMock(content="") + result = route_error_to_qa( + "Error", + "Build", + qa_agent=agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + ) + assert result["diagnosed"] is False + assert result["response"] is not None + + def test_qa_returns_none_content(self, agent_context): + agent = MagicMock() + agent.execute.return_value = MagicMock(content=None) + result = route_error_to_qa( + "Error", + "Build", + qa_agent=agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + ) + assert result["diagnosed"] is False + + +# ------------------------------------------------------------------ +# Token tracking +# ------------------------------------------------------------------ + + +class TestTokenTracking: + def test_tokens_recorded_on_success(self, qa_agent, agent_context, token_tracker): + route_error_to_qa( + "Error", + "Build", + qa_agent=qa_agent, + agent_context=agent_context, + token_tracker=token_tracker, + print_fn=MagicMock(), + ) + token_tracker.record.assert_called_once_with(qa_agent.execute.return_value) + + def test_no_token_tracker_no_error(self, qa_agent, agent_context): + # Should not raise even when token_tracker is None + result = route_error_to_qa( + "Error", + "Build", + qa_agent=qa_agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + ) + assert result["diagnosed"] is True + + def test_token_tracker_exception_swallowed(self, qa_agent, agent_context): + bad_tracker = MagicMock() + bad_tracker.record.side_effect = RuntimeError("tracker broken") + result = route_error_to_qa( + "Error", + "Build", + qa_agent=qa_agent, + agent_context=agent_context, + token_tracker=bad_tracker, + print_fn=MagicMock(), + ) + assert result["diagnosed"] is True # Should still succeed + + def test_tokens_not_recorded_when_response_is_none(self, agent_context, token_tracker): + agent = MagicMock() + agent.execute.return_value = None + route_error_to_qa( + "Error", + "Build", + qa_agent=agent, + agent_context=agent_context, + token_tracker=token_tracker, + print_fn=MagicMock(), + ) + token_tracker.record.assert_not_called() + + +# ------------------------------------------------------------------ +# Knowledge contribution (fire-and-forget) +# ------------------------------------------------------------------ + + +class TestKnowledgeContribution: + def test_knowledge_submitted_on_success(self, qa_agent, agent_context): + with patch("azext_prototype.stages.qa_router._submit_knowledge") as mock_submit: + route_error_to_qa( + "Error", + "Build Stage 3", + qa_agent=qa_agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + services=["cosmos-db"], + ) + mock_submit.assert_called_once() + + def test_knowledge_exception_swallowed(self, qa_agent, agent_context): + with patch( + "azext_prototype.stages.qa_router._submit_knowledge", + side_effect=RuntimeError("GitHub down"), + ): + result = route_error_to_qa( + "Error", + "Build", + qa_agent=qa_agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + ) + # Should still return successfully + assert result["diagnosed"] is True + + +# ------------------------------------------------------------------ +# Blocker recording (escalation tracker) +# ------------------------------------------------------------------ + + +class TestBlockerRecording: + def test_blocker_recorded_when_qa_cant_diagnose(self, agent_context, escalation_tracker): + agent = MagicMock() + agent.execute.return_value = MagicMock(content="") + + route_error_to_qa( + "Deployment failed: quota exceeded", + "Deploy Stage 1", + qa_agent=agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + escalation_tracker=escalation_tracker, + source_agent="terraform-agent", + source_stage="deploy", + ) + escalation_tracker.record_blocker.assert_called_once() + call_args = escalation_tracker.record_blocker.call_args + assert call_args[0][0] == "Deploy Stage 1" + assert "quota exceeded" in call_args[0][1] + assert call_args[1]["source_agent"] == "terraform-agent" + assert call_args[1]["source_stage"] == "deploy" + + def test_default_source_agent_is_qa_engineer(self, agent_context, escalation_tracker): + agent = MagicMock() + agent.execute.return_value = MagicMock(content="") + + route_error_to_qa( + "Error", + "Build", + qa_agent=agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + escalation_tracker=escalation_tracker, + ) + call_args = escalation_tracker.record_blocker.call_args + assert call_args[1]["source_agent"] == "qa-engineer" + + def test_no_blocker_when_no_tracker(self, agent_context): + agent = MagicMock() + agent.execute.return_value = MagicMock(content="") + # Should not raise + result = route_error_to_qa( + "Error", + "Build", + qa_agent=agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + escalation_tracker=None, + ) + assert result["diagnosed"] is False + + def test_blocker_not_recorded_on_success(self, qa_agent, agent_context, escalation_tracker): + route_error_to_qa( + "Error", + "Build", + qa_agent=qa_agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + escalation_tracker=escalation_tracker, + ) + escalation_tracker.record_blocker.assert_not_called() + + def test_blocker_recording_exception_swallowed(self, agent_context): + agent = MagicMock() + agent.execute.return_value = MagicMock(content="") + + bad_tracker = MagicMock() + bad_tracker.record_blocker.side_effect = RuntimeError("disk full") + + result = route_error_to_qa( + "Error", + "Build", + qa_agent=agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + escalation_tracker=bad_tracker, + ) + assert result["diagnosed"] is False # Still returns gracefully + + +# ------------------------------------------------------------------ +# Error text handling +# ------------------------------------------------------------------ + + +class TestErrorTextHandling: + def test_error_text_truncated(self, agent_context): + long_error = "x" * 5000 + result = route_error_to_qa( + long_error, + "Build", + qa_agent=None, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + max_error_chars=100, + ) + assert len(result["content"]) == 100 + + def test_none_error_becomes_unknown(self, agent_context): + result = route_error_to_qa( + None, + "Build", + qa_agent=None, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + ) + assert result["content"] == "Unknown error" + + def test_empty_string_error_becomes_unknown(self, agent_context): + result = route_error_to_qa( + "", + "Build", + qa_agent=None, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + ) + assert result["content"] == "Unknown error" + + def test_display_truncated(self, agent_context): + long_content = "R" * 3000 + agent = MagicMock() + agent.execute.return_value = MagicMock(content=long_content) + printed = [] + route_error_to_qa( + "Error", + "Build", + qa_agent=agent, + agent_context=agent_context, + token_tracker=None, + print_fn=printed.append, + max_display_chars=500, + ) + # The displayed content should be truncated + display_lines = [p for p in printed if p and "QA Diagnosis" not in p and p.strip()] + if display_lines: + assert len(display_lines[0]) <= 500 + + def test_custom_max_error_chars(self, qa_agent, agent_context): + long_error = "E" * 5000 + route_error_to_qa( + long_error, + "Build", + qa_agent=qa_agent, + agent_context=agent_context, + token_tracker=None, + print_fn=MagicMock(), + max_error_chars=50, + ) + task_arg = qa_agent.execute.call_args[0][1] + # The error text in the task should be truncated to 50 chars + assert "E" * 50 in task_arg + assert "E" * 51 not in task_arg + + +# --- Additional imports from merged flat test --- +from azext_prototype.agents.base import AgentCapability from azext_prototype.agents.base import AgentContext from azext_prototype.ai.provider import AIResponse -from azext_prototype.stages.qa_router import route_error_to_qa +from azext_prototype.stages.backlog_session import BacklogSession +from azext_prototype.stages.backlog_state import BacklogState +from azext_prototype.stages.build_session import BuildSession +from azext_prototype.stages.build_state import BuildState +from azext_prototype.stages.deploy_session import DeploySession +from azext_prototype.stages.deploy_state import DeployState +from azext_prototype.stages.discovery import DiscoverySession +import json + -# ====================================================================== -# Helpers # ====================================================================== def _make_response(content: str = "Root cause: X. Fix: do Y.") -> AIResponse: return AIResponse(content=content, model="gpt-4o", usage={}) - def _make_qa_agent(response: AIResponse | None = None, raises: Exception | None = None): agent = MagicMock() agent.name = "qa-engineer" @@ -28,7 +525,6 @@ def _make_qa_agent(response: AIResponse | None = None, raises: Exception | None agent.execute.return_value = response or _make_response() return agent - def _make_context(): return AgentContext( project_config={"project": {"name": "test"}}, @@ -36,14 +532,10 @@ def _make_context(): ai_provider=MagicMock(), ) - def _make_tracker(): tracker = MagicMock() return tracker - -# ====================================================================== -# Core routing tests # ====================================================================== @@ -391,9 +883,6 @@ def test_token_tracker_record_failure_swallowed(self): assert result["diagnosed"] is True - -# ====================================================================== -# Integration: Build session QA routing # ====================================================================== @@ -499,9 +988,6 @@ def test_empty_response_routes_to_qa(self, mock_knowledge, tmp_project): # QA should be called for empty response qa.execute.assert_called() - -# ====================================================================== -# Integration: Discovery session QA routing # ====================================================================== @@ -554,9 +1040,6 @@ def find_by_cap(cap): # QA should have been called for the error diagnosis qa.execute.assert_called_once() - -# ====================================================================== -# Integration: Backlog session QA routing # ====================================================================== @@ -637,9 +1120,6 @@ def test_push_error_triggers_qa(self, mock_push, mock_auth, mock_knowledge, tmp_ qa.execute.assert_called() - -# ====================================================================== -# Integration: Deploy session refactored QA routing # ====================================================================== diff --git a/tests/test_stages.py b/tests/stages/test_stages.py similarity index 72% rename from tests/test_stages.py rename to tests/stages/test_stages.py index 299b59a..7d13a21 100644 --- a/tests/test_stages.py +++ b/tests/stages/test_stages.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock, patch +import pytest + from azext_prototype.stages.base import StageGuard, StageState from azext_prototype.stages.guards import ( _check_az_logged_in, @@ -177,212 +179,10 @@ def test_init_stage_has_guards(self): assert len(guards) == 0 -class TestDeployStage: - """Test the deploy stage.""" - - def test_deploy_stage_instantiates(self): - from azext_prototype.stages.deploy_stage import DeployStage - - stage = DeployStage() - assert stage is not None - assert stage.name == "deploy" - - def test_deploy_stage_has_execute(self): - from azext_prototype.stages.deploy_stage import DeployStage - - stage = DeployStage() - assert hasattr(stage, "execute") - assert callable(stage.execute) - - -class TestDeployBicepStaging: - """Test Bicep staged deployment capabilities (via deploy_helpers).""" - - def test_find_bicep_params_main_parameters_json(self, tmp_path): - from azext_prototype.stages.deploy_helpers import find_bicep_params - - template = tmp_path / "main.bicep" - template.write_text("resource rg 'Microsoft.Resources/resourceGroups@2023-07-01' = {}") - params = tmp_path / "main.parameters.json" - params.write_text('{"parameters": {"location": {"value": "eastus"}}}') - - result = find_bicep_params(tmp_path, template) - assert result == params - - def test_find_bicep_params_bicepparam(self, tmp_path): - from azext_prototype.stages.deploy_helpers import find_bicep_params - - template = tmp_path / "main.bicep" - template.write_text("") - bp = tmp_path / "main.bicepparam" - bp.write_text("using './main.bicep'\nparam location = 'eastus'") - - result = find_bicep_params(tmp_path, template) - assert result == bp - - def test_find_bicep_params_fallback_parameters_json(self, tmp_path): - from azext_prototype.stages.deploy_helpers import find_bicep_params - - template = tmp_path / "network.bicep" - template.write_text("") - params = tmp_path / "parameters.json" - params.write_text('{"parameters": {}}') - - result = find_bicep_params(tmp_path, template) - assert result == params - - def test_find_bicep_params_none_when_missing(self, tmp_path): - from azext_prototype.stages.deploy_helpers import find_bicep_params - - template = tmp_path / "main.bicep" - template.write_text("") - - result = find_bicep_params(tmp_path, template) - assert result is None - - def test_is_subscription_scoped_true(self, tmp_path): - from azext_prototype.stages.deploy_helpers import is_subscription_scoped - - bicep_file = tmp_path / "main.bicep" - bicep_file.write_text( - "targetScope = 'subscription'\n\nresource rg 'Microsoft.Resources/resourceGroups@2023-07-01' = {}" - ) - - assert is_subscription_scoped(bicep_file) is True - - def test_is_subscription_scoped_false(self, tmp_path): - from azext_prototype.stages.deploy_helpers import is_subscription_scoped - - bicep_file = tmp_path / "main.bicep" - bicep_file.write_text("resource sa 'Microsoft.Storage/storageAccounts@2023-01-01' = {}") - - assert is_subscription_scoped(bicep_file) is False - - def test_get_deploy_location_from_params(self, tmp_path): - from azext_prototype.stages.deploy_helpers import get_deploy_location - - params = tmp_path / "parameters.json" - params.write_text('{"parameters": {"location": {"value": "westus2"}}}') - - result = get_deploy_location(tmp_path) - assert result == "westus2" - - def test_get_deploy_location_returns_none(self, tmp_path): - from azext_prototype.stages.deploy_helpers import get_deploy_location - - result = get_deploy_location(tmp_path) - assert result is None - - @patch("subprocess.run") - def test_deploy_bicep_resource_group_scope(self, mock_run, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_bicep - - bicep_dir = tmp_path / "stage1" - bicep_dir.mkdir() - (bicep_dir / "main.bicep").write_text("resource sa 'Microsoft.Storage/storageAccounts@2023-01-01' = {}") - - mock_run.return_value = MagicMock(returncode=0, stdout='{"properties":{}}', stderr="") - - result = deploy_bicep(bicep_dir, "sub-123", "my-rg") - assert result["status"] == "deployed" - assert result["scope"] == "resourceGroup" - assert result["template"] == "main.bicep" - - # Verify az deployment group create was called (not sub create) - cmd_parts = mock_run.call_args[0][0] - assert "group" in cmd_parts - assert "create" in cmd_parts - - @patch("subprocess.run") - def test_deploy_bicep_subscription_scope(self, mock_run, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_bicep - - bicep_dir = tmp_path / "stage1" - bicep_dir.mkdir() - (bicep_dir / "main.bicep").write_text( - "targetScope = 'subscription'\n\nresource rg 'Microsoft.Resources/resourceGroups@2023-07-01' = {}" - ) - - mock_run.return_value = MagicMock(returncode=0, stdout='{"properties":{}}', stderr="") - - result = deploy_bicep(bicep_dir, "sub-123", "") - assert result["status"] == "deployed" - assert result["scope"] == "subscription" - - # Verify az deployment sub create was called - cmd_parts = mock_run.call_args[0][0] - assert "sub" in cmd_parts - - @patch("subprocess.run") - def test_deploy_bicep_with_params_file(self, mock_run, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_bicep - - (tmp_path / "main.bicep").write_text("param location string\n") - (tmp_path / "main.parameters.json").write_text('{"parameters":{"location":{"value":"eastus"}}}') - - mock_run.return_value = MagicMock(returncode=0, stdout="{}", stderr="") - - deploy_bicep(tmp_path, "sub-123", "my-rg") - - cmd_parts = mock_run.call_args[0][0] - assert "--parameters" in cmd_parts - - def test_deploy_bicep_no_bicep_files_skips(self, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_bicep - - empty_dir = tmp_path / "empty" - empty_dir.mkdir() - - result = deploy_bicep(empty_dir, "sub-123", "my-rg") - assert result["status"] == "skipped" - - def test_deploy_bicep_fallback_to_first_file(self, tmp_path): - """When no main.bicep exists, uses the first .bicep file.""" - from azext_prototype.stages.deploy_helpers import deploy_bicep - - (tmp_path / "network.bicep").write_text("resource vnet 'Microsoft.Network/virtualNetworks@2023-05-01' = {}") - - with patch("subprocess.run") as mock_run: - mock_run.return_value = MagicMock(returncode=0, stdout="{}", stderr="") - result = deploy_bicep(tmp_path, "sub-123", "my-rg") - - assert result["status"] == "deployed" - assert result["template"] == "network.bicep" - - def test_deploy_bicep_rg_required_for_rg_scope(self, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_bicep - - (tmp_path / "main.bicep").write_text("resource sa 'Microsoft.Storage/storageAccounts@2023-01-01' = {}") - - result = deploy_bicep(tmp_path, "sub-123", "") - assert result["status"] == "failed" - assert "Resource group required" in result["error"] - - @patch("subprocess.run") - def test_whatif_bicep_runs(self, mock_run, tmp_path): - from azext_prototype.stages.deploy_helpers import whatif_bicep - - (tmp_path / "main.bicep").write_text("resource sa 'Microsoft.Storage/storageAccounts@2023-01-01' = {}") - mock_run.return_value = MagicMock(returncode=0, stdout="Resource changes: 1 to create", stderr="") - - result = whatif_bicep(tmp_path, "sub-123", "my-rg") - assert result["status"] == "previewed" - assert "Resource changes" in result["output"] - - cmd_parts = mock_run.call_args[0][0] - assert "what-if" in cmd_parts - class TestDesignStage: """Test the design stage.""" - def test_design_stage_instantiates(self): - from azext_prototype.stages.design_stage import DesignStage - - stage = DesignStage() - assert stage.name == "design" - assert stage.reentrant is True - def test_design_stage_has_execute(self): from azext_prototype.stages.design_stage import DesignStage @@ -1019,20 +819,347 @@ def test_design_skip_discovery_fails_without_state( ) -class TestBuildStage: - """Test the build stage.""" +# --- Additional imports from merged flat test --- +from knack.util import CLIError + +from azext_prototype.agents.base import AgentContext +from azext_prototype.agents.registry import AgentRegistry +from azext_prototype.ai.provider import AIResponse +from azext_prototype.stages.build_session import BuildResult - def test_build_stage_instantiates(self): - from azext_prototype.stages.build_stage import BuildStage - stage = BuildStage() - assert stage is not None - assert stage.name == "build" +# ====================================================================== - def test_match_templates_empty_architecture(self): + +class TestBuildStageExecution: + """Test BuildStage methods.""" + + def _make_stage(self): from azext_prototype.stages.build_stage import BuildStage - stage = BuildStage() - config = MagicMock() - result = stage._match_templates({"architecture": ""}, config) - assert result == [] + return BuildStage() + + def test_execute_dry_run(self, project_with_design, mock_agent_context, populated_registry): + stage = self._make_stage() + stage.get_guards = lambda: [] + mock_agent_context.project_dir = str(project_with_design) + mock_agent_context.ai_provider.chat.return_value = AIResponse(content="Generated code", model="gpt-4o") + + result = stage.execute(mock_agent_context, populated_registry, scope="docs", dry_run=True) + assert result["status"] == "dry-run" + + def test_execute_all_scopes_dry_run(self, project_with_design, mock_agent_context, populated_registry): + stage = self._make_stage() + stage.get_guards = lambda: [] + mock_agent_context.project_dir = str(project_with_design) + + result = stage.execute(mock_agent_context, populated_registry, scope="all", dry_run=True) + assert result["status"] == "dry-run" + assert result["scope"] == "all" + + @patch("azext_prototype.stages.build_stage.BuildSession") + def test_execute_interactive_delegates_to_session( + self, mock_session_cls, project_with_design, mock_agent_context, populated_registry + ): + stage = self._make_stage() + stage.get_guards = lambda: [] + mock_agent_context.project_dir = str(project_with_design) + + mock_result = BuildResult( + files_generated=["main.tf"], + deployment_stages=[{"stage": 1, "name": "Foundation"}], + policy_overrides=[], + resources=[{"resourceType": "Microsoft.Compute/virtualMachines", "sku": "Standard_B2s"}], + review_accepted=True, + cancelled=False, + ) + mock_session_cls.return_value.run.return_value = mock_result + + result = stage.execute(mock_agent_context, populated_registry, scope="all", dry_run=False) + assert result["status"] == "success" + assert result["scope"] == "all" + assert result["files_generated"] == ["main.tf"] + mock_session_cls.return_value.run.assert_called_once() + +# ====================================================================== + + +class TestInitStageExecution: + """Test InitStage methods.""" + + def _make_stage(self): + from azext_prototype.stages.init_stage import InitStage + + return InitStage() + + def test_init_guards(self): + """Init has no unconditional guards; gh check is conditional inside execute().""" + stage = self._make_stage() + guards = stage.get_guards() + assert len(guards) == 0 + + @patch("subprocess.run") + def test_check_gh_true(self, mock_run): + stage = self._make_stage() + mock_run.return_value = MagicMock(returncode=0) + assert stage._check_gh() is True + + @patch("subprocess.run", side_effect=FileNotFoundError) + def test_check_gh_false(self, mock_run): + stage = self._make_stage() + assert stage._check_gh() is False + + def test_create_scaffold(self, tmp_path): + stage = self._make_stage() + project_dir = tmp_path / "my-project" + stage._create_scaffold(project_dir) + + assert (project_dir / "concept" / "docs").is_dir() + assert (project_dir / ".prototype" / "agents").is_dir() + # infra, apps, db dirs are NOT created at init — only during build + assert not (project_dir / "concept" / "apps").exists() + assert not (project_dir / "concept" / "infra").exists() + assert not (project_dir / "concept" / "db").exists() + + def test_create_gitignore(self, tmp_path): + stage = self._make_stage() + stage._create_gitignore(tmp_path) + gi = tmp_path / ".gitignore" + assert gi.exists() + content = gi.read_text() + assert ".terraform/" in content + assert "__pycache__/" in content + + def test_create_gitignore_no_overwrite(self, tmp_path): + stage = self._make_stage() + gi = tmp_path / ".gitignore" + gi.write_text("custom content", encoding="utf-8") + stage._create_gitignore(tmp_path) + assert gi.read_text() == "custom content" + + @patch("azext_prototype.auth.copilot_license.CopilotLicenseValidator") + @patch("azext_prototype.auth.github_auth.GitHubAuthManager") + @patch("azext_prototype.stages.init_stage.InitStage._check_gh", return_value=True) + def test_execute_full(self, mock_gh, mock_auth_cls, mock_lic_cls, tmp_path): + stage = self._make_stage() + stage.get_guards = lambda: [] + + mock_auth = MagicMock() + mock_auth.ensure_authenticated.return_value = {"login": "devuser"} + mock_auth_cls.return_value = mock_auth + mock_lic = MagicMock() + mock_lic.validate_license.return_value = {"plan": "business"} + mock_lic_cls.return_value = mock_lic + + from azext_prototype.agents.base import AgentContext + from azext_prototype.agents.registry import AgentRegistry + + ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) + registry = AgentRegistry() + + out = tmp_path / "test-proj" + result = stage.execute( + ctx, + registry, + name="test-proj", + location="westus2", + iac_tool="bicep", + ai_provider="github-models", + output_dir=str(out), + ) + assert result["status"] == "success" + assert (out / "prototype.yaml").exists() + + @patch("azext_prototype.auth.copilot_license.CopilotLicenseValidator") + @patch("azext_prototype.auth.github_auth.GitHubAuthManager") + @patch("azext_prototype.stages.init_stage.InitStage._check_gh", return_value=True) + def test_execute_license_failure_continues(self, mock_gh, mock_auth_cls, mock_lic_cls, tmp_path): + """License validation failure should warn but continue.""" + stage = self._make_stage() + stage.get_guards = lambda: [] + + mock_auth = MagicMock() + mock_auth.ensure_authenticated.return_value = {"login": "devuser"} + mock_auth_cls.return_value = mock_auth + mock_lic = MagicMock() + mock_lic.validate_license.side_effect = CLIError("No license") + mock_lic_cls.return_value = mock_lic + + from azext_prototype.agents.base import AgentContext + from azext_prototype.agents.registry import AgentRegistry + + ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) + registry = AgentRegistry() + + result = stage.execute( + ctx, + registry, + name="lic-test", + location="eastus", + ai_provider="github-models", + output_dir=str(tmp_path / "lic-test"), + ) + assert result["status"] == "success" + assert result["copilot_license"]["status"] == "unverified" + + def test_execute_no_name_raises(self, tmp_path): + stage = self._make_stage() + stage.get_guards = lambda: [] + + from azext_prototype.agents.base import AgentContext + from azext_prototype.agents.registry import AgentRegistry + + ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) + registry = AgentRegistry() + + with pytest.raises(CLIError, match="Project name"): + stage.execute(ctx, registry, name="", output_dir=str(tmp_path / "empty-name")) + + def test_execute_no_location_raises(self, tmp_path): + stage = self._make_stage() + stage.get_guards = lambda: [] + + from azext_prototype.agents.base import AgentContext + from azext_prototype.agents.registry import AgentRegistry + + ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) + registry = AgentRegistry() + + with pytest.raises(CLIError, match="region is required"): + stage.execute( + ctx, + registry, + name="test-proj", + location=None, + output_dir=str(tmp_path / "test-proj"), + ) + + def test_execute_azure_openai_skips_auth(self, tmp_path): + """azure-openai provider should skip GitHub auth entirely.""" + stage = self._make_stage() + stage.get_guards = lambda: [] + + from azext_prototype.agents.base import AgentContext + from azext_prototype.agents.registry import AgentRegistry + + ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) + registry = AgentRegistry() + + result = stage.execute( + ctx, + registry, + name="aoai-test", + location="eastus", + ai_provider="azure-openai", + output_dir=str(tmp_path / "aoai-test"), + ) + assert result["status"] == "success" + assert result["github_user"] is None + assert "copilot_license" not in result + + def test_execute_environment_stored(self, tmp_path): + """--environment should be persisted in config.""" + stage = self._make_stage() + stage.get_guards = lambda: [] + + from azext_prototype.agents.base import AgentContext + from azext_prototype.agents.registry import AgentRegistry + from azext_prototype.config import ProjectConfig + + ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) + registry = AgentRegistry() + + out = tmp_path / "env-test" + stage.execute( + ctx, + registry, + name="env-test", + location="westus2", + ai_provider="azure-openai", + environment="prod", + output_dir=str(out), + ) + config = ProjectConfig(str(out)) + config.load() + assert config.get("project.environment") == "prod" + assert config.get("naming.env") == "prd" + assert config.get("naming.zone_id") == "zp" + + def test_execute_model_override(self, tmp_path): + """Explicit --model should override provider default.""" + stage = self._make_stage() + stage.get_guards = lambda: [] + + from azext_prototype.agents.base import AgentContext + from azext_prototype.agents.registry import AgentRegistry + from azext_prototype.config import ProjectConfig + + ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) + registry = AgentRegistry() + + out = tmp_path / "model-test" + stage.execute( + ctx, + registry, + name="model-test", + location="eastus", + ai_provider="azure-openai", + model="gpt-4o-mini", + output_dir=str(out), + ) + config = ProjectConfig(str(out)) + config.load() + assert config.get("ai.model") == "gpt-4o-mini" + + def test_execute_idempotency_cancel(self, tmp_path): + """Existing project + user declining should cancel.""" + stage = self._make_stage() + stage.get_guards = lambda: [] + + from azext_prototype.agents.base import AgentContext + from azext_prototype.agents.registry import AgentRegistry + + # Pre-create project directory with config + proj = tmp_path / "idem-test" + proj.mkdir() + (proj / "prototype.yaml").write_text("project:\n name: old\n") + + ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) + registry = AgentRegistry() + + with patch("builtins.input", return_value="n"): + result = stage.execute( + ctx, + registry, + name="idem-test", + location="eastus", + ai_provider="azure-openai", + output_dir=str(proj), + ) + assert result["status"] == "cancelled" + + def test_execute_marks_init_complete(self, tmp_path): + """Init stage should set stages.init.completed and timestamp.""" + stage = self._make_stage() + stage.get_guards = lambda: [] + + from azext_prototype.agents.base import AgentContext + from azext_prototype.agents.registry import AgentRegistry + from azext_prototype.config import ProjectConfig + + ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) + registry = AgentRegistry() + + out = tmp_path / "complete-test" + stage.execute( + ctx, + registry, + name="complete-test", + location="eastus", + ai_provider="azure-openai", + output_dir=str(out), + ) + config = ProjectConfig(str(out)) + config.load() + assert config.get("stages.init.completed") is True + assert config.get("stages.init.timestamp") is not None diff --git a/tests/telemetry/__init__.py b/tests/telemetry/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_telemetry.py b/tests/telemetry/test___init__.py similarity index 99% rename from tests/test_telemetry.py rename to tests/telemetry/test___init__.py index 8240f10..47578b6 100644 --- a/tests/test_telemetry.py +++ b/tests/telemetry/test___init__.py @@ -355,7 +355,7 @@ def test_reads_from_metadata(self): from azext_prototype.telemetry import _get_extension_version version = _get_extension_version() - assert version == "0.2.1b6" + assert version == "0.2.1b7" def test_returns_unknown_on_error(self): from azext_prototype.telemetry import _get_extension_version diff --git a/tests/templates/__init__.py b/tests/templates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_template_compliance.py b/tests/templates/test_template_compliance.py similarity index 99% rename from tests/test_template_compliance.py rename to tests/templates/test_template_compliance.py index f32e57c..2507d1d 100644 --- a/tests/test_template_compliance.py +++ b/tests/templates/test_template_compliance.py @@ -25,9 +25,9 @@ # Helpers # ------------------------------------------------------------------ # -BUILTIN_DIR = Path(__file__).resolve().parent.parent / "azext_prototype" / "templates" / "workloads" +BUILTIN_DIR = Path(__file__).resolve().parent.parent.parent / "azext_prototype" / "templates" / "workloads" -BUILTIN_POLICY_DIR = Path(__file__).resolve().parent.parent / "azext_prototype" / "governance" / "policies" +BUILTIN_POLICY_DIR = Path(__file__).resolve().parent.parent.parent / "azext_prototype" / "governance" / "policies" def _write_yaml(dest: Path, data: dict | list | str) -> Path: diff --git a/tests/test_templates.py b/tests/templates/test_templates.py similarity index 99% rename from tests/test_templates.py rename to tests/templates/test_templates.py index f40f80a..dae5849 100644 --- a/tests/test_templates.py +++ b/tests/templates/test_templates.py @@ -17,7 +17,7 @@ # Helpers # ------------------------------------------------------------------ # -BUILTIN_DIR = Path(__file__).resolve().parent.parent / "azext_prototype" / "templates" / "workloads" +BUILTIN_DIR = Path(__file__).resolve().parent.parent.parent / "azext_prototype" / "templates" / "workloads" EXPECTED_BUILTIN_NAMES = sorted( [ @@ -669,7 +669,7 @@ def test_sql_auto_pause(self): class TestTemplateSchema: """Verify the JSON schema file exists and is valid JSON.""" - SCHEMA_PATH = Path(__file__).resolve().parent.parent / "azext_prototype" / "templates" / "template.schema.json" + SCHEMA_PATH = Path(__file__).resolve().parent.parent.parent / "azext_prototype" / "templates" / "template.schema.json" def test_schema_file_exists(self): assert self.SCHEMA_PATH.exists() diff --git a/tests/test_custom.py b/tests/test_custom.py index 1126d32..0c00910 100644 --- a/tests/test_custom.py +++ b/tests/test_custom.py @@ -10,7 +10,7 @@ # All command functions call _get_project_dir() internally (uses Path.cwd()), # so we mock it to point at our tmp fixture directories. -_CUSTOM_MODULE = "azext_prototype.custom" +_MOD = "azext_prototype.custom" class TestGetProjectDir: @@ -42,7 +42,7 @@ def test_missing_config_raises(self, tmp_project): class TestPrototypeStatus: """Test az prototype status command.""" - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_status_with_config(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_status @@ -60,7 +60,7 @@ def test_status_with_config(self, mock_dir, project_with_config): assert "build" in result["stages"] assert "deploy" in result["stages"] - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_status_without_config(self, mock_dir, tmp_project): from azext_prototype.custom import prototype_status @@ -75,7 +75,7 @@ def test_status_without_config(self, mock_dir, tmp_project): class TestPrototypeConfigShow: """Test az prototype config show command.""" - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_config_show(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_config_show @@ -89,7 +89,7 @@ def test_config_show(self, mock_dir, project_with_config): class TestPrototypeConfigSet: """Test az prototype config set command.""" - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_config_set(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_config_set @@ -111,7 +111,7 @@ def test_config_set_missing_key_raises(self): class TestPrototypeAgentList: """Test az prototype agent list command.""" - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_agent_list(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_agent_list @@ -122,7 +122,7 @@ def test_agent_list(self, mock_dir, project_with_config): assert isinstance(result, list) assert len(result) >= 8 # 8 built-in agents - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_agent_list_no_builtin(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_agent_list @@ -137,7 +137,7 @@ def test_agent_list_no_builtin(self, mock_dir, project_with_config): class TestPrototypeAgentShow: """Test az prototype agent show command.""" - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_agent_show_builtin(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_agent_show @@ -159,7 +159,7 @@ def test_agent_show_missing_name_raises(self): class TestPrototypeAgentAdd: """Test az prototype agent add command — all three modes.""" - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_add_default_template(self, mock_dir, project_with_config): """Mode 1: --name only with interactive input → creates agent from prompts.""" from azext_prototype.custom import prototype_agent_add @@ -184,7 +184,7 @@ def test_add_default_template(self, mock_dir, project_with_config): content = _yaml.safe_load(agent_file.read_text(encoding="utf-8")) assert content["name"] == "my-data-agent" - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_add_from_builtin_definition(self, mock_dir, project_with_config): """Mode 2: --name + --definition → copies named builtin definition.""" from azext_prototype.custom import prototype_agent_add @@ -205,7 +205,7 @@ def test_add_from_builtin_definition(self, mock_dir, project_with_config): content = _yaml.safe_load(agent_file.read_text(encoding="utf-8")) assert content["name"] == "my-architect" - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_add_from_user_file(self, mock_dir, project_with_config): """Mode 3: --name + --file → copies user-supplied file.""" from azext_prototype.custom import prototype_agent_add @@ -235,7 +235,7 @@ def test_add_missing_name_raises(self): with pytest.raises(CLIError, match="--name"): prototype_agent_add(cmd, name=None) - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_add_file_and_definition_mutually_exclusive(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_agent_add @@ -245,7 +245,7 @@ def test_add_file_and_definition_mutually_exclusive(self, mock_dir, project_with with pytest.raises(CLIError, match="mutually exclusive"): prototype_agent_add(cmd, name="x", file="./a.yaml", definition="cloud_architect") - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_add_unknown_definition_raises(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_agent_add @@ -255,7 +255,7 @@ def test_add_unknown_definition_raises(self, mock_dir, project_with_config): with pytest.raises(CLIError, match="Unknown definition"): prototype_agent_add(cmd, name="x", definition="nonexistent_agent") - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_add_duplicate_name_raises(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_agent_add @@ -266,7 +266,7 @@ def test_add_duplicate_name_raises(self, mock_dir, project_with_config): with pytest.raises(CLIError, match="already exists"): prototype_agent_add(cmd, name="dup-agent", definition="cloud_architect") - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_add_records_config_manifest(self, mock_dir, project_with_config): """Verify the agent is recorded in prototype.yaml.""" from azext_prototype.custom import _load_config, prototype_agent_add @@ -283,7 +283,7 @@ def test_add_records_config_manifest(self, mock_dir, project_with_config): assert "file" in custom["manifest-test"] assert "capabilities" in custom["manifest-test"] - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_add_file_not_found_raises(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_agent_add @@ -347,7 +347,7 @@ def test_rewrites_name_field(self, tmp_path): class TestPrototypeGenerateDocs: """Test az prototype generate docs command.""" - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_generate_docs_creates_files(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_generate_docs @@ -364,7 +364,7 @@ def test_generate_docs_creates_files(self, mock_dir, project_with_config): md_files = list(docs_path.glob("*.md")) assert len(md_files) >= 1 - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_generate_docs_default_output(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_generate_docs @@ -379,7 +379,7 @@ def test_generate_docs_default_output(self, mock_dir, project_with_config): class TestPrototypeGenerateSpeckit: """Test az prototype generate speckit command.""" - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_generate_speckit_creates_files(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_generate_speckit @@ -391,7 +391,7 @@ def test_generate_speckit_creates_files(self, mock_dir, project_with_config): assert result is not None assert result["status"] == "generated" - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._get_project_dir") def test_generate_speckit_manifest(self, mock_dir, project_with_config): from azext_prototype.custom import prototype_generate_speckit @@ -415,8 +415,8 @@ def test_generate_speckit_manifest(self, mock_dir, project_with_config): class TestPrototypeGenerateBacklog: """Test az prototype generate backlog command.""" - @patch(f"{_CUSTOM_MODULE}._check_requirements") - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._get_project_dir") def test_generate_backlog_github(self, mock_dir, mock_check_req, project_with_design, mock_ai_provider): """Backlog session runs and returns result for github provider.""" from azext_prototype.custom import prototype_generate_backlog @@ -427,7 +427,7 @@ def test_generate_backlog_github(self, mock_dir, mock_check_req, project_with_de mock_result = BacklogResult(items_generated=3, items_pushed=0) - with patch(f"{_CUSTOM_MODULE}._build_context") as mock_ctx, patch( + with patch(f"{_MOD}._build_context") as mock_ctx, patch( "azext_prototype.stages.backlog_session.BacklogSession" ) as MockSession: from azext_prototype.agents.base import AgentContext @@ -446,8 +446,8 @@ def test_generate_backlog_github(self, mock_dir, mock_check_req, project_with_de assert result["provider"] == "github" assert result["items_generated"] == 3 - @patch(f"{_CUSTOM_MODULE}._check_requirements") - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._get_project_dir") def test_generate_backlog_devops(self, mock_dir, mock_check_req, project_with_design, mock_ai_provider): """Backlog session runs for devops provider.""" from azext_prototype.custom import prototype_generate_backlog @@ -458,7 +458,7 @@ def test_generate_backlog_devops(self, mock_dir, mock_check_req, project_with_de mock_result = BacklogResult(items_generated=2, items_pushed=0) - with patch(f"{_CUSTOM_MODULE}._build_context") as mock_ctx, patch( + with patch(f"{_MOD}._build_context") as mock_ctx, patch( "azext_prototype.stages.backlog_session.BacklogSession" ) as MockSession: from azext_prototype.agents.base import AgentContext @@ -476,8 +476,8 @@ def test_generate_backlog_devops(self, mock_dir, mock_check_req, project_with_de assert result["status"] == "generated" assert result["provider"] == "devops" - @patch(f"{_CUSTOM_MODULE}._check_requirements") - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._get_project_dir") def test_generate_backlog_invalid_provider_raises( self, mock_dir, mock_check_req, project_with_design, mock_ai_provider ): @@ -486,7 +486,7 @@ def test_generate_backlog_invalid_provider_raises( mock_dir.return_value = str(project_with_design) cmd = MagicMock() - with patch(f"{_CUSTOM_MODULE}._build_context") as mock_ctx: + with patch(f"{_MOD}._build_context") as mock_ctx: from azext_prototype.agents.base import AgentContext ctx = AgentContext( @@ -499,15 +499,15 @@ def test_generate_backlog_invalid_provider_raises( with pytest.raises(CLIError, match="Unsupported backlog provider"): prototype_generate_backlog(cmd, provider="jira", org="x", project="y") - @patch(f"{_CUSTOM_MODULE}._check_requirements") - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._get_project_dir") def test_generate_backlog_no_design_raises(self, mock_dir, mock_check_req, project_with_config, mock_ai_provider): from azext_prototype.custom import prototype_generate_backlog mock_dir.return_value = str(project_with_config) cmd = MagicMock() - with patch(f"{_CUSTOM_MODULE}._build_context") as mock_ctx: + with patch(f"{_MOD}._build_context") as mock_ctx: from azext_prototype.agents.base import AgentContext ctx = AgentContext( @@ -520,8 +520,8 @@ def test_generate_backlog_no_design_raises(self, mock_dir, mock_check_req, proje with pytest.raises(CLIError, match="No architecture design found"): prototype_generate_backlog(cmd, provider="github", org="x", project="y") - @patch(f"{_CUSTOM_MODULE}._check_requirements") - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._get_project_dir") def test_generate_backlog_defaults_from_config( self, mock_dir, mock_check_req, project_with_design, mock_ai_provider ): @@ -544,7 +544,7 @@ def test_generate_backlog_defaults_from_config( mock_result = BacklogResult(items_generated=1, items_pushed=0) - with patch(f"{_CUSTOM_MODULE}._build_context") as mock_ctx, patch( + with patch(f"{_MOD}._build_context") as mock_ctx, patch( "azext_prototype.stages.backlog_session.BacklogSession" ) as MockSession: from azext_prototype.agents.base import AgentContext @@ -561,8 +561,8 @@ def test_generate_backlog_defaults_from_config( assert result["provider"] == "devops" - @patch(f"{_CUSTOM_MODULE}._check_requirements") - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._get_project_dir") def test_generate_backlog_result_fields(self, mock_dir, mock_check_req, project_with_design, mock_ai_provider): """Result dict includes expected fields.""" from azext_prototype.custom import prototype_generate_backlog @@ -573,7 +573,7 @@ def test_generate_backlog_result_fields(self, mock_dir, mock_check_req, project_ mock_result = BacklogResult(items_generated=1, items_pushed=0) - with patch(f"{_CUSTOM_MODULE}._build_context") as mock_ctx, patch( + with patch(f"{_MOD}._build_context") as mock_ctx, patch( "azext_prototype.stages.backlog_session.BacklogSession" ) as MockSession: from azext_prototype.agents.base import AgentContext @@ -591,8 +591,8 @@ def test_generate_backlog_result_fields(self, mock_dir, mock_check_req, project_ assert result["status"] == "generated" assert result["items_generated"] == 1 - @patch(f"{_CUSTOM_MODULE}._check_requirements") - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._get_project_dir") def test_generate_backlog_prompts_when_unconfigured( self, mock_dir, mock_check_req, project_with_design, mock_ai_provider ): @@ -605,8 +605,8 @@ def test_generate_backlog_prompts_when_unconfigured( mock_result = BacklogResult(items_generated=1, items_pushed=0) - with patch(f"{_CUSTOM_MODULE}._build_context") as mock_ctx, patch( - f"{_CUSTOM_MODULE}._prompt_backlog_config" + with patch(f"{_MOD}._build_context") as mock_ctx, patch( + f"{_MOD}._prompt_backlog_config" ) as mock_prompt, patch("azext_prototype.stages.backlog_session.BacklogSession") as MockSession: from azext_prototype.agents.base import AgentContext @@ -639,8 +639,8 @@ def test_generate_backlog_prompts_when_unconfigured( assert saved["backlog"]["org"] == "prompted-org" assert saved["backlog"]["project"] == "prompted-repo" - @patch(f"{_CUSTOM_MODULE}._check_requirements") - @patch(f"{_CUSTOM_MODULE}._get_project_dir") + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._get_project_dir") def test_generate_backlog_no_prompt_when_fully_configured( self, mock_dir, mock_check_req, project_with_design, mock_ai_provider ): @@ -653,8 +653,8 @@ def test_generate_backlog_no_prompt_when_fully_configured( mock_result = BacklogResult(items_generated=1, items_pushed=0) - with patch(f"{_CUSTOM_MODULE}._build_context") as mock_ctx, patch( - f"{_CUSTOM_MODULE}._prompt_backlog_config" + with patch(f"{_MOD}._build_context") as mock_ctx, patch( + f"{_MOD}._prompt_backlog_config" ) as mock_prompt, patch("azext_prototype.stages.backlog_session.BacklogSession") as MockSession: from azext_prototype.agents.base import AgentContext @@ -729,3 +729,2193 @@ def test_invalid_choice_retries(self): result = _prompt_backlog_config() assert result["provider"] == "github" + +# ====================================================================== +# Helper functions +# ====================================================================== + + +class TestBuildRegistry: + """Test _build_registry helper.""" + + def test_build_registry_builtin_only(self): + from azext_prototype.custom import _build_registry + + registry = _build_registry(config=None, project_dir=None) + agents = registry.list_all() + assert len(agents) >= 8 + + def test_build_registry_with_custom_agents(self, project_with_config): + from azext_prototype.custom import _build_registry, _load_config + + # Create a custom YAML agent + agent_dir = project_with_config / ".prototype" / "agents" + agent_dir.mkdir(parents=True, exist_ok=True) + (agent_dir / "test-agent.yaml").write_text( + "name: test-agent\ndescription: A test\ncapabilities:\n - develop\n" "system_prompt: You are a test.\n", + encoding="utf-8", + ) + + config = _load_config(str(project_with_config)) + registry = _build_registry(config, str(project_with_config)) + names = [a.name for a in registry.list_all()] + assert "test-agent" in names + + def test_build_registry_with_overrides(self, project_with_config): + from azext_prototype.custom import _build_registry, _load_config + + # Write a YAML agent to use as override + override_file = project_with_config / "override.yaml" + override_file.write_text( + "name: cloud-architect\ndescription: Override\ncapabilities:\n - architect\n" + "system_prompt: Override prompt.\n", + encoding="utf-8", + ) + + config = _load_config(str(project_with_config)) + config.set("agents.overrides", {"cloud-architect": "override.yaml"}) + + registry = _build_registry(config, str(project_with_config)) + agent = registry.get("cloud-architect") + assert "Override" in agent.description + + +class TestBuildContext: + """Test _build_context helper.""" + + @patch("azext_prototype.ai.factory.create_ai_provider") + def test_build_context_creates_agent_context(self, mock_factory, project_with_config): + from azext_prototype.custom import _build_context, _load_config + + mock_provider = MagicMock() + mock_factory.return_value = mock_provider + config = _load_config(str(project_with_config)) + + ctx = _build_context(config, str(project_with_config)) + assert ctx.project_dir == str(project_with_config) + assert ctx.ai_provider is mock_provider + + +class TestPrepareCommand: + """Test _prepare_command helper.""" + + @patch(f"{_MOD}._check_requirements") + @patch("azext_prototype.ai.factory.create_ai_provider") + def test_prepare_command(self, mock_factory, mock_check_req, project_with_config): + from azext_prototype.custom import _prepare_command + + mock_factory.return_value = MagicMock() + pd, config, registry, ctx = _prepare_command(str(project_with_config)) + assert pd == str(project_with_config) + assert config is not None + assert registry is not None + assert ctx is not None + + +class TestCheckRequirements: + """Test _check_requirements wiring in command entry points.""" + + def test_check_requirements_passes_when_all_ok(self): + from azext_prototype.custom import _check_requirements + from azext_prototype.requirements import CheckResult + + with patch("azext_prototype.requirements.check_all") as mock_check: + mock_check.return_value = [ + CheckResult(name="Python", status="pass", installed_version="3.12.0", required=">=3.9.0", message="ok"), + ] + # Should not raise + _check_requirements("terraform") + + def test_check_requirements_raises_on_missing(self): + from azext_prototype.custom import _check_requirements + from azext_prototype.requirements import CheckResult + + with patch("azext_prototype.requirements.check_all") as mock_check: + mock_check.return_value = [ + CheckResult( + name="Terraform", + status="missing", + installed_version=None, + required=">=1.14.0", + message="Terraform is not installed", + install_hint="https://developer.hashicorp.com/terraform/install", + ), + ] + with pytest.raises(CLIError, match="Tool requirements not met"): + _check_requirements("terraform") + + def test_check_requirements_raises_on_version_fail(self): + from azext_prototype.custom import _check_requirements + from azext_prototype.requirements import CheckResult + + with patch("azext_prototype.requirements.check_all") as mock_check: + mock_check.return_value = [ + CheckResult( + name="Azure CLI", + status="fail", + installed_version="2.40.0", + required=">=2.50.0", + message="Azure CLI 2.40.0 does not satisfy >=2.50.0", + install_hint="https://learn.microsoft.com/cli/azure/install-azure-cli", + ), + ] + with pytest.raises(CLIError, match="Azure CLI"): + _check_requirements(None) + + def test_check_requirements_includes_install_hint(self): + from azext_prototype.custom import _check_requirements + from azext_prototype.requirements import CheckResult + + with patch("azext_prototype.requirements.check_all") as mock_check: + mock_check.return_value = [ + CheckResult( + name="Terraform", + status="missing", + installed_version=None, + required=">=1.14.0", + message="Terraform is not installed", + install_hint="https://developer.hashicorp.com/terraform/install", + ), + ] + with pytest.raises(CLIError, match="Install:.*hashicorp"): + _check_requirements("terraform") + + @patch("azext_prototype.ai.factory.create_ai_provider") + def test_prepare_command_calls_check_requirements(self, mock_factory, project_with_config): + from azext_prototype.custom import _prepare_command + + mock_factory.return_value = MagicMock() + with patch(f"{_MOD}._check_requirements") as mock_check: + _prepare_command(str(project_with_config)) + mock_check.assert_called_once() + + def test_init_calls_check_requirements(self, tmp_path): + with patch(f"{_MOD}._check_requirements") as mock_check, patch( + "azext_prototype.stages.init_stage.InitStage" + ) as MockStage: + from azext_prototype.custom import prototype_init + + mock_stage = MockStage.return_value + mock_stage.can_run.return_value = (True, []) + mock_stage.execute.return_value = {"status": "success"} + + cmd = MagicMock() + prototype_init(cmd, name="test", location="eastus", output_dir=str(tmp_path)) + mock_check.assert_called_once_with("terraform") # default iac_tool + + +class TestCheckGuards: + """Test _check_guards helper.""" + + def test_check_guards_pass(self): + from azext_prototype.custom import _check_guards + + stage = MagicMock() + stage.can_run.return_value = (True, []) + _check_guards(stage) # Should not raise + + def test_check_guards_fail(self): + from azext_prototype.custom import _check_guards + + stage = MagicMock() + stage.can_run.return_value = (False, ["Missing gh CLI"]) + with pytest.raises(CLIError, match="Prerequisites not met"): + _check_guards(stage) + + +class TestGetRegistryWithFallback: + """Test _get_registry_with_fallback helper.""" + + def test_with_valid_config(self, project_with_config): + from azext_prototype.custom import _get_registry_with_fallback + + registry = _get_registry_with_fallback(str(project_with_config)) + assert len(registry.list_all()) >= 8 + + def test_without_config_falls_back(self, tmp_project): + from azext_prototype.custom import _get_registry_with_fallback + + registry = _get_registry_with_fallback(str(tmp_project)) + assert len(registry.list_all()) >= 8 + + +# ====================================================================== +# Stage commands +# ====================================================================== + + +class TestPrototypeInit: + """Test the init command.""" + + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._check_guards") + @patch("azext_prototype.auth.copilot_license.CopilotLicenseValidator") + @patch("azext_prototype.auth.github_auth.GitHubAuthManager") + @patch("azext_prototype.stages.init_stage.InitStage._check_gh", return_value=True) + def test_init_success(self, mock_gh, mock_auth_cls, mock_lic_cls, mock_guards, mock_check_req, tmp_path): + from azext_prototype.custom import prototype_init + + mock_auth = MagicMock() + mock_auth.ensure_authenticated.return_value = {"login": "testuser"} + mock_auth_cls.return_value = mock_auth + + mock_lic = MagicMock() + mock_lic.validate_license.return_value = {"plan": "business", "status": "active"} + mock_lic_cls.return_value = mock_lic + + cmd = MagicMock() + out = tmp_path / "test-proj" + result = prototype_init( + cmd, + name="test-proj", + location="eastus", + output_dir=str(out), + ai_provider="github-models", + json_output=True, + ) + + assert result["status"] == "success" + assert result["github_user"] == "testuser" + assert out.is_dir() + assert (out / "prototype.yaml").exists() + assert (out / ".gitignore").exists() + + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._check_guards") + def test_init_azure_openai_skips_license(self, mock_guards, mock_check_req, tmp_path): + from azext_prototype.custom import prototype_init + + cmd = MagicMock() + result = prototype_init( + cmd, + name="aoai-proj", + location="eastus", + output_dir=str(tmp_path / "aoai-proj"), + ai_provider="azure-openai", + json_output=True, + ) + + assert result["status"] == "success" + assert "copilot_license" not in result + assert result["github_user"] is None + + @patch(f"{_MOD}._check_requirements") + def test_init_missing_name_raises(self, mock_check_req, tmp_path): + from azext_prototype.custom import prototype_init + from azext_prototype.stages.init_stage import InitStage + + cmd = MagicMock() + # Need to bypass guards + with patch.object(InitStage, "get_guards", return_value=[]): + with pytest.raises(CLIError, match="Project name"): + prototype_init(cmd, name=None, location="eastus", output_dir=str(tmp_path / "no-name")) + + @patch(f"{_MOD}._check_requirements") + def test_init_missing_location_raises(self, mock_check_req, tmp_path): + from azext_prototype.custom import prototype_init + from azext_prototype.stages.init_stage import InitStage + + cmd = MagicMock() + with patch.object(InitStage, "get_guards", return_value=[]): + with pytest.raises(CLIError, match="region is required"): + prototype_init(cmd, name="test-proj", location=None, output_dir=str(tmp_path / "test-proj")) + + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._check_guards") + def test_init_idempotency_cancel(self, mock_guards, mock_check_req, tmp_path): + """If project exists and user declines, init should cancel.""" + from azext_prototype.custom import prototype_init + + # Create existing project + proj_dir = tmp_path / "existing-proj" + proj_dir.mkdir() + (proj_dir / "prototype.yaml").write_text("project:\n name: old\n") + + cmd = MagicMock() + with patch("builtins.input", return_value="n"): + result = prototype_init( + cmd, + name="existing-proj", + location="eastus", + output_dir=str(proj_dir), + ai_provider="azure-openai", + json_output=True, + ) + assert result["status"] == "cancelled" + + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._check_guards") + def test_init_idempotency_reinitialize(self, mock_guards, mock_check_req, tmp_path): + """If project exists and user confirms, init should proceed.""" + from azext_prototype.custom import prototype_init + + proj_dir = tmp_path / "reinit-proj" + proj_dir.mkdir() + (proj_dir / "prototype.yaml").write_text("project:\n name: old\n") + + cmd = MagicMock() + with patch("builtins.input", return_value="y"): + result = prototype_init( + cmd, + name="reinit-proj", + location="eastus", + output_dir=str(proj_dir), + ai_provider="azure-openai", + json_output=True, + ) + assert result["status"] == "success" + + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._check_guards") + def test_init_environment_parameter(self, mock_guards, mock_check_req, tmp_path): + """--environment should be stored in config.""" + from azext_prototype.config import ProjectConfig + from azext_prototype.custom import prototype_init + + cmd = MagicMock() + out = tmp_path / "env-proj" + result = prototype_init( + cmd, + name="env-proj", + location="westus2", + output_dir=str(out), + ai_provider="azure-openai", + environment="staging", + json_output=True, + ) + assert result["status"] == "success" + config = ProjectConfig(str(out)) + config.load() + assert config.get("project.environment") == "staging" + assert config.get("naming.env") == "stg" + assert config.get("naming.zone_id") == "zs" + + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._check_guards") + def test_init_model_parameter(self, mock_guards, mock_check_req, tmp_path): + """--model should override the provider default.""" + from azext_prototype.config import ProjectConfig + from azext_prototype.custom import prototype_init + + cmd = MagicMock() + out = tmp_path / "model-proj" + result = prototype_init( + cmd, + name="model-proj", + location="eastus", + output_dir=str(out), + ai_provider="azure-openai", + model="gpt-4o-mini", + json_output=True, + ) + assert result["status"] == "success" + config = ProjectConfig(str(out)) + config.load() + assert config.get("ai.model") == "gpt-4o-mini" + + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._check_guards") + def test_init_default_model_per_provider(self, mock_guards, mock_check_req, tmp_path): + """Without --model, the default should be provider-specific.""" + from azext_prototype.config import ProjectConfig + from azext_prototype.custom import prototype_init + + cmd = MagicMock() + out = tmp_path / "defmodel-proj" + result = prototype_init( + cmd, + name="defmodel-proj", + location="eastus", + output_dir=str(out), + ai_provider="azure-openai", + json_output=True, + ) + assert result["status"] == "success" + config = ProjectConfig(str(out)) + config.load() + assert config.get("ai.model") == "gpt-4o" + + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._check_guards") + def test_init_sends_telemetry_overrides(self, mock_guards, mock_check_req, tmp_path): + """Init should set _telemetry_overrides with resolved values.""" + from azext_prototype.custom import prototype_init + + cmd = MagicMock() + prototype_init( + cmd, + name="telem-proj", + location="westeurope", + output_dir=str(tmp_path / "telem-proj"), + ai_provider="azure-openai", + environment="staging", + iac_tool="bicep", + ) + + assert isinstance(cmd._telemetry_overrides, dict) + overrides = cmd._telemetry_overrides + assert overrides["location"] == "westeurope" + assert overrides["ai_provider"] == "azure-openai" + assert overrides["model"] == "gpt-4o" # resolved default + assert overrides["iac_tool"] == "bicep" + assert overrides["environment"] == "staging" + + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._check_guards") + def test_init_telemetry_overrides_explicit_model(self, mock_guards, mock_check_req, tmp_path): + """When --model is explicit, overrides should use that value.""" + from azext_prototype.custom import prototype_init + + cmd = MagicMock() + prototype_init( + cmd, + name="telem-model-proj", + location="eastus", + output_dir=str(tmp_path / "telem-model-proj"), + ai_provider="azure-openai", + model="gpt-4o-mini", + ) + + overrides = cmd._telemetry_overrides + assert overrides["model"] == "gpt-4o-mini" + assert overrides["ai_provider"] == "azure-openai" + + +class TestPrototypeConfigGet: + """Test the config get command.""" + + def test_config_get_basic(self, project_with_config): + from azext_prototype.custom import prototype_config_get + + cmd = MagicMock() + with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): + result = prototype_config_get(cmd, key="ai.provider", json_output=True) + assert result == {"key": "ai.provider", "value": "github-models"} + + def test_config_get_missing_key(self, project_with_config): + from azext_prototype.custom import prototype_config_get + + cmd = MagicMock() + with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): + with pytest.raises(CLIError, match="not found"): + prototype_config_get(cmd, key="nonexistent.key") + + def test_config_get_masks_secret(self, project_with_config): + from azext_prototype.config import ProjectConfig + from azext_prototype.custom import prototype_config_get + + # Set a secret value first + config = ProjectConfig(str(project_with_config)) + config.load() + config._secrets = {"deploy": {"subscription": "secret-sub-id"}} + config._config["deploy"]["subscription"] = "secret-sub-id" + config.save() + config.save_secrets() + + cmd = MagicMock() + with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): + result = prototype_config_get(cmd, key="deploy.subscription", json_output=True) + assert result == {"key": "deploy.subscription", "value": "***"} + + +class TestPrototypeConfigShowMasking: + """Test that config show masks secrets.""" + + def test_config_show_masks_secret_values(self, project_with_config): + from azext_prototype.config import ProjectConfig + from azext_prototype.custom import prototype_config_show + + # Set a secret value + config = ProjectConfig(str(project_with_config)) + config.load() + config._secrets = {"deploy": {"subscription": "my-secret-sub"}} + config._config["deploy"]["subscription"] = "my-secret-sub" + config.save() + config.save_secrets() + + cmd = MagicMock() + with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): + result = prototype_config_show(cmd, json_output=True) + assert result["deploy"]["subscription"] == "***" + + def test_config_show_preserves_non_secrets(self, project_with_config): + from azext_prototype.custom import prototype_config_show + + cmd = MagicMock() + with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): + result = prototype_config_show(cmd, json_output=True) + # Non-secret value should not be masked + assert result["ai"]["provider"] == "github-models" + + +class TestPrototypeConfigInit: + """Test config init marks init complete.""" + + @patch( + "builtins.input", + side_effect=[ + "y", # overwrite existing prototype.yaml + "my-project", # project name + "eastus", # location + "dev", # environment + "terraform", # iac tool + "1", # naming strategy choice (microsoft-alz) + "myorg", # org + "zd", # zone_id (ALZ-specific) + "copilot", # ai provider + "", # model (accept default) + "", # subscription + "", # resource group + ], + ) + def test_config_init_marks_init_complete(self, mock_input, project_with_config): + from azext_prototype.config import ProjectConfig + from azext_prototype.custom import prototype_config_init + + cmd = MagicMock() + with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): + prototype_config_init(cmd) + + config = ProjectConfig(str(project_with_config)) + config.load() + assert config.get("stages.init.completed") is True + assert config.get("stages.init.timestamp") is not None + + @patch( + "builtins.input", + side_effect=[ + "y", # overwrite existing prototype.yaml + "telemetry-proj", # project name + "westus2", # location + "staging", # environment + "bicep", # iac tool + "2", # naming strategy choice (microsoft-caf) + "myorg", # org + "azure-openai", # ai provider + "gpt-4o", # model + "https://myres.openai.azure.com/", # Azure OpenAI endpoint + "gpt-4o", # deployment name + "", # subscription + "", # resource group + ], + ) + def test_config_init_sends_telemetry_overrides(self, mock_input, project_with_config): + """After prompting, config init should set _telemetry_overrides on cmd.""" + from azext_prototype.custom import prototype_config_init + + cmd = MagicMock() + with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): + prototype_config_init(cmd) + + assert hasattr(cmd, "_telemetry_overrides") + overrides = cmd._telemetry_overrides + assert overrides["location"] == "westus2" + assert overrides["ai_provider"] == "azure-openai" + assert overrides["model"] == "gpt-4o" + assert overrides["iac_tool"] == "bicep" + assert overrides["environment"] == "staging" + assert overrides["naming_strategy"] == "microsoft-caf" + + def test_config_init_cancelled_no_overrides(self, project_with_config): + """If config init is cancelled, no telemetry overrides should be set.""" + from azext_prototype.custom import prototype_config_init + + cmd = MagicMock(spec=[]) # strict spec — no auto-attributes + with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): + with patch("builtins.input", return_value="n"): + result = prototype_config_init(cmd, json_output=True) + assert result["status"] == "cancelled" + assert not hasattr(cmd, "_telemetry_overrides") + + +class TestPrototypeBuild: + """Test the build command.""" + + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._get_project_dir") + @patch("azext_prototype.ai.factory.create_ai_provider") + @patch(f"{_MOD}._check_guards") + def test_build_calls_stage( + self, mock_guards, mock_factory, mock_dir, mock_check_req, project_with_design, mock_ai_provider + ): + from azext_prototype.ai.provider import AIResponse + from azext_prototype.custom import prototype_build + + mock_dir.return_value = str(project_with_design) + mock_factory.return_value = mock_ai_provider + mock_ai_provider.chat.return_value = AIResponse( + content="```main.tf\nresource null {}\n```", + model="gpt-4o", + ) + + cmd = MagicMock() + result = prototype_build(cmd, scope="docs", dry_run=True, json_output=True) + assert result["status"] == "dry-run" + + +class TestPrototypeDeploy: + """Test the deploy command.""" + + @patch(f"{_MOD}._check_requirements") + @patch(f"{_MOD}._get_project_dir") + @patch("azext_prototype.ai.factory.create_ai_provider") + def test_deploy_status(self, mock_factory, mock_dir, mock_check_req, project_with_build, mock_ai_provider): + from azext_prototype.custom import prototype_deploy + + mock_dir.return_value = str(project_with_build) + mock_factory.return_value = mock_ai_provider + + cmd = MagicMock() + result = prototype_deploy(cmd, status=True, json_output=True) + assert result["status"] == "displayed" + + +class TestPrototypeDeployOutputs: + """Test deploy --outputs flag.""" + + @patch(f"{_MOD}._get_project_dir") + def test_no_outputs(self, mock_dir, project_with_build): + from azext_prototype.custom import prototype_deploy + + mock_dir.return_value = str(project_with_build) + cmd = MagicMock() + result = prototype_deploy(cmd, outputs=True, json_output=True) + assert result["status"] == "empty" + + @patch(f"{_MOD}._get_project_dir") + def test_with_outputs(self, mock_dir, project_with_build): + from azext_prototype.custom import prototype_deploy + + mock_dir.return_value = str(project_with_build) + # Write outputs file + outputs_dir = project_with_build / ".prototype" / "state" + outputs_dir.mkdir(parents=True, exist_ok=True) + (outputs_dir / "deploy_outputs.json").write_text(json.dumps({"rg_name": "test-rg"}), encoding="utf-8") + cmd = MagicMock() + result = prototype_deploy(cmd, outputs=True, json_output=True) + # May return empty or dict depending on DeploymentOutputCapture impl + assert isinstance(result, dict) + + +class TestPrototypeDeployRollbackInfo: + """Test deploy --rollback-info flag.""" + + @patch(f"{_MOD}._get_project_dir") + def test_rollback_info(self, mock_dir, project_with_build): + from azext_prototype.custom import prototype_deploy + + mock_dir.return_value = str(project_with_build) + cmd = MagicMock() + result = prototype_deploy(cmd, rollback_info=True, json_output=True) + assert "last_deployment" in result + assert "rollback_instructions" in result + + +class TestPrototypeDeployGenerateScripts: + """Test deploy --generate-scripts flag.""" + + @patch(f"{_MOD}._get_project_dir") + def test_generate_scripts_no_apps(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_deploy + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + # concept/apps exists but empty (not created by init; build creates it) + (project_with_config / "concept" / "apps").mkdir(parents=True, exist_ok=True) + result = prototype_deploy(cmd, generate_scripts=True, json_output=True) + assert result["status"] == "generated" + assert len(result["scripts"]) == 0 + + @patch(f"{_MOD}._get_project_dir") + def test_generate_scripts_with_apps(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_deploy + + mock_dir.return_value = str(project_with_config) + # Create app directories + apps_dir = project_with_config / "concept" / "apps" + (apps_dir / "backend").mkdir(parents=True, exist_ok=True) + (apps_dir / "frontend").mkdir(parents=True, exist_ok=True) + + cmd = MagicMock() + result = prototype_deploy(cmd, generate_scripts=True, script_deploy_type="webapp", json_output=True) + assert result["status"] == "generated" + assert len(result["scripts"]) == 2 + + @patch(f"{_MOD}._get_project_dir") + def test_generate_scripts_no_apps_dir_raises(self, mock_dir, project_with_config): + # Remove apps dir if present + import shutil + + from azext_prototype.custom import prototype_deploy + + apps_dir = project_with_config / "concept" / "apps" + if apps_dir.exists(): + shutil.rmtree(apps_dir) + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + with pytest.raises(CLIError, match="No apps directory"): + prototype_deploy(cmd, generate_scripts=True) + + +class TestPrototypeAgentOverride: + """Test agent override command.""" + + @patch(f"{_MOD}._get_project_dir") + def test_override_registers(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_override + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + # Create a real YAML file for the override + override_file = project_with_config / "my_arch.yaml" + override_file.write_text( + "name: cloud-architect\ndescription: Custom Override\n" + "capabilities:\n - architect\nsystem_prompt: Custom prompt.\n", + encoding="utf-8", + ) + + result = prototype_agent_override(cmd, name="cloud-architect", file="my_arch.yaml", json_output=True) + assert result["status"] == "override_registered" + assert result["name"] == "cloud-architect" + + def test_override_missing_name_raises(self): + from azext_prototype.custom import prototype_agent_override + + cmd = MagicMock() + with pytest.raises(CLIError, match="--name"): + prototype_agent_override(cmd, name=None, file="x.yaml") + + def test_override_missing_file_raises(self): + from azext_prototype.custom import prototype_agent_override + + cmd = MagicMock() + with pytest.raises(CLIError, match="--file"): + prototype_agent_override(cmd, name="x", file=None) + + +class TestPrototypeAgentRemove: + """Test agent remove command.""" + + @patch(f"{_MOD}._get_project_dir") + def test_remove_custom_agent(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_add, prototype_agent_remove + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + # Add then remove + prototype_agent_add(cmd, name="to-remove", definition="cloud_architect") + result = prototype_agent_remove(cmd, name="to-remove", json_output=True) + assert result["status"] == "removed" + + @patch(f"{_MOD}._get_project_dir") + def test_remove_override_agent(self, mock_dir, project_with_config): + from azext_prototype.custom import ( + prototype_agent_override, + prototype_agent_remove, + ) + + mock_dir.return_value = str(project_with_config) + + # Create a real YAML file for the override + override_file = project_with_config / "my_arch.yaml" + override_file.write_text( + "name: cloud-architect\ndescription: Override\n" "capabilities:\n - architect\nsystem_prompt: Override.\n", + encoding="utf-8", + ) + + cmd = MagicMock() + prototype_agent_override(cmd, name="cloud-architect", file="my_arch.yaml") + result = prototype_agent_remove(cmd, name="cloud-architect", json_output=True) + assert result["status"] == "override_removed" + + @patch(f"{_MOD}._get_project_dir") + def test_remove_builtin_raises(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_remove + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + # bicep-agent is builtin and not custom/override → should raise + with pytest.raises(CLIError, match="Built-in agents cannot be removed"): + prototype_agent_remove(cmd, name="app-developer") + + def test_remove_missing_name_raises(self): + from azext_prototype.custom import prototype_agent_remove + + cmd = MagicMock() + with pytest.raises(CLIError, match="--name"): + prototype_agent_remove(cmd, name=None) + + +class TestPrototypeAnalyzeError: + """Test the error analysis command.""" + + def test_missing_input_raises(self): + from azext_prototype.custom import prototype_analyze_error + + cmd = MagicMock() + with pytest.raises(CLIError, match="Error input is required"): + prototype_analyze_error(cmd, input=None) + + @patch(f"{_MOD}._prepare_command") + def test_analyze_inline_error(self, mock_prep, project_with_design, mock_ai_provider): + from azext_prototype.ai.provider import AIResponse + from azext_prototype.custom import prototype_analyze_error + + mock_qa = MagicMock() + mock_qa.name = "qa-engineer" + mock_qa.execute.return_value = AIResponse(content="Root cause: missing RBAC", model="gpt-4o") + + mock_registry = MagicMock() + mock_registry.find_by_capability.return_value = [mock_qa] + + mock_ctx = MagicMock() + mock_prep.return_value = (str(project_with_design), MagicMock(), mock_registry, mock_ctx) + + cmd = MagicMock() + result = prototype_analyze_error(cmd, input="ResourceNotFound error", json_output=True) + assert result["status"] == "analyzed" + + @patch(f"{_MOD}._prepare_command") + def test_analyze_log_file(self, mock_prep, project_with_design, mock_ai_provider): + from azext_prototype.ai.provider import AIResponse + from azext_prototype.custom import prototype_analyze_error + + mock_qa = MagicMock() + mock_qa.name = "qa-engineer" + mock_qa.execute.return_value = AIResponse(content="Root cause: config error", model="gpt-4o") + + mock_registry = MagicMock() + mock_registry.find_by_capability.return_value = [mock_qa] + + mock_ctx = MagicMock() + mock_prep.return_value = (str(project_with_design), MagicMock(), mock_registry, mock_ctx) + + log_file = project_with_design / "error.log" + log_file.write_text("ERROR: Connection refused", encoding="utf-8") + + cmd = MagicMock() + result = prototype_analyze_error(cmd, input=str(log_file), json_output=True) + assert result["status"] == "analyzed" + + @patch(f"{_MOD}._prepare_command") + def test_analyze_screenshot(self, mock_prep, project_with_design, mock_ai_provider): + from azext_prototype.ai.provider import AIResponse + from azext_prototype.custom import prototype_analyze_error + + mock_qa = MagicMock() + mock_qa.name = "qa-engineer" + mock_qa.execute_with_image.return_value = AIResponse(content="Screenshot analysis", model="gpt-4o") + + mock_registry = MagicMock() + mock_registry.find_by_capability.return_value = [mock_qa] + + mock_ctx = MagicMock() + mock_prep.return_value = (str(project_with_design), MagicMock(), mock_registry, mock_ctx) + + img = project_with_design / "error.png" + img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) + + cmd = MagicMock() + result = prototype_analyze_error(cmd, input=str(img), json_output=True) + assert result["status"] == "analyzed" + + +class TestPrototypeAnalyzeCosts: + """Test the cost analysis command.""" + + @patch(f"{_MOD}._prepare_command") + def test_analyze_costs(self, mock_prep, project_with_design, mock_ai_provider): + from azext_prototype.ai.provider import AIResponse + from azext_prototype.custom import prototype_analyze_costs + + mock_cost = MagicMock() + mock_cost.name = "cost-analyst" + mock_cost.execute.return_value = AIResponse(content="Cost report content", model="gpt-4o") + + mock_registry = MagicMock() + mock_registry.find_by_capability.return_value = [mock_cost] + + mock_ctx = MagicMock() + mock_prep.return_value = (str(project_with_design), MagicMock(), mock_registry, mock_ctx) + + cmd = MagicMock() + result = prototype_analyze_costs(cmd, json_output=True) + assert result["status"] == "analyzed" + + @patch(f"{_MOD}._prepare_command") + def test_analyze_costs_no_agent_raises(self, mock_prep, project_with_design): + from azext_prototype.custom import prototype_analyze_costs + + mock_registry = MagicMock() + mock_registry.find_by_capability.return_value = [] + mock_prep.return_value = (str(project_with_design), MagicMock(), mock_registry, MagicMock()) + + cmd = MagicMock() + with pytest.raises(CLIError, match="No cost analyst"): + prototype_analyze_costs(cmd) + + +class TestExtractCostTable: + """Test _extract_cost_table helper.""" + + def test_extracts_summary_table(self): + from azext_prototype.custom import _extract_cost_table + + content = ( + "# Executive Summary\n\nSome intro text.\n\n---\n\n" + "## Cost Summary Table\n\n" + " Service Small Medium Large\n" + " ──────────────────────────────────────────\n" + " App Service $0.00 $13.14 $74.00\n" + " TOTAL $0.00 $13.14 $74.00\n" + "\n\n---\n\n" + "## T-Shirt Size Definitions\n\nMore details...\n" + ) + result = _extract_cost_table(content) + assert "Cost Summary Table" in result + assert "$13.14" in result + assert "T-Shirt Size" not in result + + def test_fallback_on_no_heading(self): + from azext_prototype.custom import _extract_cost_table + + content = "No table here, just text about the architecture." + result = _extract_cost_table(content) + assert result == content + + +class TestPrototypeConfigSetExtended: + """Additional config set tests.""" + + def test_config_set_missing_value_raises(self): + from azext_prototype.custom import prototype_config_set + + cmd = MagicMock() + with pytest.raises(CLIError, match="--value"): + prototype_config_set(cmd, key="some.key", value=None) + + @patch(f"{_MOD}._get_project_dir") + def test_config_set_json_value(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_config_set + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + result = prototype_config_set(cmd, key="deploy.tags", value='{"env":"dev"}', json_output=True) + assert result["status"] == "updated" + + +class TestPrototypeStatusExtended: + """Extended status tests.""" + + @patch(f"{_MOD}._get_project_dir") + def test_status_with_build_shows_changes(self, mock_dir, project_with_build): + from azext_prototype.custom import prototype_status + + mock_dir.return_value = str(project_with_build) + cmd = MagicMock() + result = prototype_status(cmd, json_output=True) + # If build stage is marked completed, pending_changes should exist + if result.get("stages", {}).get("build", {}).get("completed"): + assert "pending_changes" in result + else: + # Build state exists → pending_changes may still be present + assert "stages" in result + + @patch(f"{_MOD}._get_project_dir") + def test_status_default_uses_console(self, mock_dir, project_with_config): + """Default mode (no flags) uses console output and returns None (suppressed).""" + from azext_prototype.custom import prototype_status + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + with patch("azext_prototype.custom.console", create=True): + result = prototype_status(cmd) + + assert result is None + + @patch(f"{_MOD}._get_project_dir") + def test_status_json_returns_enriched_dict(self, mock_dir, project_with_config): + """--json returns enriched dict with all new fields.""" + from azext_prototype.custom import prototype_status + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + result = prototype_status(cmd, json_output=True) + + assert isinstance(result, dict) + assert result["project"] == "test-project" + assert "environment" in result + assert "naming_strategy" in result + assert "project_id" in result + assert "deployment_history" in result + # All three stages present + for stage in ("design", "build", "deploy"): + assert stage in result["stages"] + assert "completed" in result["stages"][stage] + + @patch(f"{_MOD}._get_project_dir") + def test_status_detailed_prints_detail(self, mock_dir, project_with_config): + """--detailed prints expanded output and returns None (suppressed).""" + from azext_prototype.custom import prototype_status + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + with patch("azext_prototype.custom.console", create=True): + result = prototype_status(cmd, detailed=True) + + assert result is None + + @patch(f"{_MOD}._get_project_dir") + def test_status_with_discovery_state(self, mock_dir, project_with_config): + """Discovery state populates exchanges/confirmed/open.""" + import yaml + + from azext_prototype.custom import prototype_status + + state_dir = project_with_config / ".prototype" / "state" + state_dir.mkdir(parents=True, exist_ok=True) + state_file = state_dir / "discovery.yaml" + state_file.write_text( + yaml.dump( + { + "open_items": ["item1"], + "confirmed_items": ["item2", "item3"], + "conversation_history": [], + "_metadata": { + "exchange_count": 5, + "created": "2026-01-01T00:00:00", + "last_updated": "2026-01-01T01:00:00", + }, + } + ), + encoding="utf-8", + ) + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + result = prototype_status(cmd, json_output=True) + + d = result["stages"]["design"] + assert d["exchanges"] == 5 + assert d["confirmed"] == 2 + assert d["open"] == 1 + + @patch(f"{_MOD}._get_project_dir") + def test_status_with_build_state(self, mock_dir, project_with_build): + """Build state populates templates/stages/files/overrides.""" + from azext_prototype.custom import prototype_status + + mock_dir.return_value = str(project_with_build) + cmd = MagicMock() + result = prototype_status(cmd, json_output=True) + + b = result["stages"]["build"] + assert "templates_used" in b + assert "total_stages" in b + assert "accepted_stages" in b + assert "files_generated" in b + assert "policy_overrides" in b + assert b["total_stages"] >= 0 + + @patch(f"{_MOD}._get_project_dir") + def test_status_with_deploy_state(self, mock_dir, project_with_config): + """Deploy state populates deployed/failed/rolled_back/outputs.""" + import yaml + + from azext_prototype.custom import prototype_status + + state_dir = project_with_config / ".prototype" / "state" + state_dir.mkdir(parents=True, exist_ok=True) + state_file = state_dir / "deploy.yaml" + state_file.write_text( + yaml.dump( + { + "deployment_stages": [ + {"stage": 1, "name": "Foundation", "deploy_status": "deployed", "services": []}, + { + "stage": 2, + "name": "App", + "deploy_status": "failed", + "deploy_error": "timeout", + "services": [], + }, + ], + "captured_outputs": {"terraform": {"endpoint": "https://example.com"}}, + "_metadata": {"created": "2026-01-01T00:00:00", "last_updated": "2026-01-01T01:00:00"}, + } + ), + encoding="utf-8", + ) + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + result = prototype_status(cmd, json_output=True) + + dp = result["stages"]["deploy"] + assert dp["total_stages"] == 2 + assert dp["deployed"] == 1 + assert dp["failed"] == 1 + assert dp["rolled_back"] == 0 + assert dp["outputs_captured"] == 1 + + @patch(f"{_MOD}._get_project_dir") + def test_status_no_state_files(self, mock_dir, project_with_config): + """Config exists but no state files — stages show zero counts.""" + from azext_prototype.custom import prototype_status + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + result = prototype_status(cmd, json_output=True) + + d = result["stages"]["design"] + assert d["exchanges"] == 0 + assert d["confirmed"] == 0 + assert d["open"] == 0 + + b = result["stages"]["build"] + assert b["total_stages"] == 0 + assert b["files_generated"] == 0 + + dp = result["stages"]["deploy"] + assert dp["total_stages"] == 0 + assert dp["deployed"] == 0 + + @patch(f"{_MOD}._get_project_dir") + def test_status_deployment_history(self, mock_dir, project_with_config): + """Deployment history from ChangeTracker is included.""" + import json as json_mod + + from azext_prototype.custom import prototype_status + + # Create a manifest with deployment history + manifest_dir = project_with_config / ".prototype" / "state" + manifest_dir.mkdir(parents=True, exist_ok=True) + manifest_path = manifest_dir / "change_manifest.json" + manifest_path.write_text( + json_mod.dumps( + { + "files": {}, + "deployments": [ + {"scope": "all", "timestamp": "2026-01-15T10:00:00", "files_count": 12}, + ], + } + ), + encoding="utf-8", + ) + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + result = prototype_status(cmd, json_output=True) + + assert len(result["deployment_history"]) == 1 + assert result["deployment_history"][0]["scope"] == "all" + + @patch(f"{_MOD}._get_project_dir") + def test_status_detailed_json_returns_dict(self, mock_dir, project_with_config): + """When both detailed and json_output are True, json wins — returns dict.""" + from azext_prototype.custom import prototype_status + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + result = prototype_status(cmd, detailed=True, json_output=True) + + # json_output takes precedence — returns the enriched dict, not displayed + assert isinstance(result, dict) + assert "project" in result + assert result.get("status") != "displayed" + + +class TestLoadDesignContext: + """Test _load_design_context.""" + + def test_loads_from_design_json(self, project_with_design): + from azext_prototype.custom import _load_design_context + + result = _load_design_context(str(project_with_design)) + assert "Sample architecture" in result + + def test_loads_from_architecture_md(self, project_with_config): + from azext_prototype.custom import _load_design_context + + arch_md = project_with_config / "concept" / "docs" / "ARCHITECTURE.md" + arch_md.parent.mkdir(parents=True, exist_ok=True) + arch_md.write_text("# My Architecture\nDetails here.", encoding="utf-8") + + result = _load_design_context(str(project_with_config)) + assert "My Architecture" in result + + def test_returns_empty_when_no_design(self, tmp_project): + from azext_prototype.custom import _load_design_context + + result = _load_design_context(str(tmp_project)) + assert result == "" + + +class TestRenderTemplate: + """Test _render_template.""" + + def test_replaces_placeholders(self): + from azext_prototype.custom import _render_template + + template = "Project: [PROJECT_NAME], Region: [LOCATION], Date: [DATE]" + config = {"project": {"name": "my-proj", "location": "westus2"}} + result = _render_template(template, config) + assert "my-proj" in result + assert "westus2" in result + assert "[PROJECT_NAME]" not in result + + def test_keeps_unknown_placeholders(self): + from azext_prototype.custom import _render_template + + template = "[UNKNOWN_PLACEHOLDER] stays" + result = _render_template(template, {}) + assert "[UNKNOWN_PLACEHOLDER]" in result + + +class TestGenerateTemplates: + """Test _generate_templates shared helper.""" + + def test_generates_all_templates(self, project_with_config): + from azext_prototype.custom import _generate_templates, _load_config + + config = _load_config(str(project_with_config)) + output_dir = project_with_config / "test_output" + + generated = _generate_templates(output_dir, str(project_with_config), config.to_dict(), "test") + assert len(generated) >= 1 + assert output_dir.is_dir() + + def test_generates_with_manifest(self, project_with_config): + from azext_prototype.custom import _generate_templates, _load_config + + config = _load_config(str(project_with_config)) + output_dir = project_with_config / "speckit_output" + + _generate_templates( + output_dir, + str(project_with_config), + config.to_dict(), + "speckit", + include_manifest=True, + ) + assert (output_dir / "manifest.json").exists() + manifest = json.loads((output_dir / "manifest.json").read_text()) + assert "speckit_version" in manifest + + +# ====================================================================== +# _load_design_context — 3-source cascade +# ====================================================================== + + +class TestLoadDesignContextCascade: + """Test the 3-source cascade in _load_design_context.""" + + def test_loads_from_design_json(self, project_with_design): + """Source 1: design.json is used when present.""" + from azext_prototype.custom import _load_design_context + + result = _load_design_context(str(project_with_design)) + assert "Sample architecture" in result + + def test_falls_back_to_discovery_yaml(self, project_with_discovery): + """Source 2: discovery.yaml used when no design.json.""" + from azext_prototype.custom import _load_design_context + + result = _load_design_context(str(project_with_discovery)) + assert result # Should get non-empty context from discovery state + + def test_design_json_takes_priority(self, project_with_design): + """design.json takes priority over discovery.yaml when both exist.""" + import yaml as _yaml + + from azext_prototype.custom import _load_design_context + + # Add a discovery.yaml alongside the existing design.json + state_dir = project_with_design / ".prototype" / "state" + discovery = { + "project": {"summary": "Different content from discovery"}, + "confirmed_items": ["Different item"], + "_metadata": {"exchange_count": 1, "created": "2026-01-01T00:00:00", "last_updated": "2026-01-01T00:00:00"}, + } + (state_dir / "discovery.yaml").write_text(_yaml.dump(discovery), encoding="utf-8") + + result = _load_design_context(str(project_with_design)) + assert "Sample architecture" in result # design.json content, not discovery + + def test_falls_back_to_architecture_md(self, project_with_config): + """Source 3: ARCHITECTURE.md used when no state files exist.""" + from azext_prototype.custom import _load_design_context + + arch_md = project_with_config / "concept" / "docs" / "ARCHITECTURE.md" + arch_md.parent.mkdir(parents=True, exist_ok=True) + arch_md.write_text("# Architecture from markdown", encoding="utf-8") + + result = _load_design_context(str(project_with_config)) + assert "Architecture from markdown" in result + + def test_returns_empty_when_nothing(self, project_with_config): + """Returns empty string when no sources exist.""" + from azext_prototype.custom import _load_design_context + + result = _load_design_context(str(project_with_config)) + assert result == "" + + +# ====================================================================== +# Analyze costs — cache behavior +# ====================================================================== + + +class TestAnalyzeCostsCache: + """Test cost analysis caching (deterministic results).""" + + def _make_mock_prep(self, project_dir, mock_registry, mock_context): + """Build a _prepare_command return tuple.""" + from azext_prototype.config import ProjectConfig + + config = ProjectConfig(str(project_dir)) + config.load() + return (str(project_dir), config, mock_registry, mock_context) + + def _make_registry_with_cost_agent(self): + from tests.conftest import make_ai_response + + agent = MagicMock() + agent.name = "cost-analyst" + agent.execute.return_value = make_ai_response("## Cost Report\n| Service | Small | Medium | Large |") + + registry = MagicMock() + registry.find_by_capability.return_value = [agent] + return registry, agent + + @patch(f"{_MOD}._prepare_command") + def test_first_run_calls_agent_and_caches(self, mock_prep, project_with_design): + from azext_prototype.custom import prototype_analyze_costs + + registry, agent = self._make_registry_with_cost_agent() + mock_ctx = MagicMock() + mock_ctx.project_config = {"project": {"location": "eastus"}} + mock_prep.return_value = self._make_mock_prep(project_with_design, registry, mock_ctx) + + cmd = MagicMock() + result = prototype_analyze_costs(cmd, refresh=False, json_output=True) + + assert result["status"] == "analyzed" + agent.execute.assert_called_once() + + # Cache file should exist + cache = project_with_design / ".prototype" / "state" / "cost_analysis.yaml" + assert cache.exists() + + @patch(f"{_MOD}._prepare_command") + def test_second_run_returns_cached(self, mock_prep, project_with_design): + """Cached result returned without calling agent.""" + import yaml as _yaml + + from azext_prototype.custom import prototype_analyze_costs + + registry, agent = self._make_registry_with_cost_agent() + mock_ctx = MagicMock() + mock_ctx.project_config = {"project": {"location": "eastus"}} + mock_prep.return_value = self._make_mock_prep(project_with_design, registry, mock_ctx) + + # Pre-populate cache with matching hash + import hashlib + + from azext_prototype.custom import _load_design_context + + design_context = _load_design_context(str(project_with_design)) + context_hash = hashlib.sha256(design_context.encode("utf-8")).hexdigest()[:16] + + cache_data = { + "context_hash": context_hash, + "content": "Cached cost report content", + "result": {"status": "analyzed", "agent": "cost-analyst"}, + "timestamp": "2026-01-01T00:00:00+00:00", + } + cache_path = project_with_design / ".prototype" / "state" / "cost_analysis.yaml" + cache_path.write_text(_yaml.dump(cache_data, default_flow_style=False), encoding="utf-8") + + cmd = MagicMock() + result = prototype_analyze_costs(cmd, refresh=False, json_output=True) + + assert result["status"] == "analyzed" + agent.execute.assert_not_called() # Should NOT have called the agent + + @patch(f"{_MOD}._prepare_command") + def test_refresh_bypasses_cache(self, mock_prep, project_with_design): + """--refresh forces fresh analysis even when cache matches.""" + import yaml as _yaml + + from azext_prototype.custom import prototype_analyze_costs + + registry, agent = self._make_registry_with_cost_agent() + mock_ctx = MagicMock() + mock_ctx.project_config = {"project": {"location": "eastus"}} + mock_prep.return_value = self._make_mock_prep(project_with_design, registry, mock_ctx) + + # Pre-populate cache with matching hash + import hashlib + + from azext_prototype.custom import _load_design_context + + design_context = _load_design_context(str(project_with_design)) + context_hash = hashlib.sha256(design_context.encode("utf-8")).hexdigest()[:16] + + cache_data = { + "context_hash": context_hash, + "content": "Old cached content", + "result": {"status": "analyzed", "agent": "cost-analyst"}, + } + cache_path = project_with_design / ".prototype" / "state" / "cost_analysis.yaml" + cache_path.write_text(_yaml.dump(cache_data, default_flow_style=False), encoding="utf-8") + + cmd = MagicMock() + result = prototype_analyze_costs(cmd, refresh=True, json_output=True) + + assert result["status"] == "analyzed" + agent.execute.assert_called_once() # Should HAVE called the agent + + @patch(f"{_MOD}._prepare_command") + def test_cache_invalidated_on_design_change(self, mock_prep, project_with_design): + """Different design context hash invalidates the cache.""" + import yaml as _yaml + + from azext_prototype.custom import prototype_analyze_costs + + registry, agent = self._make_registry_with_cost_agent() + mock_ctx = MagicMock() + mock_ctx.project_config = {"project": {"location": "eastus"}} + mock_prep.return_value = self._make_mock_prep(project_with_design, registry, mock_ctx) + + # Pre-populate cache with a DIFFERENT hash + cache_data = { + "context_hash": "stale_hash_0000", + "content": "Stale cached content", + "result": {"status": "analyzed", "agent": "cost-analyst"}, + } + cache_path = project_with_design / ".prototype" / "state" / "cost_analysis.yaml" + cache_path.write_text(_yaml.dump(cache_data, default_flow_style=False), encoding="utf-8") + + cmd = MagicMock() + result = prototype_analyze_costs(cmd, refresh=False, json_output=True) + + assert result["status"] == "analyzed" + agent.execute.assert_called_once() # Stale cache — must re-run + + @patch(f"{_MOD}._prepare_command") + def test_cache_file_written_to_state_dir(self, mock_prep, project_with_design): + """Cache is written to .prototype/state/cost_analysis.yaml.""" + import yaml as _yaml + + from azext_prototype.custom import prototype_analyze_costs + + registry, agent = self._make_registry_with_cost_agent() + mock_ctx = MagicMock() + mock_ctx.project_config = {"project": {"location": "eastus"}} + mock_prep.return_value = self._make_mock_prep(project_with_design, registry, mock_ctx) + + cmd = MagicMock() + prototype_analyze_costs(cmd, refresh=False) + + cache_path = project_with_design / ".prototype" / "state" / "cost_analysis.yaml" + assert cache_path.exists() + cached = _yaml.safe_load(cache_path.read_text(encoding="utf-8")) + assert "context_hash" in cached + assert "content" in cached + assert "timestamp" in cached + + +# ====================================================================== +# Console output — analyze commands +# ====================================================================== + + +class TestAnalyzeConsoleOutput: + """Verify analyze commands use console.* methods (not raw print).""" + + @patch(f"{_MOD}._prepare_command") + @patch(f"{_MOD}.console", create=True) + def test_analyze_error_uses_console(self, mock_console, mock_prep, project_with_design): + from azext_prototype.custom import prototype_analyze_error + from tests.conftest import make_ai_response + + agent = MagicMock() + agent.name = "qa-engineer" + agent.execute.return_value = make_ai_response("## Fix\nDo something") + + registry = MagicMock() + registry.find_by_capability.return_value = [agent] + + config = MagicMock() + mock_prep.return_value = (str(project_with_design), config, registry, MagicMock()) + + cmd = MagicMock() + result = prototype_analyze_error(cmd, input="some error text", json_output=True) + + assert result["status"] == "analyzed" + + @patch(f"{_MOD}._prepare_command") + def test_analyze_error_warns_no_context(self, mock_prep, project_with_config): + """When no design context exists, a warning should be shown.""" + from azext_prototype.custom import prototype_analyze_error + from tests.conftest import make_ai_response + + agent = MagicMock() + agent.name = "qa-engineer" + agent.execute.return_value = make_ai_response("## Fix\nDo something") + + registry = MagicMock() + registry.find_by_capability.return_value = [agent] + + config = MagicMock() + mock_prep.return_value = (str(project_with_config), config, registry, MagicMock()) + + cmd = MagicMock() + + # Patch the module-level console singleton. We must use importlib + # because `import azext_prototype.ui.console` can resolve to the + # `console` variable re-exported in azext_prototype.ui.__init__ + # instead of the submodule (name collision on Python 3.10). + import importlib + + _console_mod = importlib.import_module("azext_prototype.ui.console") + + with patch.object(_console_mod, "console") as mock_console: # noqa: F841 + result = prototype_analyze_error(cmd, input="some error", json_output=True) + + assert result["status"] == "analyzed" + + @patch(f"{_MOD}._prepare_command") + def test_analyze_costs_uses_console(self, mock_prep, project_with_design): + from azext_prototype.custom import prototype_analyze_costs + from tests.conftest import make_ai_response + + agent = MagicMock() + agent.name = "cost-analyst" + agent.execute.return_value = make_ai_response("## Costs\n$100/mo") + + registry = MagicMock() + registry.find_by_capability.return_value = [agent] + + from azext_prototype.config import ProjectConfig + + config = ProjectConfig(str(project_with_design)) + config.load() + + mock_ctx = MagicMock() + mock_ctx.project_config = {"project": {"location": "eastus"}} + mock_prep.return_value = (str(project_with_design), config, registry, mock_ctx) + + cmd = MagicMock() + result = prototype_analyze_costs(cmd, refresh=True, json_output=True) + + assert result["status"] == "analyzed" + + +# ====================================================================== +# Console output — deploy subcommands +# ====================================================================== + + +class TestDeploySubcommandConsole: + """Verify deploy flag sub-actions use console.* methods.""" + + @patch(f"{_MOD}._get_project_dir") + def test_deploy_outputs_empty_warns(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_deploy + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + with patch("azext_prototype.stages.deploy_helpers.DeploymentOutputCapture") as MockCapture: + MockCapture.return_value.get_all.return_value = {} + result = prototype_deploy(cmd, outputs=True, json_output=True) + + assert result["status"] == "empty" + + @patch(f"{_MOD}._get_project_dir") + def test_deploy_rollback_info_empty_warns(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_deploy + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + with patch("azext_prototype.stages.deploy_helpers.RollbackManager") as MockMgr: + MockMgr.return_value.get_last_snapshot.return_value = None + MockMgr.return_value.get_rollback_instructions.return_value = None + result = prototype_deploy(cmd, rollback_info=True, json_output=True) + + assert result["last_deployment"] is None + assert result["rollback_instructions"] is None + + @patch(f"{_MOD}._get_project_dir") + @patch(f"{_MOD}._load_config") + def test_generate_scripts_uses_console(self, mock_config, mock_dir, project_with_config): + from azext_prototype.custom import prototype_deploy + + mock_dir.return_value = str(project_with_config) + mock_config.return_value = MagicMock() + mock_config.return_value.get.return_value = "" + + # Create an apps directory with a subdirectory + apps_dir = project_with_config / "concept" / "apps" + apps_dir.mkdir(parents=True, exist_ok=True) + (apps_dir / "my-app").mkdir() + + cmd = MagicMock() + + with patch("azext_prototype.stages.deploy_helpers.DeployScriptGenerator") as MockGen: # noqa: F841 + result = prototype_deploy(cmd, generate_scripts=True, json_output=True) + + assert result["status"] == "generated" + assert "my-app/deploy.sh" in result["scripts"] + + +# ====================================================================== +# Agent commands — Rich UI, new commands, validation +# ====================================================================== + + +class TestPrototypeAgentListRichUI: + """Test agent list Rich UI, json, and detailed modes.""" + + @patch(f"{_MOD}._get_project_dir") + def test_list_json_returns_list(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_list + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + result = prototype_agent_list(cmd, json_output=True) + assert isinstance(result, list) + assert len(result) >= 8 + + @patch(f"{_MOD}._get_project_dir") + def test_list_console_mode(self, mock_dir, project_with_config): + """Default (non-json) returns list and uses console.""" + from azext_prototype.custom import prototype_agent_list + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + result = prototype_agent_list(cmd, json_output=True) + assert isinstance(result, list) + + @patch(f"{_MOD}._get_project_dir") + def test_list_detailed_mode(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_list + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + result = prototype_agent_list(cmd, detailed=True, json_output=True) + assert isinstance(result, list) + + @patch(f"{_MOD}._get_project_dir") + def test_list_agents_have_source(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_list + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + result = prototype_agent_list(cmd, json_output=True) + for agent in result: + assert "source" in agent + + +class TestPrototypeAgentShowRichUI: + """Test agent show Rich UI, json, and detailed modes.""" + + @patch(f"{_MOD}._get_project_dir") + def test_show_json_returns_dict(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_show + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + result = prototype_agent_show(cmd, name="cloud-architect", json_output=True) + assert isinstance(result, dict) + assert result["name"] == "cloud-architect" + assert "system_prompt_preview" in result + + @patch(f"{_MOD}._get_project_dir") + def test_show_detailed_includes_full_prompt(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_show + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + result = prototype_agent_show(cmd, name="cloud-architect", detailed=True, json_output=True) + assert "system_prompt" in result + # detailed should not have preview + assert "system_prompt_preview" not in result + + @patch(f"{_MOD}._get_project_dir") + def test_show_console_mode(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_show + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + result = prototype_agent_show(cmd, name="cloud-architect", json_output=True) + assert isinstance(result, dict) + + +class TestPrototypeAgentUpdate: + """Test agent update command.""" + + @patch(f"{_MOD}._get_project_dir") + def test_update_description(self, mock_dir, project_with_config): + """Targeted field update — description only.""" + from azext_prototype.custom import prototype_agent_add, prototype_agent_update + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + prototype_agent_add(cmd, name="updatable", definition="cloud_architect") + result = prototype_agent_update(cmd, name="updatable", description="New desc", json_output=True) + assert result["status"] == "updated" + assert result["description"] == "New desc" + + @patch(f"{_MOD}._get_project_dir") + def test_update_capabilities(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_add, prototype_agent_update + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + prototype_agent_add(cmd, name="cap-update", definition="cloud_architect") + result = prototype_agent_update(cmd, name="cap-update", capabilities="architect,deploy", json_output=True) + assert result["status"] == "updated" + assert "architect" in result["capabilities"] + assert "deploy" in result["capabilities"] + + @patch(f"{_MOD}._get_project_dir") + def test_update_system_prompt_from_file(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_add, prototype_agent_update + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + prototype_agent_add(cmd, name="prompt-update", definition="cloud_architect") + + prompt_file = project_with_config / "new_prompt.txt" + prompt_file.write_text("You are an updated agent.", encoding="utf-8") + + result = prototype_agent_update( + cmd, name="prompt-update", system_prompt_file=str(prompt_file), json_output=True + ) + assert result["status"] == "updated" + + import yaml as _yaml + + agent_file = project_with_config / ".prototype" / "agents" / "prompt-update.yaml" + content = _yaml.safe_load(agent_file.read_text(encoding="utf-8")) + assert content["system_prompt"] == "You are an updated agent." + + @patch(f"{_MOD}._get_project_dir") + def test_update_interactive_mode(self, mock_dir, project_with_config): + """Interactive mode with mocked input.""" + from azext_prototype.custom import prototype_agent_add, prototype_agent_update + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + prototype_agent_add(cmd, name="interactive-up", definition="cloud_architect") + + # Mock interactive prompts: description, role, capabilities, constraints (empty), system prompt (empty=keep) + inputs = [ + "Updated description", # description + "architect", # role + "architect", # capabilities + "", # end constraints + "", # system prompt (keep existing - first empty line) + "", # examples (skip) + ] + with patch("builtins.input", side_effect=inputs): + result = prototype_agent_update(cmd, name="interactive-up", json_output=True) + + assert result["status"] == "updated" + assert result["description"] == "Updated description" + + @patch(f"{_MOD}._get_project_dir") + def test_update_manifest_sync(self, mock_dir, project_with_config): + """Manifest entry is updated after field update.""" + from azext_prototype.custom import ( + _load_config, + prototype_agent_add, + prototype_agent_update, + ) + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + prototype_agent_add(cmd, name="manifest-sync", definition="cloud_architect") + prototype_agent_update(cmd, name="manifest-sync", description="Synced desc") + + config = _load_config(str(project_with_config)) + custom = config.get("agents.custom", {}) + assert custom["manifest-sync"]["description"] == "Synced desc" + + def test_update_missing_name_raises(self): + from azext_prototype.custom import prototype_agent_update + + cmd = MagicMock() + with pytest.raises(CLIError, match="--name"): + prototype_agent_update(cmd, name=None) + + @patch(f"{_MOD}._get_project_dir") + def test_update_nonexistent_raises(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_update + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + with pytest.raises(CLIError, match="not found"): + prototype_agent_update(cmd, name="nonexistent-agent") + + @patch(f"{_MOD}._get_project_dir") + def test_update_invalid_capability_raises(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_add, prototype_agent_update + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + prototype_agent_add(cmd, name="bad-cap", definition="cloud_architect") + with pytest.raises(CLIError, match="Unknown capability"): + prototype_agent_update(cmd, name="bad-cap", capabilities="invalid_cap") + + @patch(f"{_MOD}._get_project_dir") + def test_update_prompt_file_not_found_raises(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_add, prototype_agent_update + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + prototype_agent_add(cmd, name="no-prompt", definition="cloud_architect") + with pytest.raises(CLIError, match="not found"): + prototype_agent_update(cmd, name="no-prompt", system_prompt_file="./does_not_exist.txt") + + +class TestPrototypeAgentTest: + """Test agent test command.""" + + @patch(f"{_MOD}._prepare_command") + def test_default_prompt(self, mock_prep, project_with_config, mock_ai_provider): + from azext_prototype.ai.provider import AIResponse + from azext_prototype.custom import prototype_agent_test + + mock_agent = MagicMock() + mock_agent.name = "cloud-architect" + mock_agent.execute.return_value = AIResponse( + content="I am the cloud architect.", + model="gpt-4o", + usage={"prompt_tokens": 50, "completion_tokens": 20, "total_tokens": 70}, + ) + + mock_registry = MagicMock() + mock_registry.get.return_value = mock_agent + mock_prep.return_value = (str(project_with_config), MagicMock(), mock_registry, MagicMock()) + + cmd = MagicMock() + result = prototype_agent_test(cmd, name="cloud-architect", json_output=True) + + assert result["status"] == "tested" + assert result["name"] == "cloud-architect" + assert result["model"] == "gpt-4o" + assert result["tokens"] == 70 + mock_agent.execute.assert_called_once() + + @patch(f"{_MOD}._prepare_command") + def test_custom_prompt(self, mock_prep, project_with_config, mock_ai_provider): + from azext_prototype.ai.provider import AIResponse + from azext_prototype.custom import prototype_agent_test + + mock_agent = MagicMock() + mock_agent.name = "cloud-architect" + mock_agent.execute.return_value = AIResponse( + content="Here is a web app design.", + model="gpt-4o", + usage={"total_tokens": 100}, + ) + + mock_registry = MagicMock() + mock_registry.get.return_value = mock_agent + mock_prep.return_value = (str(project_with_config), MagicMock(), mock_registry, MagicMock()) + + cmd = MagicMock() + result = prototype_agent_test(cmd, name="cloud-architect", prompt="Design a web app", json_output=True) + + assert result["status"] == "tested" + # Verify custom prompt was passed + call_args = mock_agent.execute.call_args + assert "Design a web app" in call_args[0][1] + + def test_test_missing_name_raises(self): + from azext_prototype.custom import prototype_agent_test + + cmd = MagicMock() + with pytest.raises(CLIError, match="--name"): + prototype_agent_test(cmd, name=None) + + +class TestPrototypeAgentExport: + """Test agent export command.""" + + @patch(f"{_MOD}._get_project_dir") + def test_export_builtin(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_export + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + output_path = str(project_with_config / "exported.yaml") + result = prototype_agent_export(cmd, name="cloud-architect", output_file=output_path, json_output=True) + + assert result["status"] == "exported" + assert result["name"] == "cloud-architect" + + import yaml as _yaml + + exported = _yaml.safe_load((project_with_config / "exported.yaml").read_text(encoding="utf-8")) + assert exported["name"] == "cloud-architect" + assert "capabilities" in exported + assert "system_prompt" in exported + + @patch(f"{_MOD}._get_project_dir") + def test_export_custom(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_add, prototype_agent_export + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + prototype_agent_add(cmd, name="export-test", definition="bicep_agent") + output_path = str(project_with_config / "custom_export.yaml") + result = prototype_agent_export(cmd, name="export-test", output_file=output_path, json_output=True) + + assert result["status"] == "exported" + assert (project_with_config / "custom_export.yaml").exists() + + @patch(f"{_MOD}._get_project_dir") + def test_export_default_path(self, mock_dir, project_with_config): + """Default output path is ./{name}.yaml.""" + import os + + from azext_prototype.custom import prototype_agent_export + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + # Change cwd to project dir for default path + original_cwd = os.getcwd() + try: + os.chdir(str(project_with_config)) + result = prototype_agent_export(cmd, name="cloud-architect", json_output=True) + assert result["status"] == "exported" + assert (project_with_config / "cloud-architect.yaml").exists() + finally: + os.chdir(original_cwd) + + @patch(f"{_MOD}._get_project_dir") + def test_export_loadable_by_loader(self, mock_dir, project_with_config): + """Exported YAML is loadable by load_yaml_agent.""" + from azext_prototype.agents.loader import load_yaml_agent + from azext_prototype.custom import prototype_agent_export + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + output_path = str(project_with_config / "loadable.yaml") + prototype_agent_export(cmd, name="cloud-architect", output_file=output_path) + + agent = load_yaml_agent(output_path) + assert agent.name == "cloud-architect" + + def test_export_missing_name_raises(self): + from azext_prototype.custom import prototype_agent_export + + cmd = MagicMock() + with pytest.raises(CLIError, match="--name"): + prototype_agent_export(cmd, name=None) + + +class TestPrototypeAgentOverrideValidation: + """Test override validation enhancements.""" + + @patch(f"{_MOD}._get_project_dir") + def test_override_file_not_found_raises(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_override + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + with pytest.raises(CLIError, match="not found"): + prototype_agent_override(cmd, name="cloud-architect", file="./does_not_exist.yaml") + + @patch(f"{_MOD}._get_project_dir") + def test_override_invalid_yaml_raises(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_override + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + bad_yaml = project_with_config / "bad.yaml" + bad_yaml.write_text("{{invalid yaml::", encoding="utf-8") + + with pytest.raises(CLIError, match="Invalid YAML"): + prototype_agent_override(cmd, name="cloud-architect", file="bad.yaml") + + @patch(f"{_MOD}._get_project_dir") + def test_override_missing_name_field_raises(self, mock_dir, project_with_config): + from azext_prototype.custom import prototype_agent_override + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + no_name = project_with_config / "no_name.yaml" + no_name.write_text("description: test\n", encoding="utf-8") + + with pytest.raises(CLIError, match="name"): + prototype_agent_override(cmd, name="cloud-architect", file="no_name.yaml") + + @patch(f"{_MOD}._get_project_dir") + def test_override_non_builtin_warns(self, mock_dir, project_with_config): + """Overriding a non-builtin name should warn but succeed.""" + from azext_prototype.custom import prototype_agent_override + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + valid_yaml = project_with_config / "valid.yaml" + valid_yaml.write_text( + "name: nonexistent-agent\ndescription: test\ncapabilities:\n - develop\n" "system_prompt: test\n", + encoding="utf-8", + ) + + result = prototype_agent_override(cmd, name="nonexistent-agent", file="valid.yaml", json_output=True) + assert result["status"] == "override_registered" + + @patch(f"{_MOD}._get_project_dir") + def test_override_valid_builtin(self, mock_dir, project_with_config): + """Overriding a known builtin should succeed without warnings.""" + from azext_prototype.custom import prototype_agent_override + + mock_dir.return_value = str(project_with_config) + cmd = MagicMock() + + valid_yaml = project_with_config / "arch_override.yaml" + valid_yaml.write_text( + "name: cloud-architect\ndescription: Custom arch\ncapabilities:\n - architect\n" + "system_prompt: Custom prompt.\n", + encoding="utf-8", + ) + + result = prototype_agent_override(cmd, name="cloud-architect", file="arch_override.yaml", json_output=True) + assert result["status"] == "override_registered" + + +class TestPromptAgentDefinition: + """Test the _prompt_agent_definition interactive helper.""" + + def test_full_walkthrough(self): + from azext_prototype.custom import _prompt_agent_definition + from azext_prototype.ui.console import Console + + console = Console() + inputs = [ + "My agent description", # description + "architect", # role + "architect,deploy", # capabilities + "Must use PaaS only", # constraint 1 + "", # end constraints + "You are a custom agent.", # system prompt line 1 + "END", # end system prompt + "", # no examples + ] + with patch("builtins.input", side_effect=inputs): + result = _prompt_agent_definition(console, "test-agent") + + assert result["name"] == "test-agent" + assert result["description"] == "My agent description" + assert result["role"] == "architect" + assert "architect" in result["capabilities"] + assert "deploy" in result["capabilities"] + assert "Must use PaaS only" in result["constraints"] + assert "You are a custom agent." in result["system_prompt"] + + def test_existing_defaults(self): + from azext_prototype.custom import _prompt_agent_definition + from azext_prototype.ui.console import Console + + console = Console() + existing = { + "description": "Old desc", + "role": "developer", + "capabilities": ["develop"], + "constraints": ["Old constraint"], + "system_prompt": "Old prompt.", + "examples": [{"user": "hello", "assistant": "hi"}], + } + # All empty inputs → keep existing values + inputs = [ + "", # description (keep) + "", # role (keep) + "", # capabilities (keep) + "", # constraints (keep existing) + "", # system prompt (keep existing) + "", # examples (keep existing) + ] + with patch("builtins.input", side_effect=inputs): + result = _prompt_agent_definition(console, "test-agent", existing=existing) + + assert result["description"] == "Old desc" + assert result["role"] == "developer" + assert result["capabilities"] == ["develop"] + assert result["constraints"] == ["Old constraint"] + assert result["system_prompt"] == "Old prompt." + assert result["examples"] == [{"user": "hello", "assistant": "hi"}] + + def test_invalid_capability_skipped(self): + from azext_prototype.custom import _prompt_agent_definition + from azext_prototype.ui.console import Console + + console = Console() + inputs = [ + "desc", # description + "role", # role + "invalid_cap,architect", # capabilities — one invalid + "", # end constraints + "prompt", # system prompt + "END", # end system prompt + "", # no examples + ] + with patch("builtins.input", side_effect=inputs): + result = _prompt_agent_definition(console, "test-agent") + + assert "architect" in result["capabilities"] + assert "invalid_cap" not in result["capabilities"] + + +class TestReadMultilineInput: + """Test _read_multiline_input helper.""" + + def test_reads_until_end(self): + from azext_prototype.custom import _read_multiline_input + + with patch("builtins.input", side_effect=["line 1", "line 2", "END"]): + result = _read_multiline_input() + assert result == "line 1\nline 2" + + def test_empty_first_line_returns_empty(self): + from azext_prototype.custom import _read_multiline_input + + with patch("builtins.input", side_effect=[""]): + result = _read_multiline_input() + assert result == "" diff --git a/tests/test_custom_extended.py b/tests/test_custom_extended.py deleted file mode 100644 index 85f3548..0000000 --- a/tests/test_custom_extended.py +++ /dev/null @@ -1,2204 +0,0 @@ -"""Tests for custom.py — additional coverage for stage commands and helpers.""" - -import json -from unittest.mock import MagicMock, patch - -import pytest -from knack.util import CLIError - -_MOD = "azext_prototype.custom" - - -# ====================================================================== -# Helper functions -# ====================================================================== - - -class TestBuildRegistry: - """Test _build_registry helper.""" - - def test_build_registry_builtin_only(self): - from azext_prototype.custom import _build_registry - - registry = _build_registry(config=None, project_dir=None) - agents = registry.list_all() - assert len(agents) >= 8 - - def test_build_registry_with_custom_agents(self, project_with_config): - from azext_prototype.custom import _build_registry, _load_config - - # Create a custom YAML agent - agent_dir = project_with_config / ".prototype" / "agents" - agent_dir.mkdir(parents=True, exist_ok=True) - (agent_dir / "test-agent.yaml").write_text( - "name: test-agent\ndescription: A test\ncapabilities:\n - develop\n" "system_prompt: You are a test.\n", - encoding="utf-8", - ) - - config = _load_config(str(project_with_config)) - registry = _build_registry(config, str(project_with_config)) - names = [a.name for a in registry.list_all()] - assert "test-agent" in names - - def test_build_registry_with_overrides(self, project_with_config): - from azext_prototype.custom import _build_registry, _load_config - - # Write a YAML agent to use as override - override_file = project_with_config / "override.yaml" - override_file.write_text( - "name: cloud-architect\ndescription: Override\ncapabilities:\n - architect\n" - "system_prompt: Override prompt.\n", - encoding="utf-8", - ) - - config = _load_config(str(project_with_config)) - config.set("agents.overrides", {"cloud-architect": "override.yaml"}) - - registry = _build_registry(config, str(project_with_config)) - agent = registry.get("cloud-architect") - assert "Override" in agent.description - - -class TestBuildContext: - """Test _build_context helper.""" - - @patch("azext_prototype.ai.factory.create_ai_provider") - def test_build_context_creates_agent_context(self, mock_factory, project_with_config): - from azext_prototype.custom import _build_context, _load_config - - mock_provider = MagicMock() - mock_factory.return_value = mock_provider - config = _load_config(str(project_with_config)) - - ctx = _build_context(config, str(project_with_config)) - assert ctx.project_dir == str(project_with_config) - assert ctx.ai_provider is mock_provider - - -class TestPrepareCommand: - """Test _prepare_command helper.""" - - @patch(f"{_MOD}._check_requirements") - @patch("azext_prototype.ai.factory.create_ai_provider") - def test_prepare_command(self, mock_factory, mock_check_req, project_with_config): - from azext_prototype.custom import _prepare_command - - mock_factory.return_value = MagicMock() - pd, config, registry, ctx = _prepare_command(str(project_with_config)) - assert pd == str(project_with_config) - assert config is not None - assert registry is not None - assert ctx is not None - - -class TestCheckRequirements: - """Test _check_requirements wiring in command entry points.""" - - def test_check_requirements_passes_when_all_ok(self): - from azext_prototype.custom import _check_requirements - from azext_prototype.requirements import CheckResult - - with patch("azext_prototype.requirements.check_all") as mock_check: - mock_check.return_value = [ - CheckResult(name="Python", status="pass", installed_version="3.12.0", required=">=3.9.0", message="ok"), - ] - # Should not raise - _check_requirements("terraform") - - def test_check_requirements_raises_on_missing(self): - from azext_prototype.custom import _check_requirements - from azext_prototype.requirements import CheckResult - - with patch("azext_prototype.requirements.check_all") as mock_check: - mock_check.return_value = [ - CheckResult( - name="Terraform", - status="missing", - installed_version=None, - required=">=1.14.0", - message="Terraform is not installed", - install_hint="https://developer.hashicorp.com/terraform/install", - ), - ] - with pytest.raises(CLIError, match="Tool requirements not met"): - _check_requirements("terraform") - - def test_check_requirements_raises_on_version_fail(self): - from azext_prototype.custom import _check_requirements - from azext_prototype.requirements import CheckResult - - with patch("azext_prototype.requirements.check_all") as mock_check: - mock_check.return_value = [ - CheckResult( - name="Azure CLI", - status="fail", - installed_version="2.40.0", - required=">=2.50.0", - message="Azure CLI 2.40.0 does not satisfy >=2.50.0", - install_hint="https://learn.microsoft.com/cli/azure/install-azure-cli", - ), - ] - with pytest.raises(CLIError, match="Azure CLI"): - _check_requirements(None) - - def test_check_requirements_includes_install_hint(self): - from azext_prototype.custom import _check_requirements - from azext_prototype.requirements import CheckResult - - with patch("azext_prototype.requirements.check_all") as mock_check: - mock_check.return_value = [ - CheckResult( - name="Terraform", - status="missing", - installed_version=None, - required=">=1.14.0", - message="Terraform is not installed", - install_hint="https://developer.hashicorp.com/terraform/install", - ), - ] - with pytest.raises(CLIError, match="Install:.*hashicorp"): - _check_requirements("terraform") - - @patch("azext_prototype.ai.factory.create_ai_provider") - def test_prepare_command_calls_check_requirements(self, mock_factory, project_with_config): - from azext_prototype.custom import _prepare_command - - mock_factory.return_value = MagicMock() - with patch(f"{_MOD}._check_requirements") as mock_check: - _prepare_command(str(project_with_config)) - mock_check.assert_called_once() - - def test_init_calls_check_requirements(self, tmp_path): - with patch(f"{_MOD}._check_requirements") as mock_check, patch( - "azext_prototype.stages.init_stage.InitStage" - ) as MockStage: - from azext_prototype.custom import prototype_init - - mock_stage = MockStage.return_value - mock_stage.can_run.return_value = (True, []) - mock_stage.execute.return_value = {"status": "success"} - - cmd = MagicMock() - prototype_init(cmd, name="test", location="eastus", output_dir=str(tmp_path)) - mock_check.assert_called_once_with("terraform") # default iac_tool - - -class TestCheckGuards: - """Test _check_guards helper.""" - - def test_check_guards_pass(self): - from azext_prototype.custom import _check_guards - - stage = MagicMock() - stage.can_run.return_value = (True, []) - _check_guards(stage) # Should not raise - - def test_check_guards_fail(self): - from azext_prototype.custom import _check_guards - - stage = MagicMock() - stage.can_run.return_value = (False, ["Missing gh CLI"]) - with pytest.raises(CLIError, match="Prerequisites not met"): - _check_guards(stage) - - -class TestGetRegistryWithFallback: - """Test _get_registry_with_fallback helper.""" - - def test_with_valid_config(self, project_with_config): - from azext_prototype.custom import _get_registry_with_fallback - - registry = _get_registry_with_fallback(str(project_with_config)) - assert len(registry.list_all()) >= 8 - - def test_without_config_falls_back(self, tmp_project): - from azext_prototype.custom import _get_registry_with_fallback - - registry = _get_registry_with_fallback(str(tmp_project)) - assert len(registry.list_all()) >= 8 - - -# ====================================================================== -# Stage commands -# ====================================================================== - - -class TestPrototypeInit: - """Test the init command.""" - - @patch(f"{_MOD}._check_requirements") - @patch(f"{_MOD}._check_guards") - @patch("azext_prototype.auth.copilot_license.CopilotLicenseValidator") - @patch("azext_prototype.auth.github_auth.GitHubAuthManager") - @patch("azext_prototype.stages.init_stage.InitStage._check_gh", return_value=True) - def test_init_success(self, mock_gh, mock_auth_cls, mock_lic_cls, mock_guards, mock_check_req, tmp_path): - from azext_prototype.custom import prototype_init - - mock_auth = MagicMock() - mock_auth.ensure_authenticated.return_value = {"login": "testuser"} - mock_auth_cls.return_value = mock_auth - - mock_lic = MagicMock() - mock_lic.validate_license.return_value = {"plan": "business", "status": "active"} - mock_lic_cls.return_value = mock_lic - - cmd = MagicMock() - out = tmp_path / "test-proj" - result = prototype_init( - cmd, - name="test-proj", - location="eastus", - output_dir=str(out), - ai_provider="github-models", - json_output=True, - ) - - assert result["status"] == "success" - assert result["github_user"] == "testuser" - assert out.is_dir() - assert (out / "prototype.yaml").exists() - assert (out / ".gitignore").exists() - - @patch(f"{_MOD}._check_requirements") - @patch(f"{_MOD}._check_guards") - def test_init_azure_openai_skips_license(self, mock_guards, mock_check_req, tmp_path): - from azext_prototype.custom import prototype_init - - cmd = MagicMock() - result = prototype_init( - cmd, - name="aoai-proj", - location="eastus", - output_dir=str(tmp_path / "aoai-proj"), - ai_provider="azure-openai", - json_output=True, - ) - - assert result["status"] == "success" - assert "copilot_license" not in result - assert result["github_user"] is None - - @patch(f"{_MOD}._check_requirements") - def test_init_missing_name_raises(self, mock_check_req, tmp_path): - from azext_prototype.custom import prototype_init - from azext_prototype.stages.init_stage import InitStage - - cmd = MagicMock() - # Need to bypass guards - with patch.object(InitStage, "get_guards", return_value=[]): - with pytest.raises(CLIError, match="Project name"): - prototype_init(cmd, name=None, location="eastus", output_dir=str(tmp_path / "no-name")) - - @patch(f"{_MOD}._check_requirements") - def test_init_missing_location_raises(self, mock_check_req, tmp_path): - from azext_prototype.custom import prototype_init - from azext_prototype.stages.init_stage import InitStage - - cmd = MagicMock() - with patch.object(InitStage, "get_guards", return_value=[]): - with pytest.raises(CLIError, match="region is required"): - prototype_init(cmd, name="test-proj", location=None, output_dir=str(tmp_path / "test-proj")) - - @patch(f"{_MOD}._check_requirements") - @patch(f"{_MOD}._check_guards") - def test_init_idempotency_cancel(self, mock_guards, mock_check_req, tmp_path): - """If project exists and user declines, init should cancel.""" - from azext_prototype.custom import prototype_init - - # Create existing project - proj_dir = tmp_path / "existing-proj" - proj_dir.mkdir() - (proj_dir / "prototype.yaml").write_text("project:\n name: old\n") - - cmd = MagicMock() - with patch("builtins.input", return_value="n"): - result = prototype_init( - cmd, - name="existing-proj", - location="eastus", - output_dir=str(proj_dir), - ai_provider="azure-openai", - json_output=True, - ) - assert result["status"] == "cancelled" - - @patch(f"{_MOD}._check_requirements") - @patch(f"{_MOD}._check_guards") - def test_init_idempotency_reinitialize(self, mock_guards, mock_check_req, tmp_path): - """If project exists and user confirms, init should proceed.""" - from azext_prototype.custom import prototype_init - - proj_dir = tmp_path / "reinit-proj" - proj_dir.mkdir() - (proj_dir / "prototype.yaml").write_text("project:\n name: old\n") - - cmd = MagicMock() - with patch("builtins.input", return_value="y"): - result = prototype_init( - cmd, - name="reinit-proj", - location="eastus", - output_dir=str(proj_dir), - ai_provider="azure-openai", - json_output=True, - ) - assert result["status"] == "success" - - @patch(f"{_MOD}._check_requirements") - @patch(f"{_MOD}._check_guards") - def test_init_environment_parameter(self, mock_guards, mock_check_req, tmp_path): - """--environment should be stored in config.""" - from azext_prototype.config import ProjectConfig - from azext_prototype.custom import prototype_init - - cmd = MagicMock() - out = tmp_path / "env-proj" - result = prototype_init( - cmd, - name="env-proj", - location="westus2", - output_dir=str(out), - ai_provider="azure-openai", - environment="staging", - json_output=True, - ) - assert result["status"] == "success" - config = ProjectConfig(str(out)) - config.load() - assert config.get("project.environment") == "staging" - assert config.get("naming.env") == "stg" - assert config.get("naming.zone_id") == "zs" - - @patch(f"{_MOD}._check_requirements") - @patch(f"{_MOD}._check_guards") - def test_init_model_parameter(self, mock_guards, mock_check_req, tmp_path): - """--model should override the provider default.""" - from azext_prototype.config import ProjectConfig - from azext_prototype.custom import prototype_init - - cmd = MagicMock() - out = tmp_path / "model-proj" - result = prototype_init( - cmd, - name="model-proj", - location="eastus", - output_dir=str(out), - ai_provider="azure-openai", - model="gpt-4o-mini", - json_output=True, - ) - assert result["status"] == "success" - config = ProjectConfig(str(out)) - config.load() - assert config.get("ai.model") == "gpt-4o-mini" - - @patch(f"{_MOD}._check_requirements") - @patch(f"{_MOD}._check_guards") - def test_init_default_model_per_provider(self, mock_guards, mock_check_req, tmp_path): - """Without --model, the default should be provider-specific.""" - from azext_prototype.config import ProjectConfig - from azext_prototype.custom import prototype_init - - cmd = MagicMock() - out = tmp_path / "defmodel-proj" - result = prototype_init( - cmd, - name="defmodel-proj", - location="eastus", - output_dir=str(out), - ai_provider="azure-openai", - json_output=True, - ) - assert result["status"] == "success" - config = ProjectConfig(str(out)) - config.load() - assert config.get("ai.model") == "gpt-4o" - - @patch(f"{_MOD}._check_requirements") - @patch(f"{_MOD}._check_guards") - def test_init_sends_telemetry_overrides(self, mock_guards, mock_check_req, tmp_path): - """Init should set _telemetry_overrides with resolved values.""" - from azext_prototype.custom import prototype_init - - cmd = MagicMock() - prototype_init( - cmd, - name="telem-proj", - location="westeurope", - output_dir=str(tmp_path / "telem-proj"), - ai_provider="azure-openai", - environment="staging", - iac_tool="bicep", - ) - - assert isinstance(cmd._telemetry_overrides, dict) - overrides = cmd._telemetry_overrides - assert overrides["location"] == "westeurope" - assert overrides["ai_provider"] == "azure-openai" - assert overrides["model"] == "gpt-4o" # resolved default - assert overrides["iac_tool"] == "bicep" - assert overrides["environment"] == "staging" - - @patch(f"{_MOD}._check_requirements") - @patch(f"{_MOD}._check_guards") - def test_init_telemetry_overrides_explicit_model(self, mock_guards, mock_check_req, tmp_path): - """When --model is explicit, overrides should use that value.""" - from azext_prototype.custom import prototype_init - - cmd = MagicMock() - prototype_init( - cmd, - name="telem-model-proj", - location="eastus", - output_dir=str(tmp_path / "telem-model-proj"), - ai_provider="azure-openai", - model="gpt-4o-mini", - ) - - overrides = cmd._telemetry_overrides - assert overrides["model"] == "gpt-4o-mini" - assert overrides["ai_provider"] == "azure-openai" - - -class TestPrototypeConfigGet: - """Test the config get command.""" - - def test_config_get_basic(self, project_with_config): - from azext_prototype.custom import prototype_config_get - - cmd = MagicMock() - with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): - result = prototype_config_get(cmd, key="ai.provider", json_output=True) - assert result == {"key": "ai.provider", "value": "github-models"} - - def test_config_get_missing_key(self, project_with_config): - from azext_prototype.custom import prototype_config_get - - cmd = MagicMock() - with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): - with pytest.raises(CLIError, match="not found"): - prototype_config_get(cmd, key="nonexistent.key") - - def test_config_get_masks_secret(self, project_with_config): - from azext_prototype.config import ProjectConfig - from azext_prototype.custom import prototype_config_get - - # Set a secret value first - config = ProjectConfig(str(project_with_config)) - config.load() - config._secrets = {"deploy": {"subscription": "secret-sub-id"}} - config._config["deploy"]["subscription"] = "secret-sub-id" - config.save() - config.save_secrets() - - cmd = MagicMock() - with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): - result = prototype_config_get(cmd, key="deploy.subscription", json_output=True) - assert result == {"key": "deploy.subscription", "value": "***"} - - -class TestPrototypeConfigShowMasking: - """Test that config show masks secrets.""" - - def test_config_show_masks_secret_values(self, project_with_config): - from azext_prototype.config import ProjectConfig - from azext_prototype.custom import prototype_config_show - - # Set a secret value - config = ProjectConfig(str(project_with_config)) - config.load() - config._secrets = {"deploy": {"subscription": "my-secret-sub"}} - config._config["deploy"]["subscription"] = "my-secret-sub" - config.save() - config.save_secrets() - - cmd = MagicMock() - with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): - result = prototype_config_show(cmd, json_output=True) - assert result["deploy"]["subscription"] == "***" - - def test_config_show_preserves_non_secrets(self, project_with_config): - from azext_prototype.custom import prototype_config_show - - cmd = MagicMock() - with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): - result = prototype_config_show(cmd, json_output=True) - # Non-secret value should not be masked - assert result["ai"]["provider"] == "github-models" - - -class TestPrototypeConfigInit: - """Test config init marks init complete.""" - - @patch( - "builtins.input", - side_effect=[ - "y", # overwrite existing prototype.yaml - "my-project", # project name - "eastus", # location - "dev", # environment - "terraform", # iac tool - "1", # naming strategy choice (microsoft-alz) - "myorg", # org - "zd", # zone_id (ALZ-specific) - "copilot", # ai provider - "", # model (accept default) - "", # subscription - "", # resource group - ], - ) - def test_config_init_marks_init_complete(self, mock_input, project_with_config): - from azext_prototype.config import ProjectConfig - from azext_prototype.custom import prototype_config_init - - cmd = MagicMock() - with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): - prototype_config_init(cmd) - - config = ProjectConfig(str(project_with_config)) - config.load() - assert config.get("stages.init.completed") is True - assert config.get("stages.init.timestamp") is not None - - @patch( - "builtins.input", - side_effect=[ - "y", # overwrite existing prototype.yaml - "telemetry-proj", # project name - "westus2", # location - "staging", # environment - "bicep", # iac tool - "2", # naming strategy choice (microsoft-caf) - "myorg", # org - "azure-openai", # ai provider - "gpt-4o", # model - "https://myres.openai.azure.com/", # Azure OpenAI endpoint - "gpt-4o", # deployment name - "", # subscription - "", # resource group - ], - ) - def test_config_init_sends_telemetry_overrides(self, mock_input, project_with_config): - """After prompting, config init should set _telemetry_overrides on cmd.""" - from azext_prototype.custom import prototype_config_init - - cmd = MagicMock() - with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): - prototype_config_init(cmd) - - assert hasattr(cmd, "_telemetry_overrides") - overrides = cmd._telemetry_overrides - assert overrides["location"] == "westus2" - assert overrides["ai_provider"] == "azure-openai" - assert overrides["model"] == "gpt-4o" - assert overrides["iac_tool"] == "bicep" - assert overrides["environment"] == "staging" - assert overrides["naming_strategy"] == "microsoft-caf" - - def test_config_init_cancelled_no_overrides(self, project_with_config): - """If config init is cancelled, no telemetry overrides should be set.""" - from azext_prototype.custom import prototype_config_init - - cmd = MagicMock(spec=[]) # strict spec — no auto-attributes - with patch(f"{_MOD}._get_project_dir", return_value=str(project_with_config)): - with patch("builtins.input", return_value="n"): - result = prototype_config_init(cmd, json_output=True) - assert result["status"] == "cancelled" - assert not hasattr(cmd, "_telemetry_overrides") - - -class TestPrototypeBuild: - """Test the build command.""" - - @patch(f"{_MOD}._check_requirements") - @patch(f"{_MOD}._get_project_dir") - @patch("azext_prototype.ai.factory.create_ai_provider") - @patch(f"{_MOD}._check_guards") - def test_build_calls_stage( - self, mock_guards, mock_factory, mock_dir, mock_check_req, project_with_design, mock_ai_provider - ): - from azext_prototype.ai.provider import AIResponse - from azext_prototype.custom import prototype_build - - mock_dir.return_value = str(project_with_design) - mock_factory.return_value = mock_ai_provider - mock_ai_provider.chat.return_value = AIResponse( - content="```main.tf\nresource null {}\n```", - model="gpt-4o", - ) - - cmd = MagicMock() - result = prototype_build(cmd, scope="docs", dry_run=True, json_output=True) - assert result["status"] == "dry-run" - - -class TestPrototypeDeploy: - """Test the deploy command.""" - - @patch(f"{_MOD}._check_requirements") - @patch(f"{_MOD}._get_project_dir") - @patch("azext_prototype.ai.factory.create_ai_provider") - def test_deploy_status(self, mock_factory, mock_dir, mock_check_req, project_with_build, mock_ai_provider): - from azext_prototype.custom import prototype_deploy - - mock_dir.return_value = str(project_with_build) - mock_factory.return_value = mock_ai_provider - - cmd = MagicMock() - result = prototype_deploy(cmd, status=True, json_output=True) - assert result["status"] == "displayed" - - -class TestPrototypeDeployOutputs: - """Test deploy --outputs flag.""" - - @patch(f"{_MOD}._get_project_dir") - def test_no_outputs(self, mock_dir, project_with_build): - from azext_prototype.custom import prototype_deploy - - mock_dir.return_value = str(project_with_build) - cmd = MagicMock() - result = prototype_deploy(cmd, outputs=True, json_output=True) - assert result["status"] == "empty" - - @patch(f"{_MOD}._get_project_dir") - def test_with_outputs(self, mock_dir, project_with_build): - from azext_prototype.custom import prototype_deploy - - mock_dir.return_value = str(project_with_build) - # Write outputs file - outputs_dir = project_with_build / ".prototype" / "state" - outputs_dir.mkdir(parents=True, exist_ok=True) - (outputs_dir / "deploy_outputs.json").write_text(json.dumps({"rg_name": "test-rg"}), encoding="utf-8") - cmd = MagicMock() - result = prototype_deploy(cmd, outputs=True, json_output=True) - # May return empty or dict depending on DeploymentOutputCapture impl - assert isinstance(result, dict) - - -class TestPrototypeDeployRollbackInfo: - """Test deploy --rollback-info flag.""" - - @patch(f"{_MOD}._get_project_dir") - def test_rollback_info(self, mock_dir, project_with_build): - from azext_prototype.custom import prototype_deploy - - mock_dir.return_value = str(project_with_build) - cmd = MagicMock() - result = prototype_deploy(cmd, rollback_info=True, json_output=True) - assert "last_deployment" in result - assert "rollback_instructions" in result - - -class TestPrototypeDeployGenerateScripts: - """Test deploy --generate-scripts flag.""" - - @patch(f"{_MOD}._get_project_dir") - def test_generate_scripts_no_apps(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_deploy - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - # concept/apps exists but empty (not created by init; build creates it) - (project_with_config / "concept" / "apps").mkdir(parents=True, exist_ok=True) - result = prototype_deploy(cmd, generate_scripts=True, json_output=True) - assert result["status"] == "generated" - assert len(result["scripts"]) == 0 - - @patch(f"{_MOD}._get_project_dir") - def test_generate_scripts_with_apps(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_deploy - - mock_dir.return_value = str(project_with_config) - # Create app directories - apps_dir = project_with_config / "concept" / "apps" - (apps_dir / "backend").mkdir(parents=True, exist_ok=True) - (apps_dir / "frontend").mkdir(parents=True, exist_ok=True) - - cmd = MagicMock() - result = prototype_deploy(cmd, generate_scripts=True, script_deploy_type="webapp", json_output=True) - assert result["status"] == "generated" - assert len(result["scripts"]) == 2 - - @patch(f"{_MOD}._get_project_dir") - def test_generate_scripts_no_apps_dir_raises(self, mock_dir, project_with_config): - # Remove apps dir if present - import shutil - - from azext_prototype.custom import prototype_deploy - - apps_dir = project_with_config / "concept" / "apps" - if apps_dir.exists(): - shutil.rmtree(apps_dir) - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - with pytest.raises(CLIError, match="No apps directory"): - prototype_deploy(cmd, generate_scripts=True) - - -class TestPrototypeAgentOverride: - """Test agent override command.""" - - @patch(f"{_MOD}._get_project_dir") - def test_override_registers(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_override - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - # Create a real YAML file for the override - override_file = project_with_config / "my_arch.yaml" - override_file.write_text( - "name: cloud-architect\ndescription: Custom Override\n" - "capabilities:\n - architect\nsystem_prompt: Custom prompt.\n", - encoding="utf-8", - ) - - result = prototype_agent_override(cmd, name="cloud-architect", file="my_arch.yaml", json_output=True) - assert result["status"] == "override_registered" - assert result["name"] == "cloud-architect" - - def test_override_missing_name_raises(self): - from azext_prototype.custom import prototype_agent_override - - cmd = MagicMock() - with pytest.raises(CLIError, match="--name"): - prototype_agent_override(cmd, name=None, file="x.yaml") - - def test_override_missing_file_raises(self): - from azext_prototype.custom import prototype_agent_override - - cmd = MagicMock() - with pytest.raises(CLIError, match="--file"): - prototype_agent_override(cmd, name="x", file=None) - - -class TestPrototypeAgentRemove: - """Test agent remove command.""" - - @patch(f"{_MOD}._get_project_dir") - def test_remove_custom_agent(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_add, prototype_agent_remove - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - # Add then remove - prototype_agent_add(cmd, name="to-remove", definition="cloud_architect") - result = prototype_agent_remove(cmd, name="to-remove", json_output=True) - assert result["status"] == "removed" - - @patch(f"{_MOD}._get_project_dir") - def test_remove_override_agent(self, mock_dir, project_with_config): - from azext_prototype.custom import ( - prototype_agent_override, - prototype_agent_remove, - ) - - mock_dir.return_value = str(project_with_config) - - # Create a real YAML file for the override - override_file = project_with_config / "my_arch.yaml" - override_file.write_text( - "name: cloud-architect\ndescription: Override\n" "capabilities:\n - architect\nsystem_prompt: Override.\n", - encoding="utf-8", - ) - - cmd = MagicMock() - prototype_agent_override(cmd, name="cloud-architect", file="my_arch.yaml") - result = prototype_agent_remove(cmd, name="cloud-architect", json_output=True) - assert result["status"] == "override_removed" - - @patch(f"{_MOD}._get_project_dir") - def test_remove_builtin_raises(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_remove - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - # bicep-agent is builtin and not custom/override → should raise - with pytest.raises(CLIError, match="Built-in agents cannot be removed"): - prototype_agent_remove(cmd, name="app-developer") - - def test_remove_missing_name_raises(self): - from azext_prototype.custom import prototype_agent_remove - - cmd = MagicMock() - with pytest.raises(CLIError, match="--name"): - prototype_agent_remove(cmd, name=None) - - -class TestPrototypeAnalyzeError: - """Test the error analysis command.""" - - def test_missing_input_raises(self): - from azext_prototype.custom import prototype_analyze_error - - cmd = MagicMock() - with pytest.raises(CLIError, match="Error input is required"): - prototype_analyze_error(cmd, input=None) - - @patch(f"{_MOD}._prepare_command") - def test_analyze_inline_error(self, mock_prep, project_with_design, mock_ai_provider): - from azext_prototype.ai.provider import AIResponse - from azext_prototype.custom import prototype_analyze_error - - mock_qa = MagicMock() - mock_qa.name = "qa-engineer" - mock_qa.execute.return_value = AIResponse(content="Root cause: missing RBAC", model="gpt-4o") - - mock_registry = MagicMock() - mock_registry.find_by_capability.return_value = [mock_qa] - - mock_ctx = MagicMock() - mock_prep.return_value = (str(project_with_design), MagicMock(), mock_registry, mock_ctx) - - cmd = MagicMock() - result = prototype_analyze_error(cmd, input="ResourceNotFound error", json_output=True) - assert result["status"] == "analyzed" - - @patch(f"{_MOD}._prepare_command") - def test_analyze_log_file(self, mock_prep, project_with_design, mock_ai_provider): - from azext_prototype.ai.provider import AIResponse - from azext_prototype.custom import prototype_analyze_error - - mock_qa = MagicMock() - mock_qa.name = "qa-engineer" - mock_qa.execute.return_value = AIResponse(content="Root cause: config error", model="gpt-4o") - - mock_registry = MagicMock() - mock_registry.find_by_capability.return_value = [mock_qa] - - mock_ctx = MagicMock() - mock_prep.return_value = (str(project_with_design), MagicMock(), mock_registry, mock_ctx) - - log_file = project_with_design / "error.log" - log_file.write_text("ERROR: Connection refused", encoding="utf-8") - - cmd = MagicMock() - result = prototype_analyze_error(cmd, input=str(log_file), json_output=True) - assert result["status"] == "analyzed" - - @patch(f"{_MOD}._prepare_command") - def test_analyze_screenshot(self, mock_prep, project_with_design, mock_ai_provider): - from azext_prototype.ai.provider import AIResponse - from azext_prototype.custom import prototype_analyze_error - - mock_qa = MagicMock() - mock_qa.name = "qa-engineer" - mock_qa.execute_with_image.return_value = AIResponse(content="Screenshot analysis", model="gpt-4o") - - mock_registry = MagicMock() - mock_registry.find_by_capability.return_value = [mock_qa] - - mock_ctx = MagicMock() - mock_prep.return_value = (str(project_with_design), MagicMock(), mock_registry, mock_ctx) - - img = project_with_design / "error.png" - img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) - - cmd = MagicMock() - result = prototype_analyze_error(cmd, input=str(img), json_output=True) - assert result["status"] == "analyzed" - - -class TestPrototypeAnalyzeCosts: - """Test the cost analysis command.""" - - @patch(f"{_MOD}._prepare_command") - def test_analyze_costs(self, mock_prep, project_with_design, mock_ai_provider): - from azext_prototype.ai.provider import AIResponse - from azext_prototype.custom import prototype_analyze_costs - - mock_cost = MagicMock() - mock_cost.name = "cost-analyst" - mock_cost.execute.return_value = AIResponse(content="Cost report content", model="gpt-4o") - - mock_registry = MagicMock() - mock_registry.find_by_capability.return_value = [mock_cost] - - mock_ctx = MagicMock() - mock_prep.return_value = (str(project_with_design), MagicMock(), mock_registry, mock_ctx) - - cmd = MagicMock() - result = prototype_analyze_costs(cmd, json_output=True) - assert result["status"] == "analyzed" - - @patch(f"{_MOD}._prepare_command") - def test_analyze_costs_no_agent_raises(self, mock_prep, project_with_design): - from azext_prototype.custom import prototype_analyze_costs - - mock_registry = MagicMock() - mock_registry.find_by_capability.return_value = [] - mock_prep.return_value = (str(project_with_design), MagicMock(), mock_registry, MagicMock()) - - cmd = MagicMock() - with pytest.raises(CLIError, match="No cost analyst"): - prototype_analyze_costs(cmd) - - -class TestExtractCostTable: - """Test _extract_cost_table helper.""" - - def test_extracts_summary_table(self): - from azext_prototype.custom import _extract_cost_table - - content = ( - "# Executive Summary\n\nSome intro text.\n\n---\n\n" - "## Cost Summary Table\n\n" - " Service Small Medium Large\n" - " ──────────────────────────────────────────\n" - " App Service $0.00 $13.14 $74.00\n" - " TOTAL $0.00 $13.14 $74.00\n" - "\n\n---\n\n" - "## T-Shirt Size Definitions\n\nMore details...\n" - ) - result = _extract_cost_table(content) - assert "Cost Summary Table" in result - assert "$13.14" in result - assert "T-Shirt Size" not in result - - def test_fallback_on_no_heading(self): - from azext_prototype.custom import _extract_cost_table - - content = "No table here, just text about the architecture." - result = _extract_cost_table(content) - assert result == content - - -class TestPrototypeConfigSet: - """Additional config set tests.""" - - def test_config_set_missing_value_raises(self): - from azext_prototype.custom import prototype_config_set - - cmd = MagicMock() - with pytest.raises(CLIError, match="--value"): - prototype_config_set(cmd, key="some.key", value=None) - - @patch(f"{_MOD}._get_project_dir") - def test_config_set_json_value(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_config_set - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - result = prototype_config_set(cmd, key="deploy.tags", value='{"env":"dev"}', json_output=True) - assert result["status"] == "updated" - - -class TestPrototypeStatusExtended: - """Extended status tests.""" - - @patch(f"{_MOD}._get_project_dir") - def test_status_with_build_shows_changes(self, mock_dir, project_with_build): - from azext_prototype.custom import prototype_status - - mock_dir.return_value = str(project_with_build) - cmd = MagicMock() - result = prototype_status(cmd, json_output=True) - # If build stage is marked completed, pending_changes should exist - if result.get("stages", {}).get("build", {}).get("completed"): - assert "pending_changes" in result - else: - # Build state exists → pending_changes may still be present - assert "stages" in result - - @patch(f"{_MOD}._get_project_dir") - def test_status_default_uses_console(self, mock_dir, project_with_config): - """Default mode (no flags) uses console output and returns None (suppressed).""" - from azext_prototype.custom import prototype_status - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - with patch("azext_prototype.custom.console", create=True): - result = prototype_status(cmd) - - assert result is None - - @patch(f"{_MOD}._get_project_dir") - def test_status_json_returns_enriched_dict(self, mock_dir, project_with_config): - """--json returns enriched dict with all new fields.""" - from azext_prototype.custom import prototype_status - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - result = prototype_status(cmd, json_output=True) - - assert isinstance(result, dict) - assert result["project"] == "test-project" - assert "environment" in result - assert "naming_strategy" in result - assert "project_id" in result - assert "deployment_history" in result - # All three stages present - for stage in ("design", "build", "deploy"): - assert stage in result["stages"] - assert "completed" in result["stages"][stage] - - @patch(f"{_MOD}._get_project_dir") - def test_status_detailed_prints_detail(self, mock_dir, project_with_config): - """--detailed prints expanded output and returns None (suppressed).""" - from azext_prototype.custom import prototype_status - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - with patch("azext_prototype.custom.console", create=True): - result = prototype_status(cmd, detailed=True) - - assert result is None - - @patch(f"{_MOD}._get_project_dir") - def test_status_with_discovery_state(self, mock_dir, project_with_config): - """Discovery state populates exchanges/confirmed/open.""" - import yaml - - from azext_prototype.custom import prototype_status - - state_dir = project_with_config / ".prototype" / "state" - state_dir.mkdir(parents=True, exist_ok=True) - state_file = state_dir / "discovery.yaml" - state_file.write_text( - yaml.dump( - { - "open_items": ["item1"], - "confirmed_items": ["item2", "item3"], - "conversation_history": [], - "_metadata": { - "exchange_count": 5, - "created": "2026-01-01T00:00:00", - "last_updated": "2026-01-01T01:00:00", - }, - } - ), - encoding="utf-8", - ) - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - result = prototype_status(cmd, json_output=True) - - d = result["stages"]["design"] - assert d["exchanges"] == 5 - assert d["confirmed"] == 2 - assert d["open"] == 1 - - @patch(f"{_MOD}._get_project_dir") - def test_status_with_build_state(self, mock_dir, project_with_build): - """Build state populates templates/stages/files/overrides.""" - from azext_prototype.custom import prototype_status - - mock_dir.return_value = str(project_with_build) - cmd = MagicMock() - result = prototype_status(cmd, json_output=True) - - b = result["stages"]["build"] - assert "templates_used" in b - assert "total_stages" in b - assert "accepted_stages" in b - assert "files_generated" in b - assert "policy_overrides" in b - assert b["total_stages"] >= 0 - - @patch(f"{_MOD}._get_project_dir") - def test_status_with_deploy_state(self, mock_dir, project_with_config): - """Deploy state populates deployed/failed/rolled_back/outputs.""" - import yaml - - from azext_prototype.custom import prototype_status - - state_dir = project_with_config / ".prototype" / "state" - state_dir.mkdir(parents=True, exist_ok=True) - state_file = state_dir / "deploy.yaml" - state_file.write_text( - yaml.dump( - { - "deployment_stages": [ - {"stage": 1, "name": "Foundation", "deploy_status": "deployed", "services": []}, - { - "stage": 2, - "name": "App", - "deploy_status": "failed", - "deploy_error": "timeout", - "services": [], - }, - ], - "captured_outputs": {"terraform": {"endpoint": "https://example.com"}}, - "_metadata": {"created": "2026-01-01T00:00:00", "last_updated": "2026-01-01T01:00:00"}, - } - ), - encoding="utf-8", - ) - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - result = prototype_status(cmd, json_output=True) - - dp = result["stages"]["deploy"] - assert dp["total_stages"] == 2 - assert dp["deployed"] == 1 - assert dp["failed"] == 1 - assert dp["rolled_back"] == 0 - assert dp["outputs_captured"] == 1 - - @patch(f"{_MOD}._get_project_dir") - def test_status_no_state_files(self, mock_dir, project_with_config): - """Config exists but no state files — stages show zero counts.""" - from azext_prototype.custom import prototype_status - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - result = prototype_status(cmd, json_output=True) - - d = result["stages"]["design"] - assert d["exchanges"] == 0 - assert d["confirmed"] == 0 - assert d["open"] == 0 - - b = result["stages"]["build"] - assert b["total_stages"] == 0 - assert b["files_generated"] == 0 - - dp = result["stages"]["deploy"] - assert dp["total_stages"] == 0 - assert dp["deployed"] == 0 - - @patch(f"{_MOD}._get_project_dir") - def test_status_deployment_history(self, mock_dir, project_with_config): - """Deployment history from ChangeTracker is included.""" - import json as json_mod - - from azext_prototype.custom import prototype_status - - # Create a manifest with deployment history - manifest_dir = project_with_config / ".prototype" / "state" - manifest_dir.mkdir(parents=True, exist_ok=True) - manifest_path = manifest_dir / "change_manifest.json" - manifest_path.write_text( - json_mod.dumps( - { - "files": {}, - "deployments": [ - {"scope": "all", "timestamp": "2026-01-15T10:00:00", "files_count": 12}, - ], - } - ), - encoding="utf-8", - ) - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - result = prototype_status(cmd, json_output=True) - - assert len(result["deployment_history"]) == 1 - assert result["deployment_history"][0]["scope"] == "all" - - @patch(f"{_MOD}._get_project_dir") - def test_status_detailed_json_returns_dict(self, mock_dir, project_with_config): - """When both detailed and json_output are True, json wins — returns dict.""" - from azext_prototype.custom import prototype_status - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - result = prototype_status(cmd, detailed=True, json_output=True) - - # json_output takes precedence — returns the enriched dict, not displayed - assert isinstance(result, dict) - assert "project" in result - assert result.get("status") != "displayed" - - -class TestLoadDesignContext: - """Test _load_design_context.""" - - def test_loads_from_design_json(self, project_with_design): - from azext_prototype.custom import _load_design_context - - result = _load_design_context(str(project_with_design)) - assert "Sample architecture" in result - - def test_loads_from_architecture_md(self, project_with_config): - from azext_prototype.custom import _load_design_context - - arch_md = project_with_config / "concept" / "docs" / "ARCHITECTURE.md" - arch_md.parent.mkdir(parents=True, exist_ok=True) - arch_md.write_text("# My Architecture\nDetails here.", encoding="utf-8") - - result = _load_design_context(str(project_with_config)) - assert "My Architecture" in result - - def test_returns_empty_when_no_design(self, tmp_project): - from azext_prototype.custom import _load_design_context - - result = _load_design_context(str(tmp_project)) - assert result == "" - - -class TestRenderTemplate: - """Test _render_template.""" - - def test_replaces_placeholders(self): - from azext_prototype.custom import _render_template - - template = "Project: [PROJECT_NAME], Region: [LOCATION], Date: [DATE]" - config = {"project": {"name": "my-proj", "location": "westus2"}} - result = _render_template(template, config) - assert "my-proj" in result - assert "westus2" in result - assert "[PROJECT_NAME]" not in result - - def test_keeps_unknown_placeholders(self): - from azext_prototype.custom import _render_template - - template = "[UNKNOWN_PLACEHOLDER] stays" - result = _render_template(template, {}) - assert "[UNKNOWN_PLACEHOLDER]" in result - - -class TestGenerateTemplates: - """Test _generate_templates shared helper.""" - - def test_generates_all_templates(self, project_with_config): - from azext_prototype.custom import _generate_templates, _load_config - - config = _load_config(str(project_with_config)) - output_dir = project_with_config / "test_output" - - generated = _generate_templates(output_dir, str(project_with_config), config.to_dict(), "test") - assert len(generated) >= 1 - assert output_dir.is_dir() - - def test_generates_with_manifest(self, project_with_config): - from azext_prototype.custom import _generate_templates, _load_config - - config = _load_config(str(project_with_config)) - output_dir = project_with_config / "speckit_output" - - _generate_templates( - output_dir, - str(project_with_config), - config.to_dict(), - "speckit", - include_manifest=True, - ) - assert (output_dir / "manifest.json").exists() - manifest = json.loads((output_dir / "manifest.json").read_text()) - assert "speckit_version" in manifest - - -# ====================================================================== -# _load_design_context — 3-source cascade -# ====================================================================== - -_MOD = "azext_prototype.custom" - - -class TestLoadDesignContextCascade: - """Test the 3-source cascade in _load_design_context.""" - - def test_loads_from_design_json(self, project_with_design): - """Source 1: design.json is used when present.""" - from azext_prototype.custom import _load_design_context - - result = _load_design_context(str(project_with_design)) - assert "Sample architecture" in result - - def test_falls_back_to_discovery_yaml(self, project_with_discovery): - """Source 2: discovery.yaml used when no design.json.""" - from azext_prototype.custom import _load_design_context - - result = _load_design_context(str(project_with_discovery)) - assert result # Should get non-empty context from discovery state - - def test_design_json_takes_priority(self, project_with_design): - """design.json takes priority over discovery.yaml when both exist.""" - import yaml as _yaml - - from azext_prototype.custom import _load_design_context - - # Add a discovery.yaml alongside the existing design.json - state_dir = project_with_design / ".prototype" / "state" - discovery = { - "project": {"summary": "Different content from discovery"}, - "confirmed_items": ["Different item"], - "_metadata": {"exchange_count": 1, "created": "2026-01-01T00:00:00", "last_updated": "2026-01-01T00:00:00"}, - } - (state_dir / "discovery.yaml").write_text(_yaml.dump(discovery), encoding="utf-8") - - result = _load_design_context(str(project_with_design)) - assert "Sample architecture" in result # design.json content, not discovery - - def test_falls_back_to_architecture_md(self, project_with_config): - """Source 3: ARCHITECTURE.md used when no state files exist.""" - from azext_prototype.custom import _load_design_context - - arch_md = project_with_config / "concept" / "docs" / "ARCHITECTURE.md" - arch_md.parent.mkdir(parents=True, exist_ok=True) - arch_md.write_text("# Architecture from markdown", encoding="utf-8") - - result = _load_design_context(str(project_with_config)) - assert "Architecture from markdown" in result - - def test_returns_empty_when_nothing(self, project_with_config): - """Returns empty string when no sources exist.""" - from azext_prototype.custom import _load_design_context - - result = _load_design_context(str(project_with_config)) - assert result == "" - - -# ====================================================================== -# Analyze costs — cache behavior -# ====================================================================== - - -class TestAnalyzeCostsCache: - """Test cost analysis caching (deterministic results).""" - - def _make_mock_prep(self, project_dir, mock_registry, mock_context): - """Build a _prepare_command return tuple.""" - from azext_prototype.config import ProjectConfig - - config = ProjectConfig(str(project_dir)) - config.load() - return (str(project_dir), config, mock_registry, mock_context) - - def _make_registry_with_cost_agent(self): - from tests.conftest import make_ai_response - - agent = MagicMock() - agent.name = "cost-analyst" - agent.execute.return_value = make_ai_response("## Cost Report\n| Service | Small | Medium | Large |") - - registry = MagicMock() - registry.find_by_capability.return_value = [agent] - return registry, agent - - @patch(f"{_MOD}._prepare_command") - def test_first_run_calls_agent_and_caches(self, mock_prep, project_with_design): - from azext_prototype.custom import prototype_analyze_costs - - registry, agent = self._make_registry_with_cost_agent() - mock_ctx = MagicMock() - mock_ctx.project_config = {"project": {"location": "eastus"}} - mock_prep.return_value = self._make_mock_prep(project_with_design, registry, mock_ctx) - - cmd = MagicMock() - result = prototype_analyze_costs(cmd, refresh=False, json_output=True) - - assert result["status"] == "analyzed" - agent.execute.assert_called_once() - - # Cache file should exist - cache = project_with_design / ".prototype" / "state" / "cost_analysis.yaml" - assert cache.exists() - - @patch(f"{_MOD}._prepare_command") - def test_second_run_returns_cached(self, mock_prep, project_with_design): - """Cached result returned without calling agent.""" - import yaml as _yaml - - from azext_prototype.custom import prototype_analyze_costs - - registry, agent = self._make_registry_with_cost_agent() - mock_ctx = MagicMock() - mock_ctx.project_config = {"project": {"location": "eastus"}} - mock_prep.return_value = self._make_mock_prep(project_with_design, registry, mock_ctx) - - # Pre-populate cache with matching hash - import hashlib - - from azext_prototype.custom import _load_design_context - - design_context = _load_design_context(str(project_with_design)) - context_hash = hashlib.sha256(design_context.encode("utf-8")).hexdigest()[:16] - - cache_data = { - "context_hash": context_hash, - "content": "Cached cost report content", - "result": {"status": "analyzed", "agent": "cost-analyst"}, - "timestamp": "2026-01-01T00:00:00+00:00", - } - cache_path = project_with_design / ".prototype" / "state" / "cost_analysis.yaml" - cache_path.write_text(_yaml.dump(cache_data, default_flow_style=False), encoding="utf-8") - - cmd = MagicMock() - result = prototype_analyze_costs(cmd, refresh=False, json_output=True) - - assert result["status"] == "analyzed" - agent.execute.assert_not_called() # Should NOT have called the agent - - @patch(f"{_MOD}._prepare_command") - def test_refresh_bypasses_cache(self, mock_prep, project_with_design): - """--refresh forces fresh analysis even when cache matches.""" - import yaml as _yaml - - from azext_prototype.custom import prototype_analyze_costs - - registry, agent = self._make_registry_with_cost_agent() - mock_ctx = MagicMock() - mock_ctx.project_config = {"project": {"location": "eastus"}} - mock_prep.return_value = self._make_mock_prep(project_with_design, registry, mock_ctx) - - # Pre-populate cache with matching hash - import hashlib - - from azext_prototype.custom import _load_design_context - - design_context = _load_design_context(str(project_with_design)) - context_hash = hashlib.sha256(design_context.encode("utf-8")).hexdigest()[:16] - - cache_data = { - "context_hash": context_hash, - "content": "Old cached content", - "result": {"status": "analyzed", "agent": "cost-analyst"}, - } - cache_path = project_with_design / ".prototype" / "state" / "cost_analysis.yaml" - cache_path.write_text(_yaml.dump(cache_data, default_flow_style=False), encoding="utf-8") - - cmd = MagicMock() - result = prototype_analyze_costs(cmd, refresh=True, json_output=True) - - assert result["status"] == "analyzed" - agent.execute.assert_called_once() # Should HAVE called the agent - - @patch(f"{_MOD}._prepare_command") - def test_cache_invalidated_on_design_change(self, mock_prep, project_with_design): - """Different design context hash invalidates the cache.""" - import yaml as _yaml - - from azext_prototype.custom import prototype_analyze_costs - - registry, agent = self._make_registry_with_cost_agent() - mock_ctx = MagicMock() - mock_ctx.project_config = {"project": {"location": "eastus"}} - mock_prep.return_value = self._make_mock_prep(project_with_design, registry, mock_ctx) - - # Pre-populate cache with a DIFFERENT hash - cache_data = { - "context_hash": "stale_hash_0000", - "content": "Stale cached content", - "result": {"status": "analyzed", "agent": "cost-analyst"}, - } - cache_path = project_with_design / ".prototype" / "state" / "cost_analysis.yaml" - cache_path.write_text(_yaml.dump(cache_data, default_flow_style=False), encoding="utf-8") - - cmd = MagicMock() - result = prototype_analyze_costs(cmd, refresh=False, json_output=True) - - assert result["status"] == "analyzed" - agent.execute.assert_called_once() # Stale cache — must re-run - - @patch(f"{_MOD}._prepare_command") - def test_cache_file_written_to_state_dir(self, mock_prep, project_with_design): - """Cache is written to .prototype/state/cost_analysis.yaml.""" - import yaml as _yaml - - from azext_prototype.custom import prototype_analyze_costs - - registry, agent = self._make_registry_with_cost_agent() - mock_ctx = MagicMock() - mock_ctx.project_config = {"project": {"location": "eastus"}} - mock_prep.return_value = self._make_mock_prep(project_with_design, registry, mock_ctx) - - cmd = MagicMock() - prototype_analyze_costs(cmd, refresh=False) - - cache_path = project_with_design / ".prototype" / "state" / "cost_analysis.yaml" - assert cache_path.exists() - cached = _yaml.safe_load(cache_path.read_text(encoding="utf-8")) - assert "context_hash" in cached - assert "content" in cached - assert "timestamp" in cached - - -# ====================================================================== -# Console output — analyze commands -# ====================================================================== - - -class TestAnalyzeConsoleOutput: - """Verify analyze commands use console.* methods (not raw print).""" - - @patch(f"{_MOD}._prepare_command") - @patch(f"{_MOD}.console", create=True) - def test_analyze_error_uses_console(self, mock_console, mock_prep, project_with_design): - from azext_prototype.custom import prototype_analyze_error - from tests.conftest import make_ai_response - - agent = MagicMock() - agent.name = "qa-engineer" - agent.execute.return_value = make_ai_response("## Fix\nDo something") - - registry = MagicMock() - registry.find_by_capability.return_value = [agent] - - config = MagicMock() - mock_prep.return_value = (str(project_with_design), config, registry, MagicMock()) - - cmd = MagicMock() - result = prototype_analyze_error(cmd, input="some error text", json_output=True) - - assert result["status"] == "analyzed" - - @patch(f"{_MOD}._prepare_command") - def test_analyze_error_warns_no_context(self, mock_prep, project_with_config): - """When no design context exists, a warning should be shown.""" - from azext_prototype.custom import prototype_analyze_error - from tests.conftest import make_ai_response - - agent = MagicMock() - agent.name = "qa-engineer" - agent.execute.return_value = make_ai_response("## Fix\nDo something") - - registry = MagicMock() - registry.find_by_capability.return_value = [agent] - - config = MagicMock() - mock_prep.return_value = (str(project_with_config), config, registry, MagicMock()) - - cmd = MagicMock() - - # Patch the module-level console singleton. We must use importlib - # because `import azext_prototype.ui.console` can resolve to the - # `console` variable re-exported in azext_prototype.ui.__init__ - # instead of the submodule (name collision on Python 3.10). - import importlib - - _console_mod = importlib.import_module("azext_prototype.ui.console") - - with patch.object(_console_mod, "console") as mock_console: # noqa: F841 - result = prototype_analyze_error(cmd, input="some error", json_output=True) - - assert result["status"] == "analyzed" - - @patch(f"{_MOD}._prepare_command") - def test_analyze_costs_uses_console(self, mock_prep, project_with_design): - from azext_prototype.custom import prototype_analyze_costs - from tests.conftest import make_ai_response - - agent = MagicMock() - agent.name = "cost-analyst" - agent.execute.return_value = make_ai_response("## Costs\n$100/mo") - - registry = MagicMock() - registry.find_by_capability.return_value = [agent] - - from azext_prototype.config import ProjectConfig - - config = ProjectConfig(str(project_with_design)) - config.load() - - mock_ctx = MagicMock() - mock_ctx.project_config = {"project": {"location": "eastus"}} - mock_prep.return_value = (str(project_with_design), config, registry, mock_ctx) - - cmd = MagicMock() - result = prototype_analyze_costs(cmd, refresh=True, json_output=True) - - assert result["status"] == "analyzed" - - -# ====================================================================== -# Console output — deploy subcommands -# ====================================================================== - - -class TestDeploySubcommandConsole: - """Verify deploy flag sub-actions use console.* methods.""" - - @patch(f"{_MOD}._get_project_dir") - def test_deploy_outputs_empty_warns(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_deploy - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - with patch("azext_prototype.stages.deploy_helpers.DeploymentOutputCapture") as MockCapture: - MockCapture.return_value.get_all.return_value = {} - result = prototype_deploy(cmd, outputs=True, json_output=True) - - assert result["status"] == "empty" - - @patch(f"{_MOD}._get_project_dir") - def test_deploy_rollback_info_empty_warns(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_deploy - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - with patch("azext_prototype.stages.deploy_helpers.RollbackManager") as MockMgr: - MockMgr.return_value.get_last_snapshot.return_value = None - MockMgr.return_value.get_rollback_instructions.return_value = None - result = prototype_deploy(cmd, rollback_info=True, json_output=True) - - assert result["last_deployment"] is None - assert result["rollback_instructions"] is None - - @patch(f"{_MOD}._get_project_dir") - @patch(f"{_MOD}._load_config") - def test_generate_scripts_uses_console(self, mock_config, mock_dir, project_with_config): - from azext_prototype.custom import prototype_deploy - - mock_dir.return_value = str(project_with_config) - mock_config.return_value = MagicMock() - mock_config.return_value.get.return_value = "" - - # Create an apps directory with a subdirectory - apps_dir = project_with_config / "concept" / "apps" - apps_dir.mkdir(parents=True, exist_ok=True) - (apps_dir / "my-app").mkdir() - - cmd = MagicMock() - - with patch("azext_prototype.stages.deploy_helpers.DeployScriptGenerator") as MockGen: # noqa: F841 - result = prototype_deploy(cmd, generate_scripts=True, json_output=True) - - assert result["status"] == "generated" - assert "my-app/deploy.sh" in result["scripts"] - - -# ====================================================================== -# Agent commands — Rich UI, new commands, validation -# ====================================================================== - -_MOD = "azext_prototype.custom" - - -class TestPrototypeAgentListRichUI: - """Test agent list Rich UI, json, and detailed modes.""" - - @patch(f"{_MOD}._get_project_dir") - def test_list_json_returns_list(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_list - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - result = prototype_agent_list(cmd, json_output=True) - assert isinstance(result, list) - assert len(result) >= 8 - - @patch(f"{_MOD}._get_project_dir") - def test_list_console_mode(self, mock_dir, project_with_config): - """Default (non-json) returns list and uses console.""" - from azext_prototype.custom import prototype_agent_list - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - result = prototype_agent_list(cmd, json_output=True) - assert isinstance(result, list) - - @patch(f"{_MOD}._get_project_dir") - def test_list_detailed_mode(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_list - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - result = prototype_agent_list(cmd, detailed=True, json_output=True) - assert isinstance(result, list) - - @patch(f"{_MOD}._get_project_dir") - def test_list_agents_have_source(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_list - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - result = prototype_agent_list(cmd, json_output=True) - for agent in result: - assert "source" in agent - - -class TestPrototypeAgentShowRichUI: - """Test agent show Rich UI, json, and detailed modes.""" - - @patch(f"{_MOD}._get_project_dir") - def test_show_json_returns_dict(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_show - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - result = prototype_agent_show(cmd, name="cloud-architect", json_output=True) - assert isinstance(result, dict) - assert result["name"] == "cloud-architect" - assert "system_prompt_preview" in result - - @patch(f"{_MOD}._get_project_dir") - def test_show_detailed_includes_full_prompt(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_show - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - result = prototype_agent_show(cmd, name="cloud-architect", detailed=True, json_output=True) - assert "system_prompt" in result - # detailed should not have preview - assert "system_prompt_preview" not in result - - @patch(f"{_MOD}._get_project_dir") - def test_show_console_mode(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_show - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - result = prototype_agent_show(cmd, name="cloud-architect", json_output=True) - assert isinstance(result, dict) - - -class TestPrototypeAgentUpdate: - """Test agent update command.""" - - @patch(f"{_MOD}._get_project_dir") - def test_update_description(self, mock_dir, project_with_config): - """Targeted field update — description only.""" - from azext_prototype.custom import prototype_agent_add, prototype_agent_update - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - prototype_agent_add(cmd, name="updatable", definition="cloud_architect") - result = prototype_agent_update(cmd, name="updatable", description="New desc", json_output=True) - assert result["status"] == "updated" - assert result["description"] == "New desc" - - @patch(f"{_MOD}._get_project_dir") - def test_update_capabilities(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_add, prototype_agent_update - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - prototype_agent_add(cmd, name="cap-update", definition="cloud_architect") - result = prototype_agent_update(cmd, name="cap-update", capabilities="architect,deploy", json_output=True) - assert result["status"] == "updated" - assert "architect" in result["capabilities"] - assert "deploy" in result["capabilities"] - - @patch(f"{_MOD}._get_project_dir") - def test_update_system_prompt_from_file(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_add, prototype_agent_update - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - prototype_agent_add(cmd, name="prompt-update", definition="cloud_architect") - - prompt_file = project_with_config / "new_prompt.txt" - prompt_file.write_text("You are an updated agent.", encoding="utf-8") - - result = prototype_agent_update( - cmd, name="prompt-update", system_prompt_file=str(prompt_file), json_output=True - ) - assert result["status"] == "updated" - - import yaml as _yaml - - agent_file = project_with_config / ".prototype" / "agents" / "prompt-update.yaml" - content = _yaml.safe_load(agent_file.read_text(encoding="utf-8")) - assert content["system_prompt"] == "You are an updated agent." - - @patch(f"{_MOD}._get_project_dir") - def test_update_interactive_mode(self, mock_dir, project_with_config): - """Interactive mode with mocked input.""" - from azext_prototype.custom import prototype_agent_add, prototype_agent_update - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - prototype_agent_add(cmd, name="interactive-up", definition="cloud_architect") - - # Mock interactive prompts: description, role, capabilities, constraints (empty), system prompt (empty=keep) - inputs = [ - "Updated description", # description - "architect", # role - "architect", # capabilities - "", # end constraints - "", # system prompt (keep existing - first empty line) - "", # examples (skip) - ] - with patch("builtins.input", side_effect=inputs): - result = prototype_agent_update(cmd, name="interactive-up", json_output=True) - - assert result["status"] == "updated" - assert result["description"] == "Updated description" - - @patch(f"{_MOD}._get_project_dir") - def test_update_manifest_sync(self, mock_dir, project_with_config): - """Manifest entry is updated after field update.""" - from azext_prototype.custom import ( - _load_config, - prototype_agent_add, - prototype_agent_update, - ) - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - prototype_agent_add(cmd, name="manifest-sync", definition="cloud_architect") - prototype_agent_update(cmd, name="manifest-sync", description="Synced desc") - - config = _load_config(str(project_with_config)) - custom = config.get("agents.custom", {}) - assert custom["manifest-sync"]["description"] == "Synced desc" - - def test_update_missing_name_raises(self): - from azext_prototype.custom import prototype_agent_update - - cmd = MagicMock() - with pytest.raises(CLIError, match="--name"): - prototype_agent_update(cmd, name=None) - - @patch(f"{_MOD}._get_project_dir") - def test_update_nonexistent_raises(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_update - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - with pytest.raises(CLIError, match="not found"): - prototype_agent_update(cmd, name="nonexistent-agent") - - @patch(f"{_MOD}._get_project_dir") - def test_update_invalid_capability_raises(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_add, prototype_agent_update - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - prototype_agent_add(cmd, name="bad-cap", definition="cloud_architect") - with pytest.raises(CLIError, match="Unknown capability"): - prototype_agent_update(cmd, name="bad-cap", capabilities="invalid_cap") - - @patch(f"{_MOD}._get_project_dir") - def test_update_prompt_file_not_found_raises(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_add, prototype_agent_update - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - prototype_agent_add(cmd, name="no-prompt", definition="cloud_architect") - with pytest.raises(CLIError, match="not found"): - prototype_agent_update(cmd, name="no-prompt", system_prompt_file="./does_not_exist.txt") - - -class TestPrototypeAgentTest: - """Test agent test command.""" - - @patch(f"{_MOD}._prepare_command") - def test_default_prompt(self, mock_prep, project_with_config, mock_ai_provider): - from azext_prototype.ai.provider import AIResponse - from azext_prototype.custom import prototype_agent_test - - mock_agent = MagicMock() - mock_agent.name = "cloud-architect" - mock_agent.execute.return_value = AIResponse( - content="I am the cloud architect.", - model="gpt-4o", - usage={"prompt_tokens": 50, "completion_tokens": 20, "total_tokens": 70}, - ) - - mock_registry = MagicMock() - mock_registry.get.return_value = mock_agent - mock_prep.return_value = (str(project_with_config), MagicMock(), mock_registry, MagicMock()) - - cmd = MagicMock() - result = prototype_agent_test(cmd, name="cloud-architect", json_output=True) - - assert result["status"] == "tested" - assert result["name"] == "cloud-architect" - assert result["model"] == "gpt-4o" - assert result["tokens"] == 70 - mock_agent.execute.assert_called_once() - - @patch(f"{_MOD}._prepare_command") - def test_custom_prompt(self, mock_prep, project_with_config, mock_ai_provider): - from azext_prototype.ai.provider import AIResponse - from azext_prototype.custom import prototype_agent_test - - mock_agent = MagicMock() - mock_agent.name = "cloud-architect" - mock_agent.execute.return_value = AIResponse( - content="Here is a web app design.", - model="gpt-4o", - usage={"total_tokens": 100}, - ) - - mock_registry = MagicMock() - mock_registry.get.return_value = mock_agent - mock_prep.return_value = (str(project_with_config), MagicMock(), mock_registry, MagicMock()) - - cmd = MagicMock() - result = prototype_agent_test(cmd, name="cloud-architect", prompt="Design a web app", json_output=True) - - assert result["status"] == "tested" - # Verify custom prompt was passed - call_args = mock_agent.execute.call_args - assert "Design a web app" in call_args[0][1] - - def test_test_missing_name_raises(self): - from azext_prototype.custom import prototype_agent_test - - cmd = MagicMock() - with pytest.raises(CLIError, match="--name"): - prototype_agent_test(cmd, name=None) - - -class TestPrototypeAgentExport: - """Test agent export command.""" - - @patch(f"{_MOD}._get_project_dir") - def test_export_builtin(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_export - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - output_path = str(project_with_config / "exported.yaml") - result = prototype_agent_export(cmd, name="cloud-architect", output_file=output_path, json_output=True) - - assert result["status"] == "exported" - assert result["name"] == "cloud-architect" - - import yaml as _yaml - - exported = _yaml.safe_load((project_with_config / "exported.yaml").read_text(encoding="utf-8")) - assert exported["name"] == "cloud-architect" - assert "capabilities" in exported - assert "system_prompt" in exported - - @patch(f"{_MOD}._get_project_dir") - def test_export_custom(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_add, prototype_agent_export - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - prototype_agent_add(cmd, name="export-test", definition="bicep_agent") - output_path = str(project_with_config / "custom_export.yaml") - result = prototype_agent_export(cmd, name="export-test", output_file=output_path, json_output=True) - - assert result["status"] == "exported" - assert (project_with_config / "custom_export.yaml").exists() - - @patch(f"{_MOD}._get_project_dir") - def test_export_default_path(self, mock_dir, project_with_config): - """Default output path is ./{name}.yaml.""" - import os - - from azext_prototype.custom import prototype_agent_export - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - # Change cwd to project dir for default path - original_cwd = os.getcwd() - try: - os.chdir(str(project_with_config)) - result = prototype_agent_export(cmd, name="cloud-architect", json_output=True) - assert result["status"] == "exported" - assert (project_with_config / "cloud-architect.yaml").exists() - finally: - os.chdir(original_cwd) - - @patch(f"{_MOD}._get_project_dir") - def test_export_loadable_by_loader(self, mock_dir, project_with_config): - """Exported YAML is loadable by load_yaml_agent.""" - from azext_prototype.agents.loader import load_yaml_agent - from azext_prototype.custom import prototype_agent_export - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - output_path = str(project_with_config / "loadable.yaml") - prototype_agent_export(cmd, name="cloud-architect", output_file=output_path) - - agent = load_yaml_agent(output_path) - assert agent.name == "cloud-architect" - - def test_export_missing_name_raises(self): - from azext_prototype.custom import prototype_agent_export - - cmd = MagicMock() - with pytest.raises(CLIError, match="--name"): - prototype_agent_export(cmd, name=None) - - -class TestPrototypeAgentOverrideValidation: - """Test override validation enhancements.""" - - @patch(f"{_MOD}._get_project_dir") - def test_override_file_not_found_raises(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_override - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - with pytest.raises(CLIError, match="not found"): - prototype_agent_override(cmd, name="cloud-architect", file="./does_not_exist.yaml") - - @patch(f"{_MOD}._get_project_dir") - def test_override_invalid_yaml_raises(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_override - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - bad_yaml = project_with_config / "bad.yaml" - bad_yaml.write_text("{{invalid yaml::", encoding="utf-8") - - with pytest.raises(CLIError, match="Invalid YAML"): - prototype_agent_override(cmd, name="cloud-architect", file="bad.yaml") - - @patch(f"{_MOD}._get_project_dir") - def test_override_missing_name_field_raises(self, mock_dir, project_with_config): - from azext_prototype.custom import prototype_agent_override - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - no_name = project_with_config / "no_name.yaml" - no_name.write_text("description: test\n", encoding="utf-8") - - with pytest.raises(CLIError, match="name"): - prototype_agent_override(cmd, name="cloud-architect", file="no_name.yaml") - - @patch(f"{_MOD}._get_project_dir") - def test_override_non_builtin_warns(self, mock_dir, project_with_config): - """Overriding a non-builtin name should warn but succeed.""" - from azext_prototype.custom import prototype_agent_override - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - valid_yaml = project_with_config / "valid.yaml" - valid_yaml.write_text( - "name: nonexistent-agent\ndescription: test\ncapabilities:\n - develop\n" "system_prompt: test\n", - encoding="utf-8", - ) - - result = prototype_agent_override(cmd, name="nonexistent-agent", file="valid.yaml", json_output=True) - assert result["status"] == "override_registered" - - @patch(f"{_MOD}._get_project_dir") - def test_override_valid_builtin(self, mock_dir, project_with_config): - """Overriding a known builtin should succeed without warnings.""" - from azext_prototype.custom import prototype_agent_override - - mock_dir.return_value = str(project_with_config) - cmd = MagicMock() - - valid_yaml = project_with_config / "arch_override.yaml" - valid_yaml.write_text( - "name: cloud-architect\ndescription: Custom arch\ncapabilities:\n - architect\n" - "system_prompt: Custom prompt.\n", - encoding="utf-8", - ) - - result = prototype_agent_override(cmd, name="cloud-architect", file="arch_override.yaml", json_output=True) - assert result["status"] == "override_registered" - - -class TestPromptAgentDefinition: - """Test the _prompt_agent_definition interactive helper.""" - - def test_full_walkthrough(self): - from azext_prototype.custom import _prompt_agent_definition - from azext_prototype.ui.console import Console - - console = Console() - inputs = [ - "My agent description", # description - "architect", # role - "architect,deploy", # capabilities - "Must use PaaS only", # constraint 1 - "", # end constraints - "You are a custom agent.", # system prompt line 1 - "END", # end system prompt - "", # no examples - ] - with patch("builtins.input", side_effect=inputs): - result = _prompt_agent_definition(console, "test-agent") - - assert result["name"] == "test-agent" - assert result["description"] == "My agent description" - assert result["role"] == "architect" - assert "architect" in result["capabilities"] - assert "deploy" in result["capabilities"] - assert "Must use PaaS only" in result["constraints"] - assert "You are a custom agent." in result["system_prompt"] - - def test_existing_defaults(self): - from azext_prototype.custom import _prompt_agent_definition - from azext_prototype.ui.console import Console - - console = Console() - existing = { - "description": "Old desc", - "role": "developer", - "capabilities": ["develop"], - "constraints": ["Old constraint"], - "system_prompt": "Old prompt.", - "examples": [{"user": "hello", "assistant": "hi"}], - } - # All empty inputs → keep existing values - inputs = [ - "", # description (keep) - "", # role (keep) - "", # capabilities (keep) - "", # constraints (keep existing) - "", # system prompt (keep existing) - "", # examples (keep existing) - ] - with patch("builtins.input", side_effect=inputs): - result = _prompt_agent_definition(console, "test-agent", existing=existing) - - assert result["description"] == "Old desc" - assert result["role"] == "developer" - assert result["capabilities"] == ["develop"] - assert result["constraints"] == ["Old constraint"] - assert result["system_prompt"] == "Old prompt." - assert result["examples"] == [{"user": "hello", "assistant": "hi"}] - - def test_invalid_capability_skipped(self): - from azext_prototype.custom import _prompt_agent_definition - from azext_prototype.ui.console import Console - - console = Console() - inputs = [ - "desc", # description - "role", # role - "invalid_cap,architect", # capabilities — one invalid - "", # end constraints - "prompt", # system prompt - "END", # end system prompt - "", # no examples - ] - with patch("builtins.input", side_effect=inputs): - result = _prompt_agent_definition(console, "test-agent") - - assert "architect" in result["capabilities"] - assert "invalid_cap" not in result["capabilities"] - - -class TestReadMultilineInput: - """Test _read_multiline_input helper.""" - - def test_reads_until_end(self): - from azext_prototype.custom import _read_multiline_input - - with patch("builtins.input", side_effect=["line 1", "line 2", "END"]): - result = _read_multiline_input() - assert result == "line 1\nline 2" - - def test_empty_first_line_returns_empty(self): - from azext_prototype.custom import _read_multiline_input - - with patch("builtins.input", side_effect=[""]): - result = _read_multiline_input() - assert result == "" diff --git a/tests/test_deploy_helpers.py b/tests/test_deploy_helpers.py deleted file mode 100644 index 085564f..0000000 --- a/tests/test_deploy_helpers.py +++ /dev/null @@ -1,477 +0,0 @@ -"""Tests for azext_prototype.stages.deploy_helpers.""" - -import json -from pathlib import Path -from unittest.mock import MagicMock, patch - -from azext_prototype.stages.deploy_helpers import ( - DEPLOY_ENV_MAPPING, - DeploymentOutputCapture, - DeployScriptGenerator, - RollbackManager, - build_deploy_env, - resolve_stage_secrets, - scan_tf_secret_variables, -) - - -class TestDeploymentOutputCapture: - """Test output capture and environment variable generation.""" - - def test_capture_and_retrieve(self, tmp_project): - capture = DeploymentOutputCapture(str(tmp_project)) - - # Simulate Bicep outputs - bicep_output = json.dumps( - { - "properties": { - "outputs": { - "resource_group_name": {"type": "string", "value": "zd-rg-api-dev-eus"}, - "storage_account_name": {"type": "string", "value": "stzddatadeveus"}, - } - } - } - ) - capture.capture_bicep(bicep_output) - - assert capture.get("resource_group_name") == "zd-rg-api-dev-eus" - assert capture.get("storage_account_name") == "stzddatadeveus" - assert capture.get("nonexistent", "fallback") == "fallback" - - def test_to_env_vars(self, tmp_project): - capture = DeploymentOutputCapture(str(tmp_project)) - - bicep_output = json.dumps( - { - "properties": { - "outputs": { - "resource_group_name": {"type": "string", "value": "rg-test"}, - "app_url": {"type": "string", "value": "https://myapp.azurewebsites.net"}, - } - } - } - ) - capture.capture_bicep(bicep_output) - - env_vars = capture.to_env_vars() - assert env_vars["PROTOTYPE_RESOURCE_GROUP_NAME"] == "rg-test" - assert env_vars["PROTOTYPE_APP_URL"] == "https://myapp.azurewebsites.net" - - def test_persistence(self, tmp_project): - # Write - capture1 = DeploymentOutputCapture(str(tmp_project)) - capture1._outputs["terraform"] = {"foo": "bar"} - capture1._save() - - # Read - capture2 = DeploymentOutputCapture(str(tmp_project)) - assert capture2.get("foo") == "bar" - - def test_get_all(self, tmp_project): - capture = DeploymentOutputCapture(str(tmp_project)) - assert isinstance(capture.get_all(), dict) - - def test_invalid_bicep_output(self, tmp_project): - capture = DeploymentOutputCapture(str(tmp_project)) - result = capture.capture_bicep("not-json") - assert result == {} - - -class TestDeployScriptGenerator: - """Test deploy script generation.""" - - def test_generate_webapp_script(self, tmp_path): - app_dir = tmp_path / "my-api" - app_dir.mkdir() - - script = DeployScriptGenerator.generate( - app_dir=app_dir, - app_name="my-api", - deploy_type="webapp", - resource_group="rg-test", - ) - - assert "#!/usr/bin/env bash" in script - assert "my-api" in script - assert "az webapp deploy" in script - assert (app_dir / "deploy.sh").exists() - - def test_generate_container_app_script(self, tmp_path): - app_dir = tmp_path / "my-app" - app_dir.mkdir() - - script = DeployScriptGenerator.generate( - app_dir=app_dir, - app_name="my-app", - deploy_type="container_app", - resource_group="rg-test", - registry="myregistry.azurecr.io", - ) - - assert "az acr build" in script - assert "az containerapp update" in script - assert "myregistry.azurecr.io" in script - - def test_generate_function_script(self, tmp_path): - app_dir = tmp_path / "my-func" - app_dir.mkdir() - - script = DeployScriptGenerator.generate( - app_dir=app_dir, - app_name="my-func", - deploy_type="function", - resource_group="rg-test", - ) - - assert "func azure functionapp publish" in script - assert "my-func" in script - - -class TestRollbackManager: - """Test rollback tracking and instructions.""" - - def test_snapshot_before_deploy(self, tmp_project): - mgr = RollbackManager(str(tmp_project)) - snapshot = mgr.snapshot_before_deploy("infra", "terraform") - - assert snapshot["scope"] == "infra" - assert snapshot["iac_tool"] == "terraform" - assert "timestamp" in snapshot - - def test_multiple_snapshots(self, tmp_project): - mgr = RollbackManager(str(tmp_project)) - mgr.snapshot_before_deploy("infra", "terraform") - mgr.snapshot_before_deploy("apps", "terraform") - - latest = mgr.get_last_snapshot() - assert latest["scope"] == "apps" - - def test_rollback_instructions_terraform(self, tmp_project): - mgr = RollbackManager(str(tmp_project)) - mgr.snapshot_before_deploy("infra", "terraform") - - instructions = mgr.get_rollback_instructions() - assert any("terraform" in line.lower() for line in instructions) - - def test_rollback_instructions_bicep(self, tmp_project): - mgr = RollbackManager(str(tmp_project)) - mgr.snapshot_before_deploy("infra", "bicep") - - instructions = mgr.get_rollback_instructions() - assert any("bicep" in line.lower() or "deployment" in line.lower() for line in instructions) - - def test_no_snapshots(self, tmp_project): - mgr = RollbackManager(str(tmp_project)) - assert mgr.get_last_snapshot() is None - - instructions = mgr.get_rollback_instructions() - assert len(instructions) >= 1 # Should have "nothing to roll back" message - - def test_persistence(self, tmp_project): - mgr1 = RollbackManager(str(tmp_project)) - mgr1.snapshot_before_deploy("infra", "terraform") - - mgr2 = RollbackManager(str(tmp_project)) - assert mgr2.get_last_snapshot() is not None - assert mgr2.get_last_snapshot()["scope"] == "infra" - - -class TestDeployEnvMapping: - """Tests for DEPLOY_ENV_MAPPING and build_deploy_env().""" - - def test_mapping_covers_all_params(self): - """Every build_deploy_env parameter has a mapping entry.""" - assert "subscription" in DEPLOY_ENV_MAPPING - assert "tenant" in DEPLOY_ENV_MAPPING - assert "client_id" in DEPLOY_ENV_MAPPING - assert "client_secret" in DEPLOY_ENV_MAPPING - - def test_mapping_includes_tf_var(self): - """Each param maps to at least one TF_VAR_* entry.""" - for param, keys in DEPLOY_ENV_MAPPING.items(): - tf_vars = [k for k in keys if k.startswith("TF_VAR_")] - assert tf_vars, f"{param} has no TF_VAR_* mapping" - - def test_mapping_includes_arm(self): - """Each param maps to at least one ARM_* entry.""" - for param, keys in DEPLOY_ENV_MAPPING.items(): - arm_vars = [k for k in keys if k.startswith("ARM_")] - assert arm_vars, f"{param} has no ARM_* mapping" - - def test_all_fields(self): - env = build_deploy_env("sub-123", "tenant-456", "client-id", "secret") - # ARM vars - assert env["ARM_SUBSCRIPTION_ID"] == "sub-123" - assert env["ARM_TENANT_ID"] == "tenant-456" - assert env["ARM_CLIENT_ID"] == "client-id" - assert env["ARM_CLIENT_SECRET"] == "secret" - # TF_VAR vars (auto-resolve HCL variables) - assert env["TF_VAR_subscription_id"] == "sub-123" - assert env["TF_VAR_tenant_id"] == "tenant-456" - assert env["TF_VAR_client_id"] == "client-id" - assert env["TF_VAR_client_secret"] == "secret" - # Legacy - assert env["SUBSCRIPTION_ID"] == "sub-123" - - def test_subscription_only(self): - env = build_deploy_env("sub-123") - assert env["ARM_SUBSCRIPTION_ID"] == "sub-123" - assert env["TF_VAR_subscription_id"] == "sub-123" - assert env["SUBSCRIPTION_ID"] == "sub-123" - assert "ARM_TENANT_ID" not in env - assert "TF_VAR_tenant_id" not in env - assert "ARM_CLIENT_ID" not in env - - def test_inherits_os_environ(self): - env = build_deploy_env("sub-123") - # PATH should be inherited from os.environ - assert "PATH" in env - - def test_empty(self): - env = build_deploy_env() - assert "ARM_SUBSCRIPTION_ID" not in env - assert "TF_VAR_subscription_id" not in env - assert "ARM_TENANT_ID" not in env - # Should still have os.environ entries - assert "PATH" in env - - -class TestDeployEnvPassing: - """Tests that verify env is passed through to subprocess calls.""" - - @patch("subprocess.run") - def test_deploy_terraform_passes_env(self, mock_run): - from azext_prototype.stages.deploy_helpers import deploy_terraform - - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - test_env = build_deploy_env("sub-123", "tenant-456") - - deploy_terraform(Path("/tmp/fake"), "sub-123", env=test_env) - - # All subprocess.run calls should receive env=test_env - for c in mock_run.call_args_list: - assert c.kwargs.get("env") is test_env - - @patch("subprocess.run") - def test_deploy_bicep_adds_tenant_flag(self, mock_run): - from azext_prototype.stages.deploy_helpers import deploy_bicep - - mock_run.return_value = MagicMock(returncode=0, stdout="{}", stderr="") - infra_dir = Path("/tmp/fake") - test_env = build_deploy_env("sub-123", "tenant-456") - - # Create a mock bicep file - with patch.object(Path, "exists", return_value=True), patch.object(Path, "glob", return_value=[]), patch( - "azext_prototype.stages.deploy_helpers.find_bicep_params", return_value=None - ), patch("azext_prototype.stages.deploy_helpers.is_subscription_scoped", return_value=False): - deploy_bicep(infra_dir, "sub-123", "my-rg", env=test_env) - - # Verify --tenant was added to the command - cmd = mock_run.call_args[0][0] - assert "--tenant" in cmd - assert "tenant-456" in cmd - assert mock_run.call_args.kwargs.get("env") is test_env - - @patch("subprocess.run") - def test_deploy_app_stage_merges_env(self, mock_run, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_app_stage - - stage_dir = tmp_path / "app" - stage_dir.mkdir() - deploy_sh = stage_dir / "deploy.sh" - deploy_sh.write_text("#!/bin/bash\necho ok") - - mock_run.return_value = MagicMock(returncode=0, stdout="ok", stderr="") - test_env = build_deploy_env("sub-123", "tenant-456", "cid", "csecret") - - deploy_app_stage(stage_dir, "sub-123", "my-rg", env=test_env) - - passed_env = mock_run.call_args.kwargs.get("env") - assert passed_env is not None - assert passed_env["ARM_SUBSCRIPTION_ID"] == "sub-123" - assert passed_env["ARM_TENANT_ID"] == "tenant-456" - assert passed_env["SUBSCRIPTION_ID"] == "sub-123" - assert passed_env["RESOURCE_GROUP"] == "my-rg" - - @patch("subprocess.run") - def test_deploy_app_sub_dirs_receive_env(self, mock_run, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_app_stage - - stage_dir = tmp_path / "apps" - stage_dir.mkdir() - sub_app = stage_dir / "api" - sub_app.mkdir() - (sub_app / "deploy.sh").write_text("#!/bin/bash\necho ok") - - mock_run.return_value = MagicMock(returncode=0, stdout="ok", stderr="") - test_env = build_deploy_env("sub-123", "tenant-456") - - deploy_app_stage(stage_dir, "sub-123", "my-rg", env=test_env) - - passed_env = mock_run.call_args.kwargs.get("env") - assert passed_env is not None - assert passed_env["ARM_SUBSCRIPTION_ID"] == "sub-123" - assert passed_env["ARM_TENANT_ID"] == "tenant-456" - assert passed_env["RESOURCE_GROUP"] == "my-rg" - - @patch("subprocess.run") - def test_rollback_terraform_passes_env(self, mock_run): - from azext_prototype.stages.deploy_helpers import rollback_terraform - - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - test_env = build_deploy_env("sub-123", "tenant-456") - - rollback_terraform(Path("/tmp/fake"), env=test_env) - - assert mock_run.call_args.kwargs.get("env") is test_env - - @patch("subprocess.run") - def test_plan_terraform_passes_env(self, mock_run): - from azext_prototype.stages.deploy_helpers import plan_terraform - - mock_run.return_value = MagicMock(returncode=0, stdout="Plan: 1 to add", stderr="") - test_env = build_deploy_env("sub-123") - - plan_terraform(Path("/tmp/fake"), "sub-123", env=test_env) - - for c in mock_run.call_args_list: - assert c.kwargs.get("env") is test_env - - @patch("subprocess.run") - def test_rollback_bicep_adds_tenant_flag(self, mock_run): - from azext_prototype.stages.deploy_helpers import rollback_bicep - - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - test_env = build_deploy_env("sub-123", "tenant-456") - - rollback_bicep(Path("/tmp/fake"), "sub-123", "my-rg", env=test_env) - - cmd = mock_run.call_args[0][0] - assert "--tenant" in cmd - assert "tenant-456" in cmd - assert mock_run.call_args.kwargs.get("env") is test_env - - @patch("subprocess.run") - def test_whatif_bicep_adds_tenant_flag(self, mock_run): - from azext_prototype.stages.deploy_helpers import whatif_bicep - - mock_run.return_value = MagicMock(returncode=0, stdout="What-if output", stderr="") - test_env = build_deploy_env("sub-123", "tenant-789") - - with patch.object(Path, "exists", return_value=True), patch.object(Path, "glob", return_value=[]), patch( - "azext_prototype.stages.deploy_helpers.find_bicep_params", return_value=None - ), patch("azext_prototype.stages.deploy_helpers.is_subscription_scoped", return_value=False): - whatif_bicep(Path("/tmp/fake"), "sub-123", "my-rg", env=test_env) - - cmd = mock_run.call_args[0][0] - assert "--tenant" in cmd - assert "tenant-789" in cmd - - @patch("subprocess.run") - def test_deploy_terraform_no_env_still_works(self, mock_run): - """Verify backward compat — env defaults to None.""" - from azext_prototype.stages.deploy_helpers import deploy_terraform - - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - deploy_terraform(Path("/tmp/fake"), "sub-123") - - # env=None is passed (default), which means subprocess inherits os.environ - for c in mock_run.call_args_list: - assert c.kwargs.get("env") is None - - -class TestSecretVariableScanning: - """Tests for scan_tf_secret_variables().""" - - def test_scan_finds_secret_suffix(self, tmp_path): - tf = tmp_path / "main.tf" - tf.write_text('variable "graph_client_secret" {}\n') - result = scan_tf_secret_variables(tmp_path) - assert "graph_client_secret" in result - - def test_scan_finds_password_suffix(self, tmp_path): - tf = tmp_path / "main.tf" - tf.write_text('variable "admin_password" {\n type = string\n}\n') - result = scan_tf_secret_variables(tmp_path) - assert "admin_password" in result - - def test_scan_ignores_known_vars(self, tmp_path): - tf = tmp_path / "main.tf" - tf.write_text('variable "client_secret" {}\n') - result = scan_tf_secret_variables(tmp_path) - assert "client_secret" not in result - - def test_scan_ignores_non_secret_vars(self, tmp_path): - tf = tmp_path / "main.tf" - tf.write_text('variable "location" {}\nvariable "resource_group_name" {}\n') - result = scan_tf_secret_variables(tmp_path) - assert result == [] - - def test_scan_ignores_vars_with_default(self, tmp_path): - tf = tmp_path / "main.tf" - tf.write_text('variable "api_secret" {\n default = "preset-value"\n}\n') - result = scan_tf_secret_variables(tmp_path) - assert result == [] - - def test_scan_multiple_files(self, tmp_path): - (tmp_path / "main.tf").write_text('variable "graph_client_secret" {}\n') - (tmp_path / "variables.tf").write_text('variable "db_password" {}\n') - result = scan_tf_secret_variables(tmp_path) - assert "graph_client_secret" in result - assert "db_password" in result - - def test_scan_empty_dir(self, tmp_path): - result = scan_tf_secret_variables(tmp_path) - assert result == [] - - -class TestResolveStageSecrets: - """Tests for resolve_stage_secrets().""" - - def _make_config(self, tmp_project): - from azext_prototype.config import ProjectConfig - - config = ProjectConfig(str(tmp_project)) - config.create_default() - return config - - def test_generates_new_secret(self, tmp_path, tmp_project): - (tmp_path / "main.tf").write_text('variable "graph_client_secret" {}\n') - config = self._make_config(tmp_project) - - result = resolve_stage_secrets(tmp_path, config) - assert "TF_VAR_graph_client_secret" in result - assert len(result["TF_VAR_graph_client_secret"]) == 64 # token_hex(32) - - def test_reuses_existing_secret(self, tmp_path, tmp_project): - (tmp_path / "main.tf").write_text('variable "graph_client_secret" {}\n') - config = self._make_config(tmp_project) - config.set("deploy.generated_secrets.graph_client_secret", "reused-value") - - result = resolve_stage_secrets(tmp_path, config) - assert result["TF_VAR_graph_client_secret"] == "reused-value" - - def test_persists_generated_secret(self, tmp_path, tmp_project): - (tmp_path / "main.tf").write_text('variable "app_password" {}\n') - config = self._make_config(tmp_project) - - resolve_stage_secrets(tmp_path, config) - - stored = config.get("deploy.generated_secrets.app_password") - assert stored is not None - assert len(stored) == 64 - - def test_multiple_secrets(self, tmp_path, tmp_project): - (tmp_path / "main.tf").write_text('variable "graph_client_secret" {}\nvariable "admin_password" {}\n') - config = self._make_config(tmp_project) - - result = resolve_stage_secrets(tmp_path, config) - assert "TF_VAR_graph_client_secret" in result - assert "TF_VAR_admin_password" in result - - def test_no_secrets_needed(self, tmp_path, tmp_project): - (tmp_path / "main.tf").write_text('variable "location" {}\n') - config = self._make_config(tmp_project) - - result = resolve_stage_secrets(tmp_path, config) - assert result == {} diff --git a/tests/test_discovery_state_scope.py b/tests/test_discovery_state_scope.py deleted file mode 100644 index eeeab8c..0000000 --- a/tests/test_discovery_state_scope.py +++ /dev/null @@ -1,192 +0,0 @@ -"""Tests for discovery_state scope management.""" - -from azext_prototype.stages.discovery_state import ( - DiscoveryState, - _default_discovery_state, -) - - -class TestDiscoveryStateScope: - """Test the scope fields in DiscoveryState.""" - - def test_default_state_has_scope(self): - state = _default_discovery_state() - assert "scope" in state - assert state["scope"] == { - "in_scope": [], - "out_of_scope": [], - "deferred": [], - } - - def test_merge_learnings_with_scope(self, tmp_path): - ds = DiscoveryState(str(tmp_path)) - ds.load() - - learnings = { - "scope": { - "in_scope": ["REST API", "SQL Database"], - "out_of_scope": ["Mobile app"], - "deferred": ["CI/CD pipeline"], - }, - } - ds.merge_learnings(learnings) - - assert ds.state["scope"]["in_scope"] == ["REST API", "SQL Database"] - assert ds.state["scope"]["out_of_scope"] == ["Mobile app"] - assert ds.state["scope"]["deferred"] == ["CI/CD pipeline"] - - def test_merge_learnings_deduplicates_scope(self, tmp_path): - ds = DiscoveryState(str(tmp_path)) - ds.load() - ds.state["scope"]["in_scope"] = ["REST API"] - - learnings = { - "scope": { - "in_scope": ["REST API", "SQL Database"], - }, - } - ds.merge_learnings(learnings) - - assert ds.state["scope"]["in_scope"] == ["REST API", "SQL Database"] - - def test_merge_learnings_partial_scope(self, tmp_path): - ds = DiscoveryState(str(tmp_path)) - ds.load() - - learnings = { - "scope": { - "in_scope": ["API endpoints"], - }, - } - ds.merge_learnings(learnings) - - assert ds.state["scope"]["in_scope"] == ["API endpoints"] - assert ds.state["scope"]["out_of_scope"] == [] - assert ds.state["scope"]["deferred"] == [] - - def test_merge_learnings_without_scope(self, tmp_path): - """Learnings without scope should not break merge.""" - ds = DiscoveryState(str(tmp_path)) - ds.load() - - learnings = { - "project": {"summary": "Test", "goals": ["Goal 1"]}, - } - ds.merge_learnings(learnings) - - assert ds.state["scope"]["in_scope"] == [] - - def test_format_as_context_includes_scope(self, tmp_path): - ds = DiscoveryState(str(tmp_path)) - ds.load() - ds._loaded = True - ds.state["scope"] = { - "in_scope": ["REST API"], - "out_of_scope": ["Mobile app"], - "deferred": ["CI/CD"], - } - - context = ds.format_as_context() - assert "## Prototype Scope" in context - assert "### In Scope" in context - assert "REST API" in context - assert "### Out of Scope" in context - assert "Mobile app" in context - assert "### Deferred / Future Work" in context - assert "CI/CD" in context - - def test_format_as_context_partial_scope(self, tmp_path): - ds = DiscoveryState(str(tmp_path)) - ds.load() - ds._loaded = True - ds.state["scope"]["in_scope"] = ["REST API"] - - context = ds.format_as_context() - assert "### In Scope" in context - assert "### Out of Scope" not in context - assert "### Deferred" not in context - - def test_format_as_context_omits_empty_scope(self, tmp_path): - ds = DiscoveryState(str(tmp_path)) - ds.load() - ds._loaded = True - ds.state["project"]["summary"] = "Test project" - - context = ds.format_as_context() - assert "Prototype Scope" not in context - - def test_format_as_context_falls_back_to_conversation(self, tmp_path): - """When structured fields are empty, format_as_context uses conversation history.""" - ds = DiscoveryState(str(tmp_path)) - ds.load() - ds._loaded = True - # Structured fields are all empty (default), but conversation has content - ds.state["conversation_history"] = [ - {"exchange": 1, "assistant": "Tell me more."}, - { - "exchange": 2, - "assistant": ( - "## Project Summary\nA web app for email drafting.\n\n" - "## Confirmed Functional Requirements\n- Feature A\n\n" - "[READY]" - ), - }, - ] - - context = ds.format_as_context() - assert "## Project Summary" in context - assert "email drafting" in context - assert "Feature A" in context - assert "[READY]" not in context - - def test_format_as_context_prefers_structured_fields(self, tmp_path): - """When structured fields are populated, those are used instead of conversation.""" - ds = DiscoveryState(str(tmp_path)) - ds.load() - ds._loaded = True - ds.state["project"]["summary"] = "Structured summary" - ds.state["conversation_history"] = [ - { - "exchange": 1, - "assistant": "## Project Summary\nConversation summary.\n\n## Confirmed Functional Requirements\n- X", - }, - ] - - context = ds.format_as_context() - assert "Structured summary" in context - assert "Conversation summary" not in context - - def test_extract_conversation_summary(self, tmp_path): - """extract_conversation_summary returns last assistant message with summary headings.""" - ds = DiscoveryState(str(tmp_path)) - ds.load() - ds.state["conversation_history"] = [ - {"exchange": 1, "assistant": "Tell me more."}, - { - "exchange": 2, - "assistant": "## Project Summary\nA web app.\n\n[READY]", - }, - ] - - result = ds.extract_conversation_summary() - assert "## Project Summary" in result - assert "[READY]" not in result - - def test_extract_conversation_summary_empty_history(self, tmp_path): - ds = DiscoveryState(str(tmp_path)) - ds.load() - - assert ds.extract_conversation_summary() == "" - - def test_scope_persists_to_yaml(self, tmp_path): - ds = DiscoveryState(str(tmp_path)) - ds.load() - ds.state["scope"]["in_scope"] = ["API endpoints"] - ds.state["scope"]["out_of_scope"] = ["Mobile app"] - ds.save() - - ds2 = DiscoveryState(str(tmp_path)) - ds2.load() - assert ds2.state["scope"]["in_scope"] == ["API endpoints"] - assert ds2.state["scope"]["out_of_scope"] == ["Mobile app"] - assert ds2.state["scope"]["deferred"] == [] diff --git a/tests/test_escalation.py b/tests/test_escalation.py deleted file mode 100644 index 369b43b..0000000 --- a/tests/test_escalation.py +++ /dev/null @@ -1,636 +0,0 @@ -"""Tests for azext_prototype.stages.escalation — blocker tracking and escalation chain. - -Covers: -- EscalationEntry serialization and defaults -- EscalationTracker state management (record, attempt, resolve, save/load) -- Escalation chain (level 1→2 technical, 1→2 scope, 2→3 web, 3→4 human) -- Auto-escalation timing -- Integration with qa_router -- Edge cases -- Report formatting -- State persistence across sessions -""" - -from __future__ import annotations - -from datetime import datetime, timedelta, timezone -from pathlib import Path -from unittest.mock import MagicMock, patch - -import yaml - -from azext_prototype.stages.escalation import EscalationEntry, EscalationTracker - -# ====================================================================== -# Helpers -# ====================================================================== - - -def _make_entry(**kwargs) -> EscalationEntry: - defaults = { - "task_description": "Build Stage 3: Data Layer", - "blocker": "Cosmos DB requires premium tier", - "source_agent": "terraform-agent", - "source_stage": "build", - "created_at": datetime.now(timezone.utc).isoformat(), - "last_escalated_at": datetime.now(timezone.utc).isoformat(), - } - defaults.update(kwargs) - return EscalationEntry(**defaults) - - -def _make_registry(architect_response=None, pm_response=None): - from azext_prototype.agents.base import AgentCapability - - architect = MagicMock() - architect.name = "cloud-architect" - if architect_response: - architect.execute.return_value = architect_response - else: - architect.execute.return_value = MagicMock(content="Use Standard tier instead") - - pm = MagicMock() - pm.name = "project-manager" - if pm_response: - pm.execute.return_value = pm_response - else: - pm.execute.return_value = MagicMock(content="Descope this item") - - registry = MagicMock() - - def find_by_cap(cap): - if cap == AgentCapability.ARCHITECT: - return [architect] - if cap == AgentCapability.BACKLOG_GENERATION: - return [pm] - return [] - - registry.find_by_capability.side_effect = find_by_cap - - return registry, architect, pm - - -def _make_context(): - from azext_prototype.agents.base import AgentContext - - return AgentContext( - project_config={"project": {"name": "test"}}, - project_dir="/tmp/test", - ai_provider=MagicMock(), - ) - - -# ====================================================================== -# EscalationEntry tests -# ====================================================================== - - -class TestEscalationEntry: - - def test_default_values(self): - entry = EscalationEntry(task_description="task", blocker="blocked") - assert entry.escalation_level == 1 - assert entry.resolved is False - assert entry.resolution == "" - assert entry.attempted_solutions == [] - - def test_to_dict_roundtrip(self): - entry = _make_entry(attempted_solutions=["Try A", "Try B"]) - d = entry.to_dict() - restored = EscalationEntry.from_dict(d) - - assert restored.task_description == entry.task_description - assert restored.blocker == entry.blocker - assert restored.attempted_solutions == ["Try A", "Try B"] - assert restored.escalation_level == entry.escalation_level - assert restored.source_agent == entry.source_agent - - def test_from_dict_missing_keys(self): - entry = EscalationEntry.from_dict({}) - assert entry.task_description == "" - assert entry.blocker == "" - assert entry.escalation_level == 1 - - -# ====================================================================== -# EscalationTracker state management tests -# ====================================================================== - - -class TestEscalationTrackerState: - - def test_record_blocker(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - - entry = tracker.record_blocker( - "Deploy Redis", - "Premium tier required", - "terraform-agent", - "deploy", - ) - - assert entry.task_description == "Deploy Redis" - assert entry.blocker == "Premium tier required" - assert entry.escalation_level == 1 - assert entry.created_at != "" - assert len(tracker.get_active_blockers()) == 1 - - def test_record_attempted_solution(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - - tracker.record_attempted_solution(entry, "Tried standard tier") - tracker.record_attempted_solution(entry, "Tried basic tier") - - assert len(entry.attempted_solutions) == 2 - assert "Tried standard tier" in entry.attempted_solutions - - def test_resolve_blocker(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - - tracker.resolve(entry, "Used standard tier instead") - - assert entry.resolved is True - assert entry.resolution == "Used standard tier instead" - assert len(tracker.get_active_blockers()) == 0 - - def test_get_active_blockers_filters_resolved(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - e1 = tracker.record_blocker("task1", "blocked1", "a1", "s1") - e2 = tracker.record_blocker("task2", "blocked2", "a2", "s2") # noqa: F841 - tracker.resolve(e1, "fixed") - - active = tracker.get_active_blockers() - assert len(active) == 1 - assert active[0].task_description == "task2" - - def test_save_load_roundtrip(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - tracker.record_blocker("task1", "blocked1", "agent1", "stage1") - tracker.record_blocker("task2", "blocked2", "agent2", "stage2") - - tracker2 = EscalationTracker(str(tmp_project)) - tracker2.load() - - assert len(tracker2.get_active_blockers()) == 2 - assert tracker2.get_active_blockers()[0].task_description == "task1" - - def test_save_creates_yaml(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - tracker.record_blocker("task", "blocked", "agent", "stage") - - yaml_path = Path(str(tmp_project)) / ".prototype" / "state" / "escalation.yaml" - assert yaml_path.exists() - - with open(yaml_path) as f: - data = yaml.safe_load(f) - assert len(data["entries"]) == 1 - - def test_exists_property(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - assert not tracker.exists - - tracker.record_blocker("task", "blocked", "agent", "stage") - assert tracker.exists - - def test_empty_load(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - tracker.load() # No file exists - assert tracker.get_active_blockers() == [] - - def test_multiple_records_and_resolves(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - e1 = tracker.record_blocker("t1", "b1", "a", "s") - e2 = tracker.record_blocker("t2", "b2", "a", "s") # noqa: F841 - e3 = tracker.record_blocker("t3", "b3", "a", "s") - - tracker.resolve(e1, "fixed") - tracker.resolve(e3, "workaround") - - assert len(tracker.get_active_blockers()) == 1 - assert tracker.get_active_blockers()[0].task_description == "t2" - - -# ====================================================================== -# Escalation chain tests -# ====================================================================== - - -class TestEscalationChain: - - def test_level_1_to_2_technical(self, tmp_project): - """Technical blocker escalates to architect.""" - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker( - "Deploy Cosmos DB", - "Premium tier required for multi-region", - "terraform-agent", - "build", - ) - - registry, architect, pm = _make_registry() - ctx = _make_context() - printed = [] - - result = tracker.escalate(entry, registry, ctx, printed.append) - - assert result["escalated"] is True - assert result["level"] == 2 - assert entry.escalation_level == 2 - architect.execute.assert_called_once() - pm.execute.assert_not_called() - - def test_level_1_to_2_scope(self, tmp_project): - """Scope blocker escalates to project-manager.""" - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker( - "Backlog items", - "Scope of feature is unclear", - "biz-analyst", - "design", - ) - - registry, architect, pm = _make_registry() - ctx = _make_context() - printed = [] - - result = tracker.escalate(entry, registry, ctx, printed.append) - - assert result["escalated"] is True - assert result["level"] == 2 - pm.execute.assert_called_once() - architect.execute.assert_not_called() - - @patch("azext_prototype.stages.escalation.EscalationTracker._escalate_to_web_search") - def test_level_2_to_3_web_search(self, mock_web, tmp_project): - """Level 2→3 triggers web search.""" - mock_web.return_value = "Found: Azure docs suggest..." - - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - entry.escalation_level = 2 # Already at level 2 - - registry, _, _ = _make_registry() - ctx = _make_context() - printed = [] - - result = tracker.escalate(entry, registry, ctx, printed.append) - - assert result["escalated"] is True - assert result["level"] == 3 - mock_web.assert_called_once() - - def test_level_3_to_4_human(self, tmp_project): - """Level 3→4 flags for human intervention.""" - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - entry.escalation_level = 3 # Already at level 3 - - registry, _, _ = _make_registry() - ctx = _make_context() - printed = [] - - result = tracker.escalate(entry, registry, ctx, printed.append) - - assert result["escalated"] is True - assert result["level"] == 4 - assert any("HUMAN INTERVENTION" in p for p in printed) - - def test_already_at_level_4_no_escalation(self, tmp_project): - """Cannot escalate past level 4.""" - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - entry.escalation_level = 4 - - registry, _, _ = _make_registry() - ctx = _make_context() - printed = [] - - result = tracker.escalate(entry, registry, ctx, printed.append) - - assert result["escalated"] is False - assert result["level"] == 4 - - def test_no_agent_available_for_escalation(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - - registry = MagicMock() - registry.find_by_capability.return_value = [] - ctx = _make_context() - printed = [] - - result = tracker.escalate(entry, registry, ctx, printed.append) - - assert result["level"] == 2 - assert "No cloud-architect available" in result["content"] - - def test_agent_escalation_failure(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - - registry, architect, _ = _make_registry() - architect.execute.side_effect = RuntimeError("AI crashed") - ctx = _make_context() - printed = [] - - result = tracker.escalate(entry, registry, ctx, printed.append) - - assert result["level"] == 2 - assert "failed" in result["content"].lower() - - def test_web_search_failure_graceful(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - entry.escalation_level = 2 - - printed = [] - - with patch("azext_prototype.stages.escalation.EscalationTracker._escalate_to_web_search") as mock_ws: - mock_ws.return_value = "Web search failed: connection error" - - registry, _, _ = _make_registry() - ctx = _make_context() - result = tracker.escalate(entry, registry, ctx, printed.append) - - assert result["level"] == 3 - assert "failed" in result["content"].lower() - - -# ====================================================================== -# Auto-escalation tests -# ====================================================================== - - -class TestAutoEscalation: - - def test_timeout_triggers_escalation(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - - # Set last_escalated_at to 5 minutes ago - old_time = datetime.now(timezone.utc) - timedelta(minutes=5) - entry.last_escalated_at = old_time.isoformat() - - assert tracker.should_auto_escalate(entry, timeout_seconds=120) - - def test_not_yet_timed_out(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - - # Just created, so not timed out - assert not tracker.should_auto_escalate(entry, timeout_seconds=120) - - def test_resolved_stops_escalation(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - tracker.resolve(entry, "fixed") - - old_time = datetime.now(timezone.utc) - timedelta(minutes=5) - entry.last_escalated_at = old_time.isoformat() - - assert not tracker.should_auto_escalate(entry) - - def test_level_4_stops_escalation(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - entry.escalation_level = 4 - - old_time = datetime.now(timezone.utc) - timedelta(minutes=5) - entry.last_escalated_at = old_time.isoformat() - - assert not tracker.should_auto_escalate(entry) - - def test_invalid_timestamp_returns_false(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - entry = tracker.record_blocker("task", "blocked", "agent", "stage") - entry.last_escalated_at = "not-a-timestamp" - - assert not tracker.should_auto_escalate(entry) - - -# ====================================================================== -# Integration with qa_router -# ====================================================================== - - -class TestQARouterIntegration: - - def test_qa_router_records_blocker_on_undiagnosed(self, tmp_project): - from azext_prototype.ai.provider import AIResponse - from azext_prototype.stages.qa_router import route_error_to_qa - - tracker = EscalationTracker(str(tmp_project)) - - # QA returns empty — undiagnosed - qa = MagicMock() - qa.execute.return_value = AIResponse(content="", model="gpt-4o", usage={}) - - ctx = _make_context() - - result = route_error_to_qa( - "Deployment failed", - "Deploy Stage 1", - qa, - ctx, - None, - lambda m: None, - escalation_tracker=tracker, - source_agent="terraform-agent", - source_stage="deploy", - ) - - assert result["diagnosed"] is False - assert len(tracker.get_active_blockers()) == 1 - blocker = tracker.get_active_blockers()[0] - assert blocker.source_agent == "terraform-agent" - assert blocker.source_stage == "deploy" - - def test_qa_router_no_tracker_no_error(self, tmp_project): - from azext_prototype.ai.provider import AIResponse - from azext_prototype.stages.qa_router import route_error_to_qa - - qa = MagicMock() - qa.execute.return_value = AIResponse(content="", model="gpt-4o", usage={}) - - ctx = _make_context() - - # No escalation tracker — should not raise - result = route_error_to_qa( - "error", - "context", - qa, - ctx, - None, - lambda m: None, - escalation_tracker=None, - ) - - assert result["diagnosed"] is False - - @patch("azext_prototype.stages.qa_router._submit_knowledge") - def test_qa_router_diagnosed_no_blocker(self, mock_knowledge, tmp_project): - from azext_prototype.ai.provider import AIResponse - from azext_prototype.stages.qa_router import route_error_to_qa - - tracker = EscalationTracker(str(tmp_project)) - - qa = MagicMock() - qa.execute.return_value = AIResponse(content="Root cause: X", model="gpt-4o", usage={}) - - ctx = _make_context() - - result = route_error_to_qa( - "error", - "context", - qa, - ctx, - None, - lambda m: None, - escalation_tracker=tracker, - ) - - assert result["diagnosed"] is True - # No blocker should be recorded when QA diagnoses successfully - assert len(tracker.get_active_blockers()) == 0 - - def test_build_session_has_escalation_tracker(self, tmp_project): - from azext_prototype.agents.base import AgentContext - from azext_prototype.stages.build_session import BuildSession - - ctx = AgentContext( - project_config={"project": {"name": "test", "location": "eastus"}}, - project_dir=str(tmp_project), - ai_provider=MagicMock(), - ) - - registry = MagicMock() - registry.find_by_capability.return_value = [] - - with patch("azext_prototype.stages.build_session.ProjectConfig") as mock_config: - mock_config.return_value.load.return_value = None - mock_config.return_value.get.side_effect = lambda k, d=None: { - "project.iac_tool": "terraform", - "project.name": "test", - }.get(k, d) - mock_config.return_value.to_dict.return_value = { - "naming": {"strategy": "simple"}, - "project": {"name": "test"}, - } - session = BuildSession(ctx, registry) - - assert hasattr(session, "_escalation_tracker") - assert isinstance(session._escalation_tracker, EscalationTracker) - - def test_deploy_session_has_escalation_tracker(self, tmp_project): - from azext_prototype.agents.base import AgentContext - from azext_prototype.stages.deploy_session import DeploySession - from azext_prototype.stages.deploy_state import DeployState - - ctx = AgentContext( - project_config={"project": {"name": "test", "location": "eastus"}}, - project_dir=str(tmp_project), - ai_provider=MagicMock(), - ) - - registry = MagicMock() - registry.find_by_capability.return_value = [] - - with patch("azext_prototype.stages.deploy_session.ProjectConfig") as mock_config: - mock_config.return_value.load.return_value = None - mock_config.return_value.get.side_effect = lambda k, d=None: { - "project.iac_tool": "terraform", - }.get(k, d) - session = DeploySession(ctx, registry, deploy_state=DeployState(str(tmp_project))) - - assert hasattr(session, "_escalation_tracker") - assert isinstance(session._escalation_tracker, EscalationTracker) - - def test_backlog_session_has_escalation_tracker(self, tmp_project): - from azext_prototype.agents.base import AgentContext - from azext_prototype.stages.backlog_session import BacklogSession - from azext_prototype.stages.backlog_state import BacklogState - - ctx = AgentContext( - project_config={"project": {"name": "test", "location": "eastus"}}, - project_dir=str(tmp_project), - ai_provider=MagicMock(), - ) - - registry = MagicMock() - registry.find_by_capability.return_value = [] - - session = BacklogSession(ctx, registry, backlog_state=BacklogState(str(tmp_project))) - - assert hasattr(session, "_escalation_tracker") - assert isinstance(session._escalation_tracker, EscalationTracker) - - -# ====================================================================== -# Report formatting tests -# ====================================================================== - - -class TestReportFormatting: - - def test_empty_report(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - report = tracker.format_escalation_report() - assert "No blockers recorded" in report - - def test_report_with_active_and_resolved(self, tmp_project): - tracker = EscalationTracker(str(tmp_project)) - e1 = tracker.record_blocker("Deploy Redis", "Premium needed", "tf", "build") # noqa: F841 - e2 = tracker.record_blocker("Deploy Cosmos", "Multi-region", "tf", "build") - tracker.resolve(e2, "Used single region") - - report = tracker.format_escalation_report() - - assert "Active Blockers (1)" in report - assert "Deploy Redis" in report - assert "Resolved (1)" in report - assert "Used single region" in report - - -# ====================================================================== -# State persistence across sessions -# ====================================================================== - - -class TestStatePersistence: - - def test_state_survives_session_restart(self, tmp_project): - tracker1 = EscalationTracker(str(tmp_project)) - tracker1.record_blocker("task1", "b1", "a1", "s1") - e2 = tracker1.record_blocker("task2", "b2", "a2", "s2") - tracker1.record_attempted_solution(e2, "Tried A") - tracker1.resolve(e2, "Used workaround B") - - # Simulate session restart - tracker2 = EscalationTracker(str(tmp_project)) - tracker2.load() - - assert len(tracker2.get_active_blockers()) == 1 - assert tracker2.get_active_blockers()[0].task_description == "task1" - - # Check resolved entry - all_entries = tracker2._entries - resolved = [e for e in all_entries if e.resolved] - assert len(resolved) == 1 - assert resolved[0].resolution == "Used workaround B" - assert resolved[0].attempted_solutions == ["Tried A"] - - def test_escalation_level_persists(self, tmp_project): - tracker1 = EscalationTracker(str(tmp_project)) - entry = tracker1.record_blocker("task", "blocked", "agent", "stage") - - registry, _, _ = _make_registry() - ctx = _make_context() - tracker1.escalate(entry, registry, ctx, lambda m: None) - - # Simulate restart - tracker2 = EscalationTracker(str(tmp_project)) - tracker2.load() - - assert tracker2.get_active_blockers()[0].escalation_level == 2 diff --git a/tests/test_knowledge_contributor.py b/tests/test_knowledge_contributor.py deleted file mode 100644 index 13f5742..0000000 --- a/tests/test_knowledge_contributor.py +++ /dev/null @@ -1,572 +0,0 @@ -"""Tests for knowledge contribution helpers. - -Covers gap detection, formatting, submission via ``gh`` CLI, QA integration, -the fire-and-forget wrapper, and the CLI command ``az prototype knowledge -contribute``. -""" - -from unittest.mock import MagicMock, patch - -import pytest - -_KC_MODULE = "azext_prototype.stages.knowledge_contributor" -_BP_MODULE = "azext_prototype.stages.backlog_push" -_CUSTOM_MODULE = "azext_prototype.custom" - - -# ====================================================================== -# Helpers -# ====================================================================== - - -def _make_finding(**overrides) -> dict: - """Create a minimal finding dict with optional overrides.""" - finding = { - "service": "cosmos-db", - "type": "Pitfall", - "file": "knowledge/services/cosmos-db.md", - "section": "Terraform Patterns", - "context": "RU throughput must be set to at least 400 for serverless", - "rationale": "Setting below 400 causes deployment failure", - "content": "minimum_throughput = 400", - "source": "QA diagnosis", - } - finding.update(overrides) - return finding - - -def _make_loader(service_content: str = "") -> MagicMock: - """Create a mock KnowledgeLoader that returns *service_content*.""" - loader = MagicMock() - loader.load_service.return_value = service_content - return loader - - -# ====================================================================== -# TestFormatContributionBody -# ====================================================================== - - -class TestFormatContributionBody: - """Tests for ``format_contribution_body()``.""" - - def test_basic_format(self): - from azext_prototype.stages.knowledge_contributor import ( - format_contribution_body, - ) - - finding = _make_finding() - body = format_contribution_body(finding) - - assert "## Knowledge Contribution" in body - assert "**Type:** Pitfall" in body - assert "**File:** `knowledge/services/cosmos-db.md`" in body - assert "**Section to update:** Terraform Patterns" in body - assert "### Context" in body - assert "RU throughput" in body - assert "### Rationale" in body - assert "### Content to Add" in body - assert "minimum_throughput = 400" in body - assert "### Source" in body - assert "QA diagnosis" in body - - def test_missing_fields_defaults(self): - from azext_prototype.stages.knowledge_contributor import ( - format_contribution_body, - ) - - finding = {"service": "redis"} - body = format_contribution_body(finding) - - # "redis" doesn't match a knowledge file (it's redis-cache.md), - # so it's auto-upgraded to "New service" - assert "**Type:** New service" in body - assert "`knowledge/services/redis.md`" in body - assert "NEW FILE" in body - assert "No context provided." in body - assert "No rationale provided." in body - assert "No specific content provided" in body - - def test_empty_content(self): - from azext_prototype.stages.knowledge_contributor import ( - format_contribution_body, - ) - - finding = _make_finding(content="") - body = format_contribution_body(finding) - - assert "No specific content provided" in body - - -# ====================================================================== -# TestFormatContributionTitle -# ====================================================================== - - -class TestFormatContributionTitle: - """Tests for ``format_contribution_title()``.""" - - def test_basic_title(self): - from azext_prototype.stages.knowledge_contributor import ( - format_contribution_title, - ) - - finding = _make_finding() - title = format_contribution_title(finding) - - assert title.startswith("[Knowledge] cosmos-db:") - assert "RU throughput" in title - - def test_truncation_at_60(self): - from azext_prototype.stages.knowledge_contributor import ( - format_contribution_title, - ) - - long_context = "A" * 100 - finding = _make_finding(context=long_context) - title = format_contribution_title(finding) - - # Title should contain truncated context + ellipsis - assert "..." in title - # The service prefix + 60 chars + "..." should be in there - assert len(title) < 120 - - def test_missing_service(self): - from azext_prototype.stages.knowledge_contributor import ( - format_contribution_title, - ) - - finding = _make_finding(service="") - # Falls back to "unknown" since service key exists but is empty - # Actually the default in the function is "unknown" for missing key - finding.pop("service") - title = format_contribution_title(finding) - - assert "[Knowledge] unknown:" in title - - def test_description_fallback(self): - from azext_prototype.stages.knowledge_contributor import ( - format_contribution_title, - ) - - finding = _make_finding(context="", description="fallback description") - title = format_contribution_title(finding) - - assert "fallback description" in title - - -# ====================================================================== -# TestCheckKnowledgeGap -# ====================================================================== - - -class TestCheckKnowledgeGap: - """Tests for ``check_knowledge_gap()``.""" - - def test_no_file_is_gap(self): - from azext_prototype.stages.knowledge_contributor import check_knowledge_gap - - loader = _make_loader("") # empty = no file - finding = _make_finding() - - assert check_knowledge_gap(finding, loader) is True - - def test_content_not_found_is_gap(self): - from azext_prototype.stages.knowledge_contributor import check_knowledge_gap - - # Service file exists but doesn't contain the finding's context - loader = _make_loader("Some unrelated content about key vault.") - finding = _make_finding() - - assert check_knowledge_gap(finding, loader) is True - - def test_content_found_is_not_gap(self): - from azext_prototype.stages.knowledge_contributor import check_knowledge_gap - - # The first 80 chars of context appear in the service file - finding = _make_finding() - context_snippet = finding["context"][:80].lower() - loader = _make_loader(f"Some preamble. {context_snippet} and more details.") - - assert check_knowledge_gap(finding, loader) is False - - def test_empty_finding_is_not_gap(self): - from azext_prototype.stages.knowledge_contributor import check_knowledge_gap - - loader = _make_loader("") - assert check_knowledge_gap({}, loader) is False - assert check_knowledge_gap(None, loader) is False - - def test_missing_service_is_not_gap(self): - from azext_prototype.stages.knowledge_contributor import check_knowledge_gap - - loader = _make_loader("") - finding = _make_finding(service="") - assert check_knowledge_gap(finding, loader) is False - - def test_missing_context_is_not_gap(self): - from azext_prototype.stages.knowledge_contributor import check_knowledge_gap - - loader = _make_loader("") - finding = _make_finding(context="") - assert check_knowledge_gap(finding, loader) is False - - def test_loader_exception_treated_as_gap(self): - from azext_prototype.stages.knowledge_contributor import check_knowledge_gap - - loader = MagicMock() - loader.load_service.side_effect = Exception("file not found") - finding = _make_finding() - - # Exception means no content found => gap - assert check_knowledge_gap(finding, loader) is True - - -# ====================================================================== -# TestSubmitContribution -# ====================================================================== - - -class TestSubmitContribution: - """Tests for ``submit_contribution()``.""" - - def test_success(self): - from azext_prototype.stages.knowledge_contributor import submit_contribution - - with patch(f"{_BP_MODULE}.subprocess.run") as mock_auth, patch(f"{_KC_MODULE}.subprocess.run") as mock_create: - mock_auth.return_value = MagicMock(returncode=0) - mock_create.return_value = MagicMock( - returncode=0, - stdout="https://github.com/Azure/az-prototype/issues/42\n", - ) - - result = submit_contribution(_make_finding()) - - assert result["url"] == "https://github.com/Azure/az-prototype/issues/42" - assert result["number"] == "42" - - def test_gh_not_authed(self): - from azext_prototype.stages.knowledge_contributor import submit_contribution - - with patch(f"{_BP_MODULE}.subprocess.run") as mock_auth: - mock_auth.return_value = MagicMock(returncode=1) - - result = submit_contribution(_make_finding()) - assert "error" in result - assert "not authenticated" in result["error"].lower() - - def test_create_fails(self): - from azext_prototype.stages.knowledge_contributor import submit_contribution - - with patch(f"{_BP_MODULE}.subprocess.run") as mock_auth, patch(f"{_KC_MODULE}.subprocess.run") as mock_create: - mock_auth.return_value = MagicMock(returncode=0) - mock_create.return_value = MagicMock( - returncode=1, - stderr="label 'pitfall' not found", - stdout="", - ) - - result = submit_contribution(_make_finding()) - assert "error" in result - - def test_labels_include_service_and_type(self): - from azext_prototype.stages.knowledge_contributor import submit_contribution - - with patch(f"{_BP_MODULE}.subprocess.run") as mock_auth, patch(f"{_KC_MODULE}.subprocess.run") as mock_create: - mock_auth.return_value = MagicMock(returncode=0) - mock_create.return_value = MagicMock( - returncode=0, - stdout="https://github.com/Azure/az-prototype/issues/99\n", - ) - - finding = _make_finding(service="key-vault", type="Service pattern update") - submit_contribution(finding) - - # Check the command args include service and type labels - call_args = mock_create.call_args[0][0] - label_indices = [i for i, a in enumerate(call_args) if a == "--label"] - labels = [call_args[i + 1] for i in label_indices] - assert "knowledge-contribution" in labels - assert "service/key-vault" in labels - assert "pattern-update" in labels - - def test_custom_repo(self): - from azext_prototype.stages.knowledge_contributor import submit_contribution - - with patch(f"{_BP_MODULE}.subprocess.run") as mock_auth, patch(f"{_KC_MODULE}.subprocess.run") as mock_create: - mock_auth.return_value = MagicMock(returncode=0) - mock_create.return_value = MagicMock( - returncode=0, - stdout="https://github.com/myorg/myrepo/issues/1\n", - ) - - result = submit_contribution(_make_finding(), repo="myorg/myrepo") - - call_args = mock_create.call_args[0][0] - repo_idx = call_args.index("--repo") - assert call_args[repo_idx + 1] == "myorg/myrepo" - assert result["url"] == "https://github.com/myorg/myrepo/issues/1" - - def test_gh_not_installed(self): - from azext_prototype.stages.knowledge_contributor import submit_contribution - - # Mock check_gh_auth at its source (both modules share the subprocess object) - with patch(f"{_BP_MODULE}.check_gh_auth", return_value=True), patch( - f"{_KC_MODULE}.subprocess.run" - ) as mock_create: - mock_create.side_effect = FileNotFoundError - - result = submit_contribution(_make_finding()) - assert "error" in result - assert "not found" in result["error"].lower() - - -# ====================================================================== -# TestBuildFindingFromQa -# ====================================================================== - - -class TestBuildFindingFromQa: - """Tests for ``build_finding_from_qa()``.""" - - def test_builds_from_qa_text(self): - from azext_prototype.stages.knowledge_contributor import build_finding_from_qa - - qa_text = "The Cosmos DB RU throughput was set below the minimum of 400." - finding = build_finding_from_qa(qa_text, service="cosmos-db", source="Deploy failure: Stage 2") - - assert finding["service"] == "cosmos-db" - assert finding["type"] == "Pitfall" - assert finding["source"] == "Deploy failure: Stage 2" - assert "cosmos-db" in finding["file"] - assert "400" in finding["context"] - assert "400" in finding["content"] - - def test_truncates_long_content(self): - from azext_prototype.stages.knowledge_contributor import build_finding_from_qa - - long_text = "X" * 1000 - finding = build_finding_from_qa(long_text, service="redis") - - assert len(finding["context"]) <= 500 - assert len(finding["content"]) <= 200 - - def test_empty_qa_text(self): - from azext_prototype.stages.knowledge_contributor import build_finding_from_qa - - finding = build_finding_from_qa("", service="redis") - assert finding["context"] == "" - assert finding["content"] == "" - - def test_defaults(self): - from azext_prototype.stages.knowledge_contributor import build_finding_from_qa - - finding = build_finding_from_qa("some content") - assert finding["service"] == "unknown" - assert finding["source"] == "QA diagnosis" - - -# ====================================================================== -# TestSubmitIfGap -# ====================================================================== - - -class TestSubmitIfGap: - """Tests for ``submit_if_gap()``.""" - - def test_submits_when_gap(self): - from azext_prototype.stages.knowledge_contributor import submit_if_gap - - loader = _make_loader("") # no content = gap - printed: list[str] = [] - - with patch(f"{_BP_MODULE}.subprocess.run") as mock_auth, patch(f"{_KC_MODULE}.subprocess.run") as mock_create: - mock_auth.return_value = MagicMock(returncode=0) - mock_create.return_value = MagicMock( - returncode=0, - stdout="https://github.com/Azure/az-prototype/issues/7\n", - ) - - result = submit_if_gap( - _make_finding(), - loader, - print_fn=printed.append, - ) - - assert result is not None - assert result["url"] == "https://github.com/Azure/az-prototype/issues/7" - assert any("submitted" in p.lower() for p in printed) - - def test_skips_when_no_gap(self): - from azext_prototype.stages.knowledge_contributor import submit_if_gap - - # Content already exists in knowledge file - finding = _make_finding() - loader = _make_loader(finding["context"][:80].lower() + " more details") - printed: list[str] = [] - - result = submit_if_gap(finding, loader, print_fn=printed.append) - - assert result is None - assert len(printed) == 0 - - def test_never_raises(self): - from azext_prototype.stages.knowledge_contributor import submit_if_gap - - # Loader throws an exception - loader = MagicMock() - loader.load_service.side_effect = RuntimeError("kaboom") - - # Even if gap check raises inside, submit_if_gap should not propagate - # Actually check_knowledge_gap catches it and returns True, then - # submit_contribution is called — let's make that fail too - with patch(f"{_KC_MODULE}.submit_contribution") as mock_submit: - mock_submit.side_effect = RuntimeError("double kaboom") - - result = submit_if_gap(_make_finding(), loader) - - # Should return None, not raise - assert result is None - - def test_no_print_when_no_url(self): - from azext_prototype.stages.knowledge_contributor import submit_if_gap - - loader = _make_loader("") # gap - printed: list[str] = [] - - with patch(f"{_BP_MODULE}.subprocess.run") as mock_auth, patch(f"{_KC_MODULE}.subprocess.run") as mock_create: - mock_auth.return_value = MagicMock(returncode=0) - mock_create.return_value = MagicMock( - returncode=1, - stderr="error", - stdout="", - ) - - submit_if_gap( - _make_finding(), - loader, - print_fn=printed.append, - ) - - # Error result, no URL to print - assert len(printed) == 0 - - -# ====================================================================== -# TestKnowledgeContributeCommand -# ====================================================================== - - -class TestKnowledgeContributeCommand: - """Tests for ``prototype_knowledge_contribute()`` CLI command.""" - - def test_draft_mode(self, project_with_config): - from azext_prototype.custom import prototype_knowledge_contribute - - cmd = MagicMock() - with patch(f"{_CUSTOM_MODULE}._get_project_dir", return_value=str(project_with_config)): - result = prototype_knowledge_contribute( - cmd, - service="cosmos-db", - description="RU throughput must be >= 400", - draft=True, - json_output=True, - ) - - assert result["status"] == "draft" - assert "cosmos-db" in result["title"] - - def test_noninteractive_submit(self, project_with_config): - from azext_prototype.custom import prototype_knowledge_contribute - - cmd = MagicMock() - with patch(f"{_CUSTOM_MODULE}._get_project_dir", return_value=str(project_with_config)), patch( - f"{_BP_MODULE}.subprocess.run" - ) as mock_auth, patch(f"{_KC_MODULE}.subprocess.run") as mock_create: - mock_auth.return_value = MagicMock(returncode=0) - mock_create.return_value = MagicMock( - returncode=0, - stdout="https://github.com/Azure/az-prototype/issues/55\n", - ) - - result = prototype_knowledge_contribute( - cmd, - service="cosmos-db", - description="RU throughput must be >= 400", - json_output=True, - ) - - assert result["status"] == "submitted" - assert result["url"] == "https://github.com/Azure/az-prototype/issues/55" - - def test_gh_not_authed_raises(self, project_with_config): - from knack.util import CLIError - - from azext_prototype.custom import prototype_knowledge_contribute - - cmd = MagicMock() - with patch(f"{_CUSTOM_MODULE}._get_project_dir", return_value=str(project_with_config)), patch( - f"{_BP_MODULE}.subprocess.run" - ) as mock_auth: - mock_auth.return_value = MagicMock(returncode=1) - - with pytest.raises(CLIError, match="not authenticated"): - prototype_knowledge_contribute( - cmd, - service="cosmos-db", - description="RU throughput", - ) - - def test_file_input(self, project_with_config): - from azext_prototype.custom import prototype_knowledge_contribute - - # Create a finding file - finding_file = project_with_config / "finding.md" - finding_file.write_text( - "Service: cosmos-db\nContext: RU must be >= 400\nContent: min_ru = 400", - encoding="utf-8", - ) - - cmd = MagicMock() - with patch(f"{_CUSTOM_MODULE}._get_project_dir", return_value=str(project_with_config)): - result = prototype_knowledge_contribute( - cmd, - file=str(finding_file), - draft=True, - json_output=True, - ) - - assert result["status"] == "draft" - - def test_file_not_found_raises(self, project_with_config): - from knack.util import CLIError - - from azext_prototype.custom import prototype_knowledge_contribute - - cmd = MagicMock() - with patch(f"{_CUSTOM_MODULE}._get_project_dir", return_value=str(project_with_config)): - with pytest.raises(CLIError, match="not found"): - prototype_knowledge_contribute( - cmd, - file="/nonexistent/path/finding.md", - draft=True, - ) - - def test_contribution_type_forwarded(self, project_with_config): - from azext_prototype.custom import prototype_knowledge_contribute - - cmd = MagicMock() - with patch(f"{_CUSTOM_MODULE}._get_project_dir", return_value=str(project_with_config)): - result = prototype_knowledge_contribute( - cmd, - service="redis", - description="Cache eviction pitfall", - contribution_type="Service pattern update", - section="Pitfalls", - draft=True, - json_output=True, - ) - - assert result["status"] == "draft" - assert "Service pattern update" in result["body"] - assert "Pitfalls" in result["body"] diff --git a/tests/test_stages_extended.py b/tests/test_stages_extended.py deleted file mode 100644 index 11d38a2..0000000 --- a/tests/test_stages_extended.py +++ /dev/null @@ -1,558 +0,0 @@ -"""Tests for deploy_stage.py, build_stage.py, and init_stage.py — full coverage.""" - -from unittest.mock import MagicMock, patch - -import pytest -from knack.util import CLIError - -from azext_prototype.ai.provider import AIResponse -from azext_prototype.stages.build_session import BuildResult - -# ====================================================================== -# DeployStage -# ====================================================================== - - -class TestDeployStageExecution: - """Test DeployStage orchestration and deploy_helpers functions.""" - - def _make_stage(self): - from azext_prototype.stages.deploy_stage import DeployStage - - return DeployStage() - - def test_deploy_guards(self): - stage = self._make_stage() - guards = stage.get_guards() - names = [g.name for g in guards] - assert "project_initialized" in names - assert "build_complete" in names - assert "az_logged_in" in names - - @patch("subprocess.run") - def test_check_az_login_true(self, mock_run): - from azext_prototype.stages.deploy_helpers import check_az_login - - mock_run.return_value = MagicMock(returncode=0) - assert check_az_login() is True - - @patch("subprocess.run") - def test_check_az_login_false(self, mock_run): - from azext_prototype.stages.deploy_helpers import check_az_login - - mock_run.return_value = MagicMock(returncode=1) - assert check_az_login() is False - - @patch("subprocess.run", side_effect=FileNotFoundError) - def test_check_az_login_not_installed(self, mock_run): - from azext_prototype.stages.deploy_helpers import check_az_login - - assert check_az_login() is False - - @patch("subprocess.run") - def test_get_current_subscription(self, mock_run): - from azext_prototype.stages.deploy_helpers import get_current_subscription - - mock_run.return_value = MagicMock(returncode=0, stdout="abc-123\n") - result = get_current_subscription() - assert result == "abc-123" - - @patch("subprocess.run", side_effect=FileNotFoundError) - def test_get_current_subscription_not_installed(self, mock_run): - from azext_prototype.stages.deploy_helpers import get_current_subscription - - assert get_current_subscription() == "" - - @patch("subprocess.run") - def test_deploy_terraform_success(self, mock_run, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_terraform - - infra_dir = tmp_path / "tf" - infra_dir.mkdir() - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - - result = deploy_terraform(infra_dir, "sub-123") - assert result["status"] == "deployed" - - @patch("subprocess.run") - def test_deploy_terraform_failure(self, mock_run, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_terraform - - infra_dir = tmp_path / "tf" - infra_dir.mkdir() - mock_run.return_value = MagicMock(returncode=1, stderr="init failed", stdout="") - - result = deploy_terraform(infra_dir, "sub-123") - assert result["status"] == "failed" - - @patch("subprocess.run") - def test_deploy_bicep_failure(self, mock_run, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_bicep - - (tmp_path / "main.bicep").write_text("resource x 'y' = {}", encoding="utf-8") - mock_run.return_value = MagicMock(returncode=1, stderr="Deployment failed", stdout="") - - result = deploy_bicep(tmp_path, "sub-123", "my-rg") - assert result["status"] == "failed" - - def test_deploy_app_stage_with_deploy_script(self, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_app_stage - - app_dir = tmp_path / "app" - app_dir.mkdir() - (app_dir / "deploy.sh").write_text("echo deployed", encoding="utf-8") - - with patch("subprocess.run") as mock_run: - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - result = deploy_app_stage(app_dir, "sub-123", "my-rg") - assert result["status"] == "deployed" - - def test_deploy_app_stage_sub_apps(self, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_app_stage - - stage_dir = tmp_path / "stage" - stage_dir.mkdir() - backend = stage_dir / "backend" - backend.mkdir() - (backend / "deploy.sh").write_text("echo ok", encoding="utf-8") - - with patch("subprocess.run") as mock_run: - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - result = deploy_app_stage(stage_dir, "sub-123", "my-rg") - assert result["status"] == "deployed" - assert "backend" in result["apps"] - - def test_deploy_app_stage_no_scripts(self, tmp_path): - from azext_prototype.stages.deploy_helpers import deploy_app_stage - - empty_dir = tmp_path / "empty" - empty_dir.mkdir() - result = deploy_app_stage(empty_dir, "sub-123", "my-rg") - assert result["status"] == "skipped" - - @patch("subprocess.run") - def test_whatif_bicep_no_files(self, mock_run, tmp_path): - from azext_prototype.stages.deploy_helpers import whatif_bicep - - empty_dir = tmp_path / "empty" - empty_dir.mkdir() - result = whatif_bicep(empty_dir, "sub-123", "my-rg") - assert result["status"] == "skipped" - - @patch("subprocess.run") - def test_whatif_bicep_no_rg_skips(self, mock_run, tmp_path): - from azext_prototype.stages.deploy_helpers import whatif_bicep - - (tmp_path / "main.bicep").write_text("resource x 'y' = {}", encoding="utf-8") - result = whatif_bicep(tmp_path, "sub-123", "") - assert result["status"] == "skipped" - - def test_get_deploy_location_main_params(self, tmp_path): - from azext_prototype.stages.deploy_helpers import get_deploy_location - - (tmp_path / "main.parameters.json").write_text( - '{"parameters": {"location": {"value": "northeurope"}}}', encoding="utf-8" - ) - result = get_deploy_location(tmp_path) - assert result == "northeurope" - - def test_get_deploy_location_string_value(self, tmp_path): - from azext_prototype.stages.deploy_helpers import get_deploy_location - - (tmp_path / "parameters.json").write_text('{"location": "uksouth"}', encoding="utf-8") - result = get_deploy_location(tmp_path) - assert result == "uksouth" - - def test_execute_status(self, project_with_build, mock_agent_context, populated_registry): - """Deploy with --status shows state and returns.""" - stage = self._make_stage() - stage.get_guards = lambda: [] - mock_agent_context.project_dir = str(project_with_build) - - result = stage.execute( - mock_agent_context, - populated_registry, - status=True, - ) - assert result["status"] == "status_displayed" - - def test_execute_reset(self, project_with_build, mock_agent_context, populated_registry): - """Deploy with --reset clears state and returns.""" - stage = self._make_stage() - stage.get_guards = lambda: [] - mock_agent_context.project_dir = str(project_with_build) - - result = stage.execute( - mock_agent_context, - populated_registry, - reset=True, - ) - assert result["status"] == "reset" - - -# ====================================================================== -# BuildStage -# ====================================================================== - - -class TestBuildStageExecution: - """Test BuildStage methods.""" - - def _make_stage(self): - from azext_prototype.stages.build_stage import BuildStage - - return BuildStage() - - def test_build_guards(self): - stage = self._make_stage() - guards = stage.get_guards() - names = [g.name for g in guards] - assert "project_initialized" in names - assert "discovery_complete" in names - assert "design_complete" in names - - def test_load_design(self, project_with_design): - stage = self._make_stage() - design = stage._load_design(str(project_with_design)) - assert "architecture" in design - - def test_load_design_missing(self, tmp_project): - stage = self._make_stage() - result = stage._load_design(str(tmp_project)) - assert result == {} - - def test_execute_no_design_raises(self, project_with_config, mock_agent_context, populated_registry): - stage = self._make_stage() - stage.get_guards = lambda: [] - mock_agent_context.project_dir = str(project_with_config) - - with pytest.raises(CLIError, match="No architecture design"): - stage.execute(mock_agent_context, populated_registry) - - def test_execute_dry_run(self, project_with_design, mock_agent_context, populated_registry): - stage = self._make_stage() - stage.get_guards = lambda: [] - mock_agent_context.project_dir = str(project_with_design) - mock_agent_context.ai_provider.chat.return_value = AIResponse(content="Generated code", model="gpt-4o") - - result = stage.execute(mock_agent_context, populated_registry, scope="docs", dry_run=True) - assert result["status"] == "dry-run" - - def test_execute_all_scopes_dry_run(self, project_with_design, mock_agent_context, populated_registry): - stage = self._make_stage() - stage.get_guards = lambda: [] - mock_agent_context.project_dir = str(project_with_design) - - result = stage.execute(mock_agent_context, populated_registry, scope="all", dry_run=True) - assert result["status"] == "dry-run" - assert result["scope"] == "all" - - @patch("azext_prototype.stages.build_stage.BuildSession") - def test_execute_interactive_delegates_to_session( - self, mock_session_cls, project_with_design, mock_agent_context, populated_registry - ): - stage = self._make_stage() - stage.get_guards = lambda: [] - mock_agent_context.project_dir = str(project_with_design) - - mock_result = BuildResult( - files_generated=["main.tf"], - deployment_stages=[{"stage": 1, "name": "Foundation"}], - policy_overrides=[], - resources=[{"resourceType": "Microsoft.Compute/virtualMachines", "sku": "Standard_B2s"}], - review_accepted=True, - cancelled=False, - ) - mock_session_cls.return_value.run.return_value = mock_result - - result = stage.execute(mock_agent_context, populated_registry, scope="all", dry_run=False) - assert result["status"] == "success" - assert result["scope"] == "all" - assert result["files_generated"] == ["main.tf"] - mock_session_cls.return_value.run.assert_called_once() - - -# ====================================================================== -# InitStage -# ====================================================================== - - -class TestInitStageExecution: - """Test InitStage methods.""" - - def _make_stage(self): - from azext_prototype.stages.init_stage import InitStage - - return InitStage() - - def test_init_guards(self): - """Init has no unconditional guards; gh check is conditional inside execute().""" - stage = self._make_stage() - guards = stage.get_guards() - assert len(guards) == 0 - - @patch("subprocess.run") - def test_check_gh_true(self, mock_run): - stage = self._make_stage() - mock_run.return_value = MagicMock(returncode=0) - assert stage._check_gh() is True - - @patch("subprocess.run", side_effect=FileNotFoundError) - def test_check_gh_false(self, mock_run): - stage = self._make_stage() - assert stage._check_gh() is False - - def test_create_scaffold(self, tmp_path): - stage = self._make_stage() - project_dir = tmp_path / "my-project" - stage._create_scaffold(project_dir) - - assert (project_dir / "concept" / "docs").is_dir() - assert (project_dir / ".prototype" / "agents").is_dir() - # infra, apps, db dirs are NOT created at init — only during build - assert not (project_dir / "concept" / "apps").exists() - assert not (project_dir / "concept" / "infra").exists() - assert not (project_dir / "concept" / "db").exists() - - def test_create_gitignore(self, tmp_path): - stage = self._make_stage() - stage._create_gitignore(tmp_path) - gi = tmp_path / ".gitignore" - assert gi.exists() - content = gi.read_text() - assert ".terraform/" in content - assert "__pycache__/" in content - - def test_create_gitignore_no_overwrite(self, tmp_path): - stage = self._make_stage() - gi = tmp_path / ".gitignore" - gi.write_text("custom content", encoding="utf-8") - stage._create_gitignore(tmp_path) - assert gi.read_text() == "custom content" - - @patch("azext_prototype.auth.copilot_license.CopilotLicenseValidator") - @patch("azext_prototype.auth.github_auth.GitHubAuthManager") - @patch("azext_prototype.stages.init_stage.InitStage._check_gh", return_value=True) - def test_execute_full(self, mock_gh, mock_auth_cls, mock_lic_cls, tmp_path): - stage = self._make_stage() - stage.get_guards = lambda: [] - - mock_auth = MagicMock() - mock_auth.ensure_authenticated.return_value = {"login": "devuser"} - mock_auth_cls.return_value = mock_auth - mock_lic = MagicMock() - mock_lic.validate_license.return_value = {"plan": "business"} - mock_lic_cls.return_value = mock_lic - - from azext_prototype.agents.base import AgentContext - from azext_prototype.agents.registry import AgentRegistry - - ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) - registry = AgentRegistry() - - out = tmp_path / "test-proj" - result = stage.execute( - ctx, - registry, - name="test-proj", - location="westus2", - iac_tool="bicep", - ai_provider="github-models", - output_dir=str(out), - ) - assert result["status"] == "success" - assert (out / "prototype.yaml").exists() - - @patch("azext_prototype.auth.copilot_license.CopilotLicenseValidator") - @patch("azext_prototype.auth.github_auth.GitHubAuthManager") - @patch("azext_prototype.stages.init_stage.InitStage._check_gh", return_value=True) - def test_execute_license_failure_continues(self, mock_gh, mock_auth_cls, mock_lic_cls, tmp_path): - """License validation failure should warn but continue.""" - stage = self._make_stage() - stage.get_guards = lambda: [] - - mock_auth = MagicMock() - mock_auth.ensure_authenticated.return_value = {"login": "devuser"} - mock_auth_cls.return_value = mock_auth - mock_lic = MagicMock() - mock_lic.validate_license.side_effect = CLIError("No license") - mock_lic_cls.return_value = mock_lic - - from azext_prototype.agents.base import AgentContext - from azext_prototype.agents.registry import AgentRegistry - - ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) - registry = AgentRegistry() - - result = stage.execute( - ctx, - registry, - name="lic-test", - location="eastus", - ai_provider="github-models", - output_dir=str(tmp_path / "lic-test"), - ) - assert result["status"] == "success" - assert result["copilot_license"]["status"] == "unverified" - - def test_execute_no_name_raises(self, tmp_path): - stage = self._make_stage() - stage.get_guards = lambda: [] - - from azext_prototype.agents.base import AgentContext - from azext_prototype.agents.registry import AgentRegistry - - ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) - registry = AgentRegistry() - - with pytest.raises(CLIError, match="Project name"): - stage.execute(ctx, registry, name="", output_dir=str(tmp_path / "empty-name")) - - def test_execute_no_location_raises(self, tmp_path): - stage = self._make_stage() - stage.get_guards = lambda: [] - - from azext_prototype.agents.base import AgentContext - from azext_prototype.agents.registry import AgentRegistry - - ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) - registry = AgentRegistry() - - with pytest.raises(CLIError, match="region is required"): - stage.execute( - ctx, - registry, - name="test-proj", - location=None, - output_dir=str(tmp_path / "test-proj"), - ) - - def test_execute_azure_openai_skips_auth(self, tmp_path): - """azure-openai provider should skip GitHub auth entirely.""" - stage = self._make_stage() - stage.get_guards = lambda: [] - - from azext_prototype.agents.base import AgentContext - from azext_prototype.agents.registry import AgentRegistry - - ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) - registry = AgentRegistry() - - result = stage.execute( - ctx, - registry, - name="aoai-test", - location="eastus", - ai_provider="azure-openai", - output_dir=str(tmp_path / "aoai-test"), - ) - assert result["status"] == "success" - assert result["github_user"] is None - assert "copilot_license" not in result - - def test_execute_environment_stored(self, tmp_path): - """--environment should be persisted in config.""" - stage = self._make_stage() - stage.get_guards = lambda: [] - - from azext_prototype.agents.base import AgentContext - from azext_prototype.agents.registry import AgentRegistry - from azext_prototype.config import ProjectConfig - - ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) - registry = AgentRegistry() - - out = tmp_path / "env-test" - stage.execute( - ctx, - registry, - name="env-test", - location="westus2", - ai_provider="azure-openai", - environment="prod", - output_dir=str(out), - ) - config = ProjectConfig(str(out)) - config.load() - assert config.get("project.environment") == "prod" - assert config.get("naming.env") == "prd" - assert config.get("naming.zone_id") == "zp" - - def test_execute_model_override(self, tmp_path): - """Explicit --model should override provider default.""" - stage = self._make_stage() - stage.get_guards = lambda: [] - - from azext_prototype.agents.base import AgentContext - from azext_prototype.agents.registry import AgentRegistry - from azext_prototype.config import ProjectConfig - - ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) - registry = AgentRegistry() - - out = tmp_path / "model-test" - stage.execute( - ctx, - registry, - name="model-test", - location="eastus", - ai_provider="azure-openai", - model="gpt-4o-mini", - output_dir=str(out), - ) - config = ProjectConfig(str(out)) - config.load() - assert config.get("ai.model") == "gpt-4o-mini" - - def test_execute_idempotency_cancel(self, tmp_path): - """Existing project + user declining should cancel.""" - stage = self._make_stage() - stage.get_guards = lambda: [] - - from azext_prototype.agents.base import AgentContext - from azext_prototype.agents.registry import AgentRegistry - - # Pre-create project directory with config - proj = tmp_path / "idem-test" - proj.mkdir() - (proj / "prototype.yaml").write_text("project:\n name: old\n") - - ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) - registry = AgentRegistry() - - with patch("builtins.input", return_value="n"): - result = stage.execute( - ctx, - registry, - name="idem-test", - location="eastus", - ai_provider="azure-openai", - output_dir=str(proj), - ) - assert result["status"] == "cancelled" - - def test_execute_marks_init_complete(self, tmp_path): - """Init stage should set stages.init.completed and timestamp.""" - stage = self._make_stage() - stage.get_guards = lambda: [] - - from azext_prototype.agents.base import AgentContext - from azext_prototype.agents.registry import AgentRegistry - from azext_prototype.config import ProjectConfig - - ctx = AgentContext(project_config={}, project_dir=str(tmp_path), ai_provider=None) - registry = AgentRegistry() - - out = tmp_path / "complete-test" - stage.execute( - ctx, - registry, - name="complete-test", - location="eastus", - ai_provider="azure-openai", - output_dir=str(out), - ) - config = ProjectConfig(str(out)) - config.load() - assert config.get("stages.init.completed") is True - assert config.get("stages.init.timestamp") is not None diff --git a/tests/tracking/__init__.py b/tests/tracking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_tracking.py b/tests/tracking/test___init__.py similarity index 100% rename from tests/test_tracking.py rename to tests/tracking/test___init__.py diff --git a/tests/ui/__init__.py b/tests/ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_console.py b/tests/ui/test_console.py similarity index 100% rename from tests/test_console.py rename to tests/ui/test_console.py diff --git a/tests/test_prompt_input.py b/tests/ui/test_prompt_input.py similarity index 100% rename from tests/test_prompt_input.py rename to tests/ui/test_prompt_input.py diff --git a/tests/test_stage_orchestrator.py b/tests/ui/test_stage_orchestrator.py similarity index 100% rename from tests/test_stage_orchestrator.py rename to tests/ui/test_stage_orchestrator.py diff --git a/tests/test_tui_adapter.py b/tests/ui/test_tui_adapter.py similarity index 100% rename from tests/test_tui_adapter.py rename to tests/ui/test_tui_adapter.py diff --git a/tests/test_tui_widgets.py b/tests/ui/test_tui_widgets.py similarity index 81% rename from tests/test_tui_widgets.py rename to tests/ui/test_tui_widgets.py index 46f95b9..202ca4b 100644 --- a/tests/test_tui_widgets.py +++ b/tests/ui/test_tui_widgets.py @@ -195,27 +195,6 @@ async def test_info_bar_updates(): # No exception = success -@pytest.mark.asyncio -async def test_prompt_input_disable(): - """PromptInput should be disabled by default.""" - app = PrototypeApp() - async with app.run_test() as pilot: # noqa: F841 - prompt = app.prompt_input - assert prompt._enabled is False - assert prompt.read_only is True - - -@pytest.mark.asyncio -async def test_prompt_input_enable(): - """PromptInput should allow enabling for input.""" - app = PrototypeApp() - async with app.run_test() as pilot: # noqa: F841 - prompt = app.prompt_input - prompt.enable() - assert prompt._enabled is True - assert prompt.read_only is False - - @pytest.mark.asyncio async def test_file_list(): """ConsoleView should render file lists.""" @@ -249,40 +228,3 @@ async def test_console_view_write_markup_invalid_falls_back(): app.console_view.write_markup("[invalid_tag_that_wont_parse") -# -------------------------------------------------------------------- # -# PromptInput allow_empty tests -# -------------------------------------------------------------------- # - - -@pytest.mark.asyncio -async def test_prompt_input_allow_empty(): - """PromptInput with allow_empty=True should submit empty string.""" - app = PrototypeApp() - async with app.run_test() as pilot: # noqa: F841 - prompt = app.prompt_input - prompt.enable(allow_empty=True) - assert prompt._allow_empty is True - assert prompt._enabled is True - - -@pytest.mark.asyncio -async def test_prompt_input_default_no_allow_empty(): - """PromptInput defaults to allow_empty=False.""" - app = PrototypeApp() - async with app.run_test() as pilot: # noqa: F841 - prompt = app.prompt_input - prompt.enable() - assert prompt._allow_empty is False - - -@pytest.mark.asyncio -async def test_prompt_input_input_mode(): - """In input mode (default), text has '> ' prefix and placeholder is empty.""" - app = PrototypeApp() - async with app.run_test() as pilot: # noqa: F841 - prompt = app.prompt_input - prompt.enable() - assert prompt._allow_empty is False - assert prompt._enabled is True - assert prompt.text == "> " - assert prompt.placeholder == ""