Skip to content

Commit a4e8fc8

Browse files
committed
a0222
1 parent f7fa2fd commit a4e8fc8

23 files changed

Lines changed: 84 additions & 116 deletions

app/core/config.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,7 @@ class Settings(BaseSettings):
4444
# ================================
4545
# Device configuration
4646
# ================================
47-
DEVICE: str = Field(
48-
default="cuda:0",
49-
description="e.g., 'cuda:0', 'cpu', 'mps' (macOS), 'cuda:1', etc."
50-
)
47+
DEVICE: str = Field(default="cuda:0", description="e.g., 'cuda:0', 'cpu', 'mps' (macOS), 'cuda:1', etc.")
5148

5249
# ================================
5350
# Model cache paths
@@ -87,12 +84,7 @@ class Settings(BaseSettings):
8784
# ================================
8885
# pydantic-settings configuration
8986
# ================================
90-
model_config = SettingsConfigDict(
91-
env_file=".env",
92-
env_prefix="APP_",
93-
case_sensitive=False,
94-
extra="ignore"
95-
)
87+
model_config = SettingsConfigDict(env_file=".env", env_prefix="APP_", case_sensitive=False, extra="ignore")
9688

9789
# ================================
9890
# Validators
@@ -114,8 +106,9 @@ def default_cache_dirs(cls, v, info):
114106
return Path(root) / "huggingface" / "hub"
115107
return Path(v) if isinstance(v, str) else v
116108

117-
@field_validator("MODEL_CACHE_ROOT", "STATIC_DIR", "TEMPLATES_DIR", "UPLOAD_DIR",
118-
"ERROR_LOG_FILE", "PLUGINS_LOG_FILE")
109+
@field_validator(
110+
"MODEL_CACHE_ROOT", "STATIC_DIR", "TEMPLATES_DIR", "UPLOAD_DIR", "ERROR_LOG_FILE", "PLUGINS_LOG_FILE"
111+
)
119112
@classmethod
120113
def ensure_path(cls, v: Path):
121114
"""

app/core/errors.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from fastapi import FastAPI, Request
88
from fastapi.exceptions import RequestValidationError
99
from fastapi.responses import HTMLResponse, JSONResponse
10+
from fastapi.templating import Jinja2Templates
1011
from jinja2 import TemplateNotFound
1112
from pydantic import ValidationError
1213
from starlette.exceptions import HTTPException as StarletteHTTPException
@@ -25,7 +26,6 @@
2526
HTTP_501_NOT_IMPLEMENTED,
2627
HTTP_503_SERVICE_UNAVAILABLE,
2728
)
28-
from fastapi.templating import Jinja2Templates
2929

3030
from app.core.config import get_settings
3131

@@ -114,8 +114,7 @@ def _render(
114114
html = (
115115
f"<h1>{status_code}{message}</h1>"
116116
f"<p><strong>Path:</strong> {payload['path']}</p>"
117-
f"<p><strong>Method:</strong> {payload['method']}</p>"
118-
+ (f"<pre>{details}</pre>" if details else "")
117+
f"<p><strong>Method:</strong> {payload['method']}</p>" + (f"<pre>{details}</pre>" if details else "")
119118
)
120119
return HTMLResponse(content=html, status_code=status_code)
121120

app/core/logging_.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from app.core.config import get_settings
55

6+
67
def setup_logging() -> None:
78
"""
89
Set up logging configuration using the LOG_LEVEL from application settings.

app/main.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
from pathlib import Path
1+
import uuid
2+
23
from fastapi import FastAPI, Request
34
from fastapi.middleware.cors import CORSMiddleware
4-
from fastapi.responses import HTMLResponse, FileResponse
5+
from fastapi.responses import FileResponse, HTMLResponse
56
from fastapi.staticfiles import StaticFiles
67
from fastapi.templating import Jinja2Templates
8+
from starlette.middleware.base import BaseHTTPMiddleware
79

810
from app.core.config import get_settings
911
from app.core.errors import register_exception_handlers
1012
from app.core.logging_ import setup_logging
1113
from app.routes import plugins as plugins_routes
1214

13-
import uuid
14-
from starlette.middleware.base import BaseHTTPMiddleware
15-
1615
# Initialize settings and logging
1716
settings = get_settings()
1817
setup_logging()
@@ -35,6 +34,7 @@
3534
allow_credentials=settings.CORS_ALLOW_CREDENTIALS,
3635
)
3736

37+
3838
# Middleware: unique request ID
3939
class RequestIDMiddleware(BaseHTTPMiddleware):
4040
async def dispatch(self, request, call_next):
@@ -44,26 +44,32 @@ async def dispatch(self, request, call_next):
4444
response.headers["X-Request-ID"] = rid
4545
return response
4646

47+
4748
app.add_middleware(RequestIDMiddleware)
4849

4950
# Register exception handlers
5051
register_exception_handlers(app)
5152

53+
5254
@app.get("/", response_class=HTMLResponse)
5355
def index(request: Request):
5456
return templates.TemplateResponse("index.html", {"request": request, "title": settings.APP_NAME})
5557

58+
5659
@app.get("/health")
5760
def health():
5861
return {"status": "ok"}
5962

63+
6064
@app.get("/env")
6165
def env():
6266
return settings.summary()
6367

68+
6469
@app.get("/favicon.ico", include_in_schema=False)
6570
def favicon():
6671
return FileResponse(str(settings.STATIC_DIR / "favicon.ico"))
6772

73+
6874
# Include plugin routes
6975
app.include_router(plugins_routes.router, tags=["plugins"])

app/plugins/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ def infer(self, payload: Dict[str, Any]) -> Dict[str, Any]:
3434
Returns:
3535
Dict[str, Any]: Inference result.
3636
"""
37-
...
37+
...

app/plugins/dummy/plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ def load(self) -> None:
88
print("[plugin] dummy service ready")
99

1010
def infer(self, payload: dict) -> dict:
11-
return {"task": "ping", "message": "✅ Dummy service is working", "payload_received": payload}
11+
return {"task": "ping", "message": "✅ Dummy service is working", "payload_received": payload}

app/plugins/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,4 @@ def get(name: str) -> AIPlugin | None:
7676

7777

7878
def all_meta() -> Dict[str, Dict[str, Any]]:
79-
return _meta
79+
return _meta

app/plugins/neu_server/plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ def load(self) -> None:
88
pass
99

1010
def infer(self, payload: dict) -> dict:
11-
pass
11+
pass

app/routes/plugins.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
router = APIRouter()
1010

11+
1112
def _ensure_discovered(request: Request) -> None:
1213
"""
1314
Ensure plugin registry and metadata are discovered and cached in app state.
@@ -38,12 +39,7 @@ def list_plugins(request: Request) -> Dict[str, Any]:
3839

3940

4041
@router.post("/plugins/{name}/{task}", summary="Run a task on a plugin")
41-
def run_plugin_task(
42-
name: str,
43-
task: str,
44-
payload: Dict[str, Any],
45-
request: Request
46-
) -> Dict[str, Any]:
42+
def run_plugin_task(name: str, task: str, payload: Dict[str, Any], request: Request) -> Dict[str, Any]:
4743
"""
4844
Execute a specific task on a given plugin.
4945

app/runtime.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
DEVICE = os.getenv("DEVICE", "cuda:0")
2020

21+
2122
def pick_device() -> torch.device:
2223
"""
2324
Select a valid device based on the DEVICE environment variable.
@@ -40,6 +41,7 @@ def pick_device() -> torch.device:
4041

4142
return torch.device("cpu")
4243

44+
4345
def pick_dtype(device: str | None = None) -> torch.dtype:
4446
"""
4547
Select an appropriate torch dtype based on the selected device.
@@ -64,6 +66,7 @@ def pick_dtype(device: str | None = None) -> torch.dtype:
6466

6567
return torch.float32
6668

69+
6770
def cuda_info() -> dict:
6871
"""
6972
Retrieve CUDA and GPU device information.
@@ -85,14 +88,15 @@ def cuda_info() -> dict:
8588
info.update(
8689
{
8790
"gpu_name": props.name,
88-
"total_memory_gb": round(props.total_memory / (1024 ** 3), 2),
91+
"total_memory_gb": round(props.total_memory / (1024**3), 2),
8992
"device_index": idx,
9093
"driver": getattr(torch.cuda, "driver_version", None),
9194
}
9295
)
9396

9497
return info
9598

99+
96100
def warmup() -> dict:
97101
"""
98102
Perform a matrix multiplication to warm up the selected device.
@@ -120,4 +124,4 @@ def warmup() -> dict:
120124
"shape": list(z.shape),
121125
"elapsed_sec": round(dt, 4),
122126
"device": str(dev),
123-
}
127+
}

0 commit comments

Comments
 (0)