Skip to content

Commit 1b14a52

Browse files
committed
fix gameplay heuristic fallback, add per-bot model runtime, and normalize web sfx asset paths
1 parent c68b018 commit 1b14a52

15 files changed

Lines changed: 534 additions & 20 deletions

File tree

src/api/modules/gameplay/router.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Depends,
1111
HTTPException,
1212
Query,
13+
Request,
1314
WebSocket,
1415
WebSocketDisconnect,
1516
status,
@@ -49,6 +50,15 @@
4950
logger = logging.getLogger(__name__)
5051

5152

53+
def _resolve_inference_service(request: Request) -> InferenceService:
54+
"""Resolve inference service honoring FastAPI dependency overrides in tests."""
55+
provider = request.app.dependency_overrides.get(
56+
get_inference_service_dep,
57+
get_inference_service_dep,
58+
)
59+
return provider()
60+
61+
5262
async def _to_game_response(
5363
gameplay_service: GameplayService,
5464
game: Game,
@@ -115,7 +125,7 @@ async def _broadcast_move_applied(
115125
)
116126
def post_move(
117127
request: MoveRequest,
118-
inference_service: InferenceService = INFERENCE_SERVICE_DEP,
128+
http_request: Request,
119129
) -> MoveResponse:
120130
try:
121131
board = board_from_state(request.board.model_dump())
@@ -127,6 +137,7 @@ def post_move(
127137

128138
mode = request.mode
129139
if mode in {"fast", "strong"}:
140+
inference_service = _resolve_inference_service(http_request)
130141
result = inference_service.predict(board=board, mode=mode)
131142
move_payload: MovePayload | None = None
132143
if result.move is not None:

src/api/modules/identity/repository.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,23 @@ async def create(self, user: User) -> User:
2323
async def get_by_id(self, user_id: UUID) -> User | None:
2424
return await self.session.get(User, user_id)
2525

26+
async def save_user(self, user: User) -> User:
27+
self.session.add(user)
28+
await self.session.commit()
29+
await self.session.refresh(user)
30+
return user
31+
32+
async def get_bot_profile(self, user_id: UUID) -> BotProfile | None:
33+
stmt = select(BotProfile).where(col(BotProfile.user_id) == user_id)
34+
result = await self.session.execute(stmt)
35+
return result.scalars().first()
36+
37+
async def save_bot_profile(self, profile: BotProfile) -> BotProfile:
38+
self.session.add(profile)
39+
await self.session.commit()
40+
await self.session.refresh(profile)
41+
return profile
42+
2643
async def count_users(self) -> int:
2744
stmt = select(func.count()).select_from(User)
2845
result = await self.session.execute(stmt)

src/api/modules/identity/router.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from api.deps.identity import get_identity_service_dep
1010
from api.modules.identity.schemas import (
1111
BotProfileListResponse,
12+
BotProfileResponse,
13+
BotProfileUpsertRequest,
1214
PublicPlayerListResponse,
1315
UserCreateRequest,
1416
UserListResponse,
@@ -51,6 +53,34 @@ async def post_user(
5153
return UserResponse.model_validate(user)
5254

5355

56+
@router.post(
57+
"/bot-profiles",
58+
response_model=BotProfileResponse,
59+
status_code=status.HTTP_201_CREATED,
60+
summary="Create/Update Bot Profile (Admin)",
61+
description="Upserts bot behavior for a user account (heuristic/model). Requires admin privileges.",
62+
responses={
63+
201: {"description": "Bot profile stored successfully."},
64+
401: {"description": "Missing or invalid access token."},
65+
403: {"description": "Admin privileges required."},
66+
404: {"description": "User not found."},
67+
},
68+
)
69+
async def post_bot_profile(
70+
request: BotProfileUpsertRequest,
71+
identity_service: IdentityService = IDENTITY_SERVICE_DEP,
72+
admin_user: User = ADMIN_USER_DEP,
73+
) -> BotProfileResponse:
74+
del admin_user
75+
try:
76+
return await identity_service.upsert_bot_profile(request)
77+
except LookupError as exc:
78+
raise HTTPException(
79+
status_code=status.HTTP_404_NOT_FOUND,
80+
detail=str(exc),
81+
) from exc
82+
83+
5484
@router.get(
5585
"/users/{user_id}",
5686
response_model=UserResponse,

src/api/modules/identity/schemas.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from datetime import datetime
4+
from typing import Literal
45
from uuid import UUID
56

67
from pydantic import BaseModel, ConfigDict, Field
@@ -111,6 +112,7 @@ class BotProfileResponse(BaseModel):
111112
"agent_type": "heuristic",
112113
"heuristic_level": "normal",
113114
"model_mode": None,
115+
"model_version_id": None,
114116
"enabled": True,
115117
}
116118
}
@@ -122,9 +124,31 @@ class BotProfileResponse(BaseModel):
122124
agent_type: str
123125
heuristic_level: str | None
124126
model_mode: str | None
127+
model_version_id: UUID | None
125128
enabled: bool
126129

127130

131+
class BotProfileUpsertRequest(BaseModel):
132+
model_config = ConfigDict(
133+
json_schema_extra={
134+
"example": {
135+
"user_id": "5f6e8d34-292d-434f-a8ff-f48f4f3040f9",
136+
"agent_type": "model",
137+
"model_mode": "fast",
138+
"model_version_id": "1932ac4a-5dcf-4dc9-8f99-fdf7ef20cc99",
139+
"enabled": True,
140+
}
141+
}
142+
)
143+
144+
user_id: UUID
145+
agent_type: Literal["heuristic", "model"]
146+
heuristic_level: Literal["easy", "normal", "hard"] | None = None
147+
model_mode: Literal["fast", "strong"] | None = None
148+
model_version_id: UUID | None = None
149+
enabled: bool = True
150+
151+
128152
class BotProfileListResponse(BaseModel):
129153
model_config = ConfigDict(
130154
json_schema_extra={
@@ -156,6 +180,7 @@ class PublicPlayerResponse(BaseModel):
156180
"agent_type": "heuristic",
157181
"heuristic_level": "hard",
158182
"model_mode": None,
183+
"model_version_id": None,
159184
"enabled": True,
160185
}
161186
}
@@ -168,6 +193,7 @@ class PublicPlayerResponse(BaseModel):
168193
agent_type: str | None
169194
heuristic_level: str | None
170195
model_mode: str | None
196+
model_version_id: UUID | None
171197
enabled: bool | None
172198

173199

src/api/modules/identity/service.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
from sqlalchemy.exc import IntegrityError
77

8-
from api.db.models import User
8+
from api.db.enums import AgentType, BotKind
9+
from api.db.models import BotProfile, User
910
from api.modules.identity.repository import UserRepository
1011
from api.modules.identity.schemas import (
1112
BotProfileResponse,
13+
BotProfileUpsertRequest,
1214
PublicPlayerResponse,
1315
UserCreateRequest,
1416
)
@@ -68,12 +70,54 @@ async def list_playable_bots(
6870
agent_type=profile.agent_type.value,
6971
heuristic_level=profile.heuristic_level,
7072
model_mode=profile.model_mode,
73+
model_version_id=user.model_version_id,
7174
enabled=profile.enabled,
7275
)
7376
for user, profile in rows
7477
]
7578
return total, items
7679

80+
async def upsert_bot_profile(self, payload: BotProfileUpsertRequest) -> BotProfileResponse:
81+
user = await self.user_repository.get_by_id(payload.user_id)
82+
if user is None:
83+
raise LookupError(f"User not found: {payload.user_id}")
84+
85+
profile = await self.user_repository.get_bot_profile(payload.user_id)
86+
if profile is None:
87+
profile = BotProfile(user_id=payload.user_id, agent_type=AgentType.HEURISTIC)
88+
89+
# Promote the account into bot mode so matchmaking/listing sees it.
90+
user.is_bot = True
91+
user.is_hidden_bot = False
92+
93+
if payload.agent_type == "heuristic":
94+
profile.agent_type = AgentType.HEURISTIC
95+
profile.heuristic_level = payload.heuristic_level or "normal"
96+
profile.model_mode = None
97+
user.bot_kind = BotKind.HEURISTIC
98+
if payload.model_version_id is not None:
99+
user.model_version_id = payload.model_version_id
100+
else:
101+
profile.agent_type = AgentType.MODEL
102+
profile.model_mode = payload.model_mode or "fast"
103+
profile.heuristic_level = None
104+
user.bot_kind = BotKind.MODEL
105+
user.model_version_id = payload.model_version_id
106+
107+
profile.enabled = payload.enabled
108+
await self.user_repository.save_user(user)
109+
profile = await self.user_repository.save_bot_profile(profile)
110+
return BotProfileResponse(
111+
user_id=user.id,
112+
username=user.username,
113+
bot_kind=user.bot_kind,
114+
agent_type=profile.agent_type.value,
115+
heuristic_level=profile.heuristic_level,
116+
model_mode=profile.model_mode,
117+
model_version_id=user.model_version_id,
118+
enabled=profile.enabled,
119+
)
120+
77121
async def list_public_players(
78122
self, *, limit: int = 50, offset: int = 0, query: str | None = None
79123
) -> tuple[int, list[PublicPlayerResponse]]:
@@ -95,6 +139,7 @@ async def list_public_players(
95139
agent_type=profile.agent_type.value if profile else None,
96140
heuristic_level=profile.heuristic_level if profile else None,
97141
model_mode=profile.model_mode if profile else None,
142+
model_version_id=user.model_version_id,
98143
enabled=profile.enabled if profile else None,
99144
)
100145
for user, profile in rows
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import annotations
2+
3+
from functools import lru_cache
4+
from urllib.parse import unquote, urlparse
5+
6+
from api.db.models import ModelVersion
7+
from inference.service import InferenceService
8+
9+
10+
def _normalize_local_artifact_path(uri: str | None) -> str | None:
11+
if uri is None:
12+
return None
13+
cleaned = uri.strip()
14+
if cleaned == "":
15+
return None
16+
17+
parsed = urlparse(cleaned)
18+
if parsed.scheme in {"", "file"}:
19+
if parsed.scheme == "file":
20+
if parsed.netloc not in {"", "localhost"}:
21+
raise ValueError(f"Unsupported artifact URI host: {cleaned}")
22+
path = unquote(parsed.path)
23+
if path == "":
24+
raise ValueError(f"Invalid artifact URI (empty path): {cleaned}")
25+
return path
26+
return cleaned
27+
28+
raise ValueError(
29+
f"Unsupported artifact URI scheme '{parsed.scheme}'. Use local paths or file:// URIs."
30+
)
31+
32+
33+
def _runtime_config_from_base(base_service: InferenceService | None) -> tuple[str, int, float, bool]:
34+
if base_service is None:
35+
return "auto", 160, 1.5, True
36+
return (
37+
base_service.device,
38+
base_service.mcts_sims,
39+
base_service.c_puct,
40+
base_service.prefer_onnx,
41+
)
42+
43+
44+
@lru_cache(maxsize=32)
45+
def _build_cached_inference_service(
46+
checkpoint_path: str,
47+
onnx_path: str,
48+
device: str,
49+
mcts_sims: int,
50+
c_puct: float,
51+
prefer_onnx: bool,
52+
) -> InferenceService:
53+
return InferenceService(
54+
checkpoint_path=checkpoint_path,
55+
onnx_path=onnx_path if onnx_path else None,
56+
device=device,
57+
mcts_sims=mcts_sims,
58+
c_puct=c_puct,
59+
prefer_onnx=prefer_onnx,
60+
)
61+
62+
63+
def resolve_model_inference_service(
64+
*,
65+
version: ModelVersion,
66+
base_service: InferenceService | None,
67+
) -> InferenceService:
68+
checkpoint_path = _normalize_local_artifact_path(version.checkpoint_uri)
69+
onnx_path = _normalize_local_artifact_path(version.onnx_uri)
70+
if checkpoint_path is None and onnx_path is None:
71+
raise ValueError(
72+
f"Model version '{version.name}' has no local checkpoint_uri/onnx_uri configured."
73+
)
74+
75+
# Keep runtime knobs aligned with the API default inference service so
76+
# per-account model selection does not silently change move quality.
77+
device, mcts_sims, c_puct, prefer_onnx = _runtime_config_from_base(base_service)
78+
return _build_cached_inference_service(
79+
checkpoint_path=checkpoint_path or "",
80+
onnx_path=onnx_path or "",
81+
device=device,
82+
mcts_sims=mcts_sims,
83+
c_puct=c_puct,
84+
prefer_onnx=prefer_onnx,
85+
)
86+
87+
88+
__all__ = ["resolve_model_inference_service"]

src/api/modules/matches/repository.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sqlmodel import desc, select
88

99
from api.db.enums import AgentType, GameStatus, QueueType
10-
from api.db.models import BotProfile, Game, GameMove, User
10+
from api.db.models import BotProfile, Game, GameMove, ModelVersion, User
1111

1212

1313
class MatchesRepository:
@@ -26,6 +26,9 @@ async def get_game(self, game_id: UUID) -> Game | None:
2626
async def get_user(self, user_id: UUID) -> User | None:
2727
return await self.session.get(User, user_id)
2828

29+
async def get_model_version(self, version_id: UUID) -> ModelVersion | None:
30+
return await self.session.get(ModelVersion, version_id)
31+
2932
async def get_bot_profile(self, user_id: UUID) -> BotProfile | None:
3033
stmt = select(BotProfile).where(BotProfile.user_id == user_id)
3134
result = await self.session.execute(stmt)

0 commit comments

Comments
 (0)