Skip to content

Commit 19b7ce1

Browse files
committed
feat(tasks): register tasks in authorization graph on create/delete
1 parent b6e6004 commit 19b7ce1

9 files changed

Lines changed: 466 additions & 165 deletions

File tree

agentex/src/domain/services/task_service.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
from fastapi import Depends
55

6+
from src.adapters.crud_store.exceptions import ItemDoesNotExist
67
from src.adapters.streams.adapter_redis import DRedisStreamRepository
8+
from src.api.schemas.authorization_types import AgentexResource
79
from src.domain.entities.agents import ACPType, AgentEntity
810
from src.domain.entities.events import EventEntity
911
from src.domain.entities.task_message_updates import TaskMessageUpdateEntity
@@ -14,6 +16,7 @@
1416
from src.domain.repositories.task_repository import DTaskRepository
1517
from src.domain.repositories.task_state_repository import DTaskStateRepository
1618
from src.domain.services.agent_acp_service import DAgentACPService
19+
from src.domain.services.authorization_service import DAuthorizationService
1720
from src.utils.ids import orm_id
1821
from src.utils.logging import make_logger
1922
from src.utils.stream_topics import get_task_event_stream_topic
@@ -33,12 +36,14 @@ def __init__(
3336
task_repository: DTaskRepository,
3437
event_repository: DEventRepository,
3538
stream_repository: DRedisStreamRepository,
39+
authorization_service: DAuthorizationService,
3640
):
3741
self.acp_client = acp_client
3842
self.task_state_repository = task_state_repository
3943
self.task_repository = task_repository
4044
self.event_repository = event_repository
4145
self.stream_repository = stream_repository
46+
self.authorization_service = authorization_service
4247

4348
async def create_task(
4449
self,
@@ -59,19 +64,33 @@ async def create_task(
5964
Returns:
6065
Task containing the created task info
6166
"""
62-
63-
task_entity = await self.task_repository.create(
64-
agent_id=agent.id,
65-
task=TaskEntity(
66-
id=orm_id(),
67-
name=task_name,
68-
status=TaskStatus.RUNNING,
69-
status_reason="Task created, forwarding to ACP server",
70-
params=task_params,
71-
task_metadata=task_metadata,
72-
),
67+
# Register in the authorization service before persisting: a registration
68+
# failure aborts the request with no orphaned row. If the persist fails
69+
# after a successful registration, the compensating deregister_resource
70+
# below prevents a dangling authorization entry. Both calls are no-ops
71+
# when the authorization service is disabled for this account.
72+
task_entity = TaskEntity(
73+
id=orm_id(),
74+
name=task_name,
75+
status=TaskStatus.RUNNING,
76+
status_reason="Task created, forwarding to ACP server",
77+
params=task_params,
78+
task_metadata=task_metadata,
79+
)
80+
await self.authorization_service.register_resource(
81+
AgentexResource.task(task_entity.id),
82+
parent=AgentexResource.agent(agent.id),
7383
)
74-
return task_entity
84+
try:
85+
return await self.task_repository.create(
86+
agent_id=agent.id,
87+
task=task_entity,
88+
)
89+
except Exception:
90+
await self.authorization_service.deregister_resource(
91+
AgentexResource.task(task_entity.id),
92+
)
93+
raise
7594

7695
async def create_task_and_forward_to_acp(
7796
self,
@@ -91,7 +110,9 @@ async def create_task_and_forward_to_acp(
91110
Task containing the created task info
92111
"""
93112
task_entity = await self.create_task(
94-
agent=agent, task_name=task_name, task_params=task_params
113+
agent=agent,
114+
task_name=task_name,
115+
task_params=task_params,
95116
)
96117

97118
if agent.acp_type == ACPType.SYNC:
@@ -214,8 +235,35 @@ async def delete_task(self, id: str | None = None, name: str | None = None) -> N
214235
"""
215236
Delete a task from the repository.
216237
"""
238+
# Delete first (Postgres is the source of truth for existence), then
239+
# deregister best-effort: a deregister failure is logged and swallowed
240+
# rather than failing a delete that already succeeded.
241+
# Resolve the id before the delete so we can pass it to deregister_resource;
242+
# looking it up by name afterwards would race. If the name doesn't resolve,
243+
# swallow ItemDoesNotExist and let delete() surface its own native error
244+
# so the missing-task error contract is unchanged.
245+
task_id_for_deregister: str | None = id
246+
if task_id_for_deregister is None and name is not None:
247+
try:
248+
task = await self.task_repository.get(name=name)
249+
task_id_for_deregister = task.id
250+
except ItemDoesNotExist:
251+
task_id_for_deregister = None
252+
217253
await self.task_repository.delete(id=id, name=name)
218254

255+
if task_id_for_deregister is not None:
256+
try:
257+
await self.authorization_service.deregister_resource(
258+
AgentexResource.task(task_id_for_deregister),
259+
)
260+
except Exception:
261+
logger.exception(
262+
"task authorization deregister failed for task %s after successful delete; "
263+
"the deregistration failure has been swallowed",
264+
task_id_for_deregister,
265+
)
266+
219267
async def list_tasks(
220268
self,
221269
*,

agentex/tests/fixtures/services.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Provides factory functions and specific fixtures for creating services with test repositories.
44
"""
55

6-
from unittest.mock import MagicMock, Mock
6+
from unittest.mock import AsyncMock, MagicMock, Mock
77

88
import pytest
99

@@ -12,6 +12,24 @@
1212
# =============================================================================
1313

1414

15+
def make_noop_authorization_service() -> Mock:
16+
"""Shared noop AuthorizationService mock for tests that don't exercise authz.
17+
18+
``principal_context`` is ``None``, and
19+
``grant``/``revoke``/``register_resource``/``deregister_resource`` are async
20+
no-ops returning ``None`` — matching the real service signature. Use this
21+
anywhere a test just needs to construct ``AgentTaskService`` without caring
22+
about authorization behavior.
23+
"""
24+
svc = Mock()
25+
svc.principal_context = None
26+
svc.grant = AsyncMock(return_value=None)
27+
svc.revoke = AsyncMock(return_value=None)
28+
svc.register_resource = AsyncMock(return_value=None)
29+
svc.deregister_resource = AsyncMock(return_value=None)
30+
return svc
31+
32+
1533
def create_task_message_service(task_message_repository):
1634
"""Factory function to create TaskMessageService with given repository"""
1735
from src.domain.services.task_message_service import TaskMessageService
@@ -52,16 +70,21 @@ def create_task_service(
5270
event_repository,
5371
agent_acp_service,
5472
redis_stream_repository,
73+
authorization_service=None,
5574
):
56-
"""Factory function to create AgentTaskService with given repositories and services"""
75+
"""Factory function to create AgentTaskService with given repositories and services."""
5776
from src.domain.services.task_service import AgentTaskService
5877

78+
if authorization_service is None:
79+
authorization_service = make_noop_authorization_service()
80+
5981
return AgentTaskService(
6082
task_repository=task_repository,
6183
task_state_repository=task_state_repository,
6284
event_repository=event_repository,
6385
acp_client=agent_acp_service,
6486
stream_repository=redis_stream_repository,
87+
authorization_service=authorization_service,
6588
)
6689

6790

agentex/tests/integration/fixtures/integration_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from src.config.dependencies import GlobalDependencies
2323
from src.config.environment_variables import EnvironmentVariables
2424

25+
from tests.fixtures.services import make_noop_authorization_service
26+
2527

2628
@pytest.fixture(scope="session")
2729
def event_loop():
@@ -455,6 +457,7 @@ async def send_message(self, *args, **kwargs):
455457
task_repository=isolated_repositories["task_repository"],
456458
event_repository=isolated_repositories["event_repository"],
457459
stream_repository=isolated_repositories["redis_stream_repository"],
460+
authorization_service=make_noop_authorization_service(),
458461
)
459462

460463
return TasksUseCase(task_service=task_service)

agentex/tests/integration/test_task_stream.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from src.domain.use_cases.tasks_use_case import TasksUseCase
88
from src.utils.ids import orm_id
99

10+
from tests.fixtures.services import make_noop_authorization_service
11+
1012

1113
@pytest.mark.asyncio
1214
@pytest.mark.integration
@@ -76,6 +78,7 @@ async def send_message(self, *args, **kwargs):
7678
task_repository=isolated_repositories["task_repository"],
7779
event_repository=isolated_repositories["event_repository"],
7880
stream_repository=isolated_repositories["redis_stream_repository"],
81+
authorization_service=make_noop_authorization_service(),
7982
)
8083

8184
return TasksUseCase(task_service=task_service)
@@ -103,6 +106,7 @@ async def send_message(self, *args, **kwargs):
103106
task_repository=isolated_repositories["task_repository"],
104107
event_repository=isolated_repositories["event_repository"],
105108
stream_repository=isolated_repositories["redis_stream_repository"],
109+
authorization_service=make_noop_authorization_service(),
106110
)
107111

108112
environment_variables = EnvironmentVariables.refresh()
@@ -194,17 +198,17 @@ async def collect_stream_events():
194198
pass
195199

196200
# Then - Verify the stream event was received
197-
assert (
198-
len(stream_events) >= 1
199-
), f"Expected at least 1 stream event, got {len(stream_events)}"
201+
assert len(stream_events) >= 1, (
202+
f"Expected at least 1 stream event, got {len(stream_events)}"
203+
)
200204

201205
# Find the task_updated event
202206
task_updated_events = [
203207
e for e in stream_events if e.get("type") == "task_updated"
204208
]
205-
assert (
206-
len(task_updated_events) >= 1
207-
), f"Expected task_updated event, got events: {[e.get('type') for e in stream_events]}"
209+
assert len(task_updated_events) >= 1, (
210+
f"Expected task_updated event, got events: {[e.get('type') for e in stream_events]}"
211+
)
208212

209213
task_updated_event = task_updated_events[0]
210214

@@ -389,9 +393,9 @@ async def collect_stream_events():
389393
task_updated_events = [
390394
e for e in stream_events if e.get("type") == "task_updated"
391395
]
392-
assert (
393-
len(task_updated_events) >= 3
394-
), f"Expected at least 3 task_updated events, got {len(task_updated_events)}"
396+
assert len(task_updated_events) >= 3, (
397+
f"Expected at least 3 task_updated events, got {len(task_updated_events)}"
398+
)
395399

396400
# Verify each event has the correct metadata for its update
397401
versions = [
@@ -599,8 +603,8 @@ async def collect_stream_data():
599603
pass
600604

601605
# Then - Verify we received at least 2 pings
602-
assert (
603-
ping_count >= 2
604-
), f"Expected at least 2 ping messages during idle period, got {ping_count}"
606+
assert ping_count >= 2, (
607+
f"Expected at least 2 ping messages during idle period, got {ping_count}"
608+
)
605609

606610
print(f"✅ Stream sent {ping_count} keepalive pings during idle period")

agentex/tests/integration/use_cases/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)