Skip to content

Commit d429040

Browse files
committed
Agent service filtering + ReDoS fix in transform handlers
- Add stage_services field to AgentContext, populated by _agent_build_context() and passed through _apply_governance_check() to reduce false positive anti-pattern warnings for irrelevant service namespaces - Extract _find_azapi_blocks() shared brace-counting helper and rewrite _add_response_export_values, _add_resource_group_parent_id, and _remove_private_endpoint_resources to eliminate nested-quantifier regex - 5 new tests (2 service filtering, 3 brace counting safety)
1 parent 699ee7d commit d429040

6 files changed

Lines changed: 186 additions & 81 deletions

File tree

HISTORY.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@ Generation prompt improvements
4141
now appears before the architecture context in the generation prompt,
4242
reducing unused ``terraform_remote_state`` data sources.
4343

44+
Agent-level service filtering
45+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
46+
* **Agent governance checks now filter by service namespace** — added
47+
``stage_services`` field to ``AgentContext``, populated by
48+
``_agent_build_context()``. ``_apply_governance_check()`` now passes
49+
stage services to ``validate_response()``, reducing false positive
50+
anti-pattern warnings for irrelevant service namespaces.
51+
52+
ReDoS fix in transform handlers
53+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
54+
* **Replaced nested-quantifier regex with brace counting** — extracted
55+
shared ``_find_azapi_blocks()`` helper and rewrote
56+
``_add_response_export_values``, ``_add_resource_group_parent_id``,
57+
and ``_remove_private_endpoint_resources`` to use it. Eliminates
58+
potential exponential backtracking on pathological input.
59+
4460
Test suite consolidation
4561
~~~~~~~~~~~~~~~~~~~~~~~~~~
4662
* **Consolidated and enhanced unit test coverage** — migrated flat test

azext_prototype/agents/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class AgentContext:
8181
artifacts: dict[str, Any] = field(default_factory=dict)
8282
shared_state: dict[str, Any] = field(default_factory=dict)
8383
mcp_manager: Any = None # MCPManager | None — typed as Any to avoid circular import
84+
stage_services: list[str] | None = None # ARM namespaces for service filtering
8485

8586
def add_artifact(self, key: str, value: Any):
8687
"""Store an artifact for other agents to reference."""
@@ -299,7 +300,7 @@ def _apply_governance_check(self, response: AIResponse, context: AgentContext) -
299300
avoid duplicating the governance warning block.
300301
"""
301302
iac_tool = context.project_config.get("project", {}).get("iac_tool") if context.project_config else None
302-
warnings = self.validate_response(response.content, iac_tool=iac_tool, services=None)
303+
warnings = self.validate_response(response.content, iac_tool=iac_tool, services=context.stage_services)
303304
if warnings:
304305
for w in warnings:
305306
logger.warning("Governance: %s", w)

azext_prototype/governance/transforms/__init__.py

Lines changed: 73 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,32 @@ def _remove_unused_remote_state(content: str, stage_content: str | None = None)
268268
return result
269269

270270

271+
def _find_azapi_blocks(content: str) -> list[tuple[int, int, str, str]]:
272+
"""Find all ``azapi_resource`` blocks using brace counting.
273+
274+
Returns a list of ``(start, end, resource_name, block_text)`` tuples
275+
where *start*/*end* are character offsets into *content*.
276+
"""
277+
pattern = re.compile(r'resource\s+"azapi_resource"\s+"(\w+)"\s*\{')
278+
blocks: list[tuple[int, int, str, str]] = []
279+
for match in pattern.finditer(content):
280+
name = match.group(1)
281+
start = match.start()
282+
brace_start = match.end() - 1
283+
depth = 1
284+
pos = brace_start + 1
285+
while pos < len(content) and depth > 0:
286+
if content[pos] == "{":
287+
depth += 1
288+
elif content[pos] == "}":
289+
depth -= 1
290+
pos += 1
291+
if depth != 0:
292+
continue # malformed block
293+
blocks.append((start, pos, name, content[start:pos]))
294+
return blocks
295+
296+
271297
def _remove_private_endpoint_resources(content: str) -> str:
272298
"""Remove private endpoint and DNS zone resources from non-networking stages.
273299
@@ -287,42 +313,20 @@ def _remove_private_endpoint_resources(content: str) -> str:
287313
"virtualnetworklinks",
288314
)
289315

290-
# Find resource block starts and use brace counting to find the end
291-
block_start_pattern = re.compile(
292-
r'resource\s+"azapi_resource"\s+"(\w+)"\s*\{',
293-
)
294-
295316
removed_names: list[str] = []
296317
result = content
297318

298-
for match in reversed(list(block_start_pattern.finditer(result))):
299-
resource_name = match.group(1)
300-
# Find the matching closing brace using brace counting
301-
start = match.start()
302-
brace_start = match.end() - 1 # position of opening {
303-
depth = 1
304-
pos = brace_start + 1
305-
while pos < len(result) and depth > 0:
306-
if result[pos] == "{":
307-
depth += 1
308-
elif result[pos] == "}":
309-
depth -= 1
310-
pos += 1
311-
if depth != 0:
312-
continue # malformed block, skip
313-
314-
block_text = result[start:pos]
315-
# Check if this block's type is a PE/DNS type
319+
for start, end, resource_name, block_text in reversed(_find_azapi_blocks(result)):
316320
type_match = re.search(r'type\s*=\s*"([^"]+)"', block_text)
317321
if not type_match:
318322
continue
319323
resource_type = type_match.group(1).lower()
320324
if any(pt in resource_type for pt in pe_types):
321325
# Remove the block plus any trailing whitespace/newlines
322-
end = pos
323-
while end < len(result) and result[end] in ("\n", "\r", " "):
324-
end += 1
325-
result = result[:start] + result[end:]
326+
trim_end = end
327+
while trim_end < len(result) and result[trim_end] in ("\n", "\r", " "):
328+
trim_end += 1
329+
result = result[:start] + result[trim_end:]
326330
removed_names.append(resource_name)
327331
logger.debug("Removed PE/DNS resource: azapi_resource.%s", resource_name)
328332

@@ -351,29 +355,26 @@ def _remove_private_endpoint_resources(content: str) -> str:
351355
def _add_response_export_values(content: str) -> str:
352356
"""Add ``response_export_values = ["*"]`` to azapi_resource blocks missing it.
353357
354-
Finds each ``resource "azapi_resource" "name" { ... }`` block and checks
355-
if ``response_export_values`` appears inside it. If missing, inserts it
356-
after the ``parent_id`` line (or after ``type`` if no ``parent_id``).
358+
Uses brace-counting via :func:`_find_azapi_blocks` to avoid nested-quantifier
359+
regex (ReDoS risk). Inserts after ``parent_id``, ``location``, or ``type``.
357360
"""
358-
# Match azapi_resource blocks
359-
block_pattern = re.compile(
360-
r'(resource\s+"azapi_resource"\s+"\w+"\s*\{)(.*?\n)((?:.*?\n)*?)(})',
361-
re.DOTALL,
362-
)
363-
364-
def _inject(match: re.Match) -> str: # type: ignore[type-arg]
365-
full = match.group(0)
366-
if "response_export_values" in full:
367-
return full # already has it
361+
result = content
362+
for start, end, _name, block_text in reversed(_find_azapi_blocks(result)):
363+
if "response_export_values" in block_text:
364+
continue
368365

369-
header = match.group(1)
370-
first_line = match.group(2)
371-
body = match.group(3)
372-
closing = match.group(4)
366+
# Split block body (after the opening { line) into lines
367+
header_end = block_text.index("{") + 1
368+
header = block_text[:header_end]
369+
body_plus_close = block_text[header_end:]
370+
# Remove the final closing brace
371+
body = body_plus_close.rstrip()
372+
if body.endswith("}"):
373+
body = body[:-1]
374+
closing = "}"
373375

374-
# Find insertion point: after parent_id, or after location, or after type
375-
lines = (first_line + body).splitlines(keepends=True)
376-
insert_idx = len(lines) # fallback: before closing brace
376+
lines = body.splitlines(keepends=True)
377+
insert_idx = len(lines)
377378
for i, line in enumerate(lines):
378379
stripped = line.strip()
379380
if stripped.startswith("parent_id"):
@@ -384,69 +385,62 @@ def _inject(match: re.Match) -> str: # type: ignore[type-arg]
384385
elif stripped.startswith("type") and insert_idx == len(lines):
385386
insert_idx = i + 1
386387

387-
# Detect indentation from the type/parent_id line
388388
indent = " "
389-
if insert_idx > 0 and insert_idx <= len(lines):
389+
if 0 < insert_idx <= len(lines):
390390
prev_line = lines[insert_idx - 1]
391391
leading = len(prev_line) - len(prev_line.lstrip())
392392
indent = " " * leading
393393

394394
lines.insert(insert_idx, f'\n{indent}response_export_values = ["*"]\n')
395-
return header + "".join(lines) + closing
395+
new_block = header + "".join(lines) + closing
396+
result = result[:start] + new_block + result[end:]
396397

397-
new_content = block_pattern.sub(_inject, content)
398-
if new_content != content:
398+
if result != content:
399399
logger.debug("Added response_export_values to azapi_resource blocks")
400-
return new_content
400+
return result
401401

402402

403403
def _add_resource_group_parent_id(content: str) -> str:
404404
"""Add ``parent_id`` to resource group azapi_resource blocks missing it.
405405
406-
Finds ``azapi_resource`` blocks whose type contains
407-
``Microsoft.Resources/resourceGroups`` and injects
408-
``parent_id = "/subscriptions/${var.subscription_id}"``
409-
after the ``name`` line.
406+
Uses brace-counting via :func:`_find_azapi_blocks` to avoid nested-quantifier
407+
regex (ReDoS risk). Injects after the ``name`` line.
410408
"""
411-
# Match azapi_resource blocks with resourceGroups type
412-
block_pattern = re.compile(
413-
r'(resource\s+"azapi_resource"\s+"\w+"\s*\{)(.*?)(})',
414-
re.DOTALL,
415-
)
416-
417-
def _inject(match: re.Match) -> str: # type: ignore[type-arg]
418-
full = match.group(0)
419-
if "resourcegroups" not in full.lower():
420-
return full
421-
if "parent_id" in full:
422-
return full # already has it
409+
result = content
410+
for start, end, _name, block_text in reversed(_find_azapi_blocks(result)):
411+
if "resourcegroups" not in block_text.lower():
412+
continue
413+
if "parent_id" in block_text:
414+
continue
423415

424-
header = match.group(1)
425-
body = match.group(2)
426-
closing = match.group(3)
416+
header_end = block_text.index("{") + 1
417+
header = block_text[:header_end]
418+
body_plus_close = block_text[header_end:]
419+
body = body_plus_close.rstrip()
420+
if body.endswith("}"):
421+
body = body[:-1]
422+
closing = "}"
427423

428-
# Insert after the name line
429424
lines = body.splitlines(keepends=True)
430425
insert_idx = len(lines)
431426
for i, line in enumerate(lines):
432427
if line.strip().startswith("name"):
433428
insert_idx = i + 1
434429
break
435430

436-
# Detect indentation
437431
indent = " "
438-
if insert_idx > 0 and insert_idx <= len(lines):
432+
if 0 < insert_idx <= len(lines):
439433
prev_line = lines[insert_idx - 1]
440434
leading = len(prev_line) - len(prev_line.lstrip())
441435
indent = " " * leading
442436

443437
lines.insert(insert_idx, f'{indent}parent_id = "/subscriptions/${{var.subscription_id}}"\n')
444-
return header + "".join(lines) + closing
438+
new_block = header + "".join(lines) + closing
439+
result = result[:start] + new_block + result[end:]
445440

446-
new_content = block_pattern.sub(_inject, content)
447-
if new_content != content:
441+
if result != content:
448442
logger.debug("Added parent_id to resource group azapi_resource")
449-
return new_content
443+
return result
450444

451445

452446
_STRUCTURED_HANDLERS: dict[str, Callable] = {

azext_prototype/stages/build_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,10 +1843,13 @@ def _agent_build_context(self, agent: Any, stage: dict) -> Iterator[Any]:
18431843
layer = stage.get("layer", "")
18441844
self._apply_governor_brief(agent, stage.get("name", ""), stage.get("services", []), layer)
18451845
self._apply_stage_knowledge(agent, stage)
1846+
svc_types = [s.get("resource_type", "") for s in stage.get("services", []) if s.get("resource_type")]
1847+
self._context.stage_services = svc_types or None
18461848
try:
18471849
yield agent
18481850
finally:
18491851
agent.set_knowledge_override("")
1852+
self._context.stage_services = None
18501853

18511854
def _apply_stage_knowledge(self, agent: Any, stage: dict) -> None:
18521855
"""Set stage-specific knowledge on the agent.

tests/agents/test_agents.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for azext_prototype.agents — registry, loader, base."""
22

3-
from unittest.mock import MagicMock
3+
from unittest.mock import MagicMock, patch
44

55
import pytest
66
import yaml
@@ -507,6 +507,39 @@ def test_cloud_architect_injects_azure_api_version_for_bicep(self):
507507
assert "deployment-language-bicep" in joined
508508

509509

510+
# ------------------------------------------------------------------
511+
# Agent-level service filtering
512+
# ------------------------------------------------------------------
513+
514+
515+
class TestAgentServiceFiltering:
516+
"""AgentContext must carry stage_services so _apply_governance_check can filter."""
517+
518+
def test_agent_context_has_stage_services_field(self):
519+
from azext_prototype.agents.base import AgentContext
520+
521+
ctx = AgentContext(project_config={}, project_dir="/tmp", ai_provider=None)
522+
assert ctx.stage_services is None, "Default stage_services should be None"
523+
524+
def test_apply_governance_check_passes_stage_services(self):
525+
from azext_prototype.agents.base import AgentContext
526+
527+
agent = StubAgent()
528+
provider = MagicMock()
529+
ctx = AgentContext(project_config={"project": {"iac_tool": "terraform"}}, project_dir="/tmp", ai_provider=provider)
530+
ctx.stage_services = ["Microsoft.KeyVault/vaults"]
531+
532+
response = AIResponse(content="resource content", model="test", usage={})
533+
534+
with patch.object(agent, "validate_response", return_value=[]) as mock_validate:
535+
agent._apply_governance_check(response, ctx)
536+
mock_validate.assert_called_once()
537+
call_kwargs = mock_validate.call_args
538+
assert call_kwargs[1]["services"] == ["Microsoft.KeyVault/vaults"], (
539+
f"Expected stage_services to be passed through, got: {call_kwargs[1].get('services')}"
540+
)
541+
542+
510543
# ------------------------------------------------------------------
511544
# QA checklist content requirements
512545
# ------------------------------------------------------------------

tests/governance/transforms/test_transforms.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,61 @@ def test_cross_file_unused_remote_state_removed(self):
393393
)
394394
assert "TFM-TF-001" in ids
395395
assert "terraform_remote_state" not in result
396+
397+
398+
# ------------------------------------------------------------------
399+
# ReDoS safety: brace-counting replaces nested quantifier regex
400+
# ------------------------------------------------------------------
401+
402+
403+
class TestBraceCountingSafety:
404+
"""Transform handlers must use brace counting, not nested-quantifier regex."""
405+
406+
def test_response_export_values_pathological_input(self):
407+
"""Long line with no newlines must complete in <1 second (no backtracking)."""
408+
import time
409+
410+
# Pathological: very long body with no newlines, followed by closing brace
411+
long_body = "x" * 50000
412+
content = f'resource "azapi_resource" "kv" {{\n type = "Microsoft.KeyVault/vaults@2023-07-01"\n {long_body}\n}}\n'
413+
414+
start = time.monotonic()
415+
result = _add_response_export_values(content)
416+
elapsed = time.monotonic() - start
417+
418+
assert elapsed < 1.0, f"_add_response_export_values took {elapsed:.2f}s on pathological input (ReDoS?)"
419+
assert 'response_export_values = ["*"]' in result
420+
421+
def test_resource_group_parent_id_pathological_input(self):
422+
"""Long line with no newlines must complete in <1 second (no backtracking)."""
423+
import time
424+
425+
long_body = "x" * 50000
426+
content = f'resource "azapi_resource" "rg" {{\n type = "Microsoft.Resources/resourceGroups@2024-03-01"\n name = var.rg\n {long_body}\n}}\n'
427+
428+
start = time.monotonic()
429+
result = _add_resource_group_parent_id(content)
430+
elapsed = time.monotonic() - start
431+
432+
assert elapsed < 1.0, f"_add_resource_group_parent_id took {elapsed:.2f}s on pathological input (ReDoS?)"
433+
assert "parent_id" in result
434+
435+
def test_find_azapi_blocks_nested_braces(self):
436+
"""Brace counting must handle nested blocks correctly."""
437+
from azext_prototype.governance.transforms import _find_azapi_blocks
438+
439+
content = """resource "azapi_resource" "kv" {
440+
type = "Microsoft.KeyVault/vaults@2023-07-01"
441+
body = {
442+
properties = {
443+
tenantId = var.tenant_id
444+
}
445+
}
446+
}
447+
"""
448+
blocks = _find_azapi_blocks(content)
449+
assert len(blocks) == 1
450+
start, end, name, block_text = blocks[0]
451+
assert name == "kv"
452+
assert block_text.startswith('resource "azapi_resource"')
453+
assert block_text.rstrip().endswith("}")

0 commit comments

Comments
 (0)