Skip to content

Commit 16ba003

Browse files
committed
Add middleware
1 parent fa0583a commit 16ba003

3 files changed

Lines changed: 233 additions & 7 deletions

File tree

lightbug_api/__init__.mojo

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,75 @@ from lightbug_api.context import Context
33
from lightbug_api.response import Response
44
from lightbug_api.routing import (
55
BaseRequest,
6+
ErrorHandler,
67
FromReq,
78
Handler,
89
HandlerResponse,
10+
Middleware,
11+
MiddlewareEntry,
12+
MiddlewareResult,
913
PathPattern,
1014
RouteMatch,
1115
RootRouter,
1216
Router,
17+
abort,
18+
next,
1319
)
1420

1521

22+
# ---------------------------------------------------------- startup hook types
23+
24+
comptime StartupHook = fn () raises
25+
26+
struct StartupHookEntry(Copyable):
27+
var hook: StartupHook
28+
29+
fn __init__(out self, hook: StartupHook):
30+
self.hook = hook
31+
32+
fn __init__(out self, *, copy: Self):
33+
self.hook = copy.hook
34+
35+
36+
# ----------------------------------------------------------------------- App
37+
1638
struct App:
1739
"""The top-level application — register routes then call ``run()``.
1840
19-
Example::
41+
**Quick-start**::
2042
2143
fn main() raises:
2244
var app = App()
2345
46+
# Routes
2447
app.get("/", index)
2548
app.get("/users/{id}", get_user)
2649
app.post("/users", create_user)
2750
app.delete("/users/{id}", delete_user)
2851
52+
# Sub-router mounted at /v1
2953
var api = Router("v1")
3054
api.get("status", health)
31-
app.add_router(api^) # mounts at /v1/status
55+
app.add_router(api^)
56+
57+
# Middleware — runs before every handler
58+
app.use(require_auth)
59+
60+
# Error handler — catches unhandled exceptions from handlers
61+
app.on_error(my_error_handler)
62+
63+
# Startup hook — runs once before the server starts
64+
app.on_startup(connect_db)
3265
3366
app.run()
3467
"""
3568

3669
var router: RootRouter
70+
var startup_hooks: List[StartupHookEntry]
3771

3872
def __init__(out self) raises:
3973
self.router = RootRouter()
74+
self.startup_hooks = List[StartupHookEntry]()
4075

4176
# ------------------------------------------ route registration
4277

@@ -72,22 +107,78 @@ struct App:
72107
"""Mount a sub-router under its path fragment."""
73108
self.router.add_router(router^)
74109

110+
# ------------------------------------------ middleware
111+
112+
def use(mut self, middleware: Middleware) -> None:
113+
"""Add a middleware function that runs before every handler.
114+
115+
Middleware runs in registration order. Use ``next()`` to continue to
116+
the next middleware / handler, or ``abort(response)`` to short-circuit.
117+
118+
Example::
119+
120+
fn log_requests(ctx: Context) raises -> MiddlewareResult:
121+
print(ctx.method(), ctx.path())
122+
return next()
123+
124+
fn require_token(ctx: Context) raises -> MiddlewareResult:
125+
if not ctx.header("X-API-Key"):
126+
return abort(Response.unauthorized("missing X-API-Key"))
127+
return next()
128+
129+
app.use(log_requests)
130+
app.use(require_token)
131+
"""
132+
self.router.use(middleware)
133+
134+
# ------------------------------------------ lifecycle
135+
136+
def on_startup(mut self, hook: StartupHook) -> None:
137+
"""Register a function to run once before the server starts.
138+
139+
Useful for opening database connections, loading config, etc.
140+
141+
Example::
142+
143+
fn connect_db() raises:
144+
print("DB connected")
145+
146+
app.on_startup(connect_db)
147+
"""
148+
self.startup_hooks.append(StartupHookEntry(hook))
149+
150+
def on_error(mut self, handler: ErrorHandler) -> None:
151+
"""Register a custom error handler for unhandled exceptions from handlers.
152+
153+
The default handler logs the error and returns 500 Internal Server Error.
154+
155+
Example::
156+
157+
fn my_errors(ctx: Context, e: Error) raises -> HTTPResponse:
158+
print("Oops:", String(e))
159+
return Response.internal_error(String(e))
160+
161+
app.on_error(my_errors)
162+
"""
163+
self.router.error_handler = handler
164+
75165
# ------------------------------------------ start server
76166

77167
def run(mut self, host: String = "0.0.0.0", port: Int = 8080) raises:
78168
"""Start the HTTP server.
79169
170+
Runs all startup hooks, then begins listening for connections.
171+
80172
Args:
81173
host: Bind address (default ``0.0.0.0``).
82174
port: TCP port (default ``8080``).
83175
"""
176+
for i in range(len(self.startup_hooks)):
177+
self.startup_hooks[i].hook()
84178
var server = Server()
85179
server.listen_and_serve(String(host, ":", port), self.router)
86180

87181
def start_server(mut self, address: String = "0.0.0.0:8080") raises:
88-
"""Start the HTTP server.
89-
90-
Deprecated: use ``run(host, port)`` instead.
91-
"""
182+
"""Deprecated: use ``run(host, port)`` instead."""
92183
var server = Server()
93184
server.listen_and_serve(address, self.router)

lightbug_api/context.mojo

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,56 @@ struct Context(Copyable):
145145
return result.value()
146146
return default
147147

148+
# ------------------------------------------------------- typed path params
149+
150+
fn path_int(self, name: String) -> Optional[Int]:
151+
"""Parse a path parameter as ``Int``.
152+
153+
Returns ``None`` if the parameter is absent or not a valid integer.
154+
"""
155+
var s = self.path_param(name)
156+
if s:
157+
try:
158+
return Optional(atol(s.value()))
159+
except:
160+
pass
161+
return Optional[Int]()
162+
163+
fn path_int(self, name: String, default: Int) -> Int:
164+
"""Parse a path parameter as ``Int``, falling back to *default*."""
165+
var result = self.path_int(name)
166+
if result:
167+
return result.value()
168+
return default
169+
170+
# ------------------------------------------------------ typed query params
171+
172+
fn query_int(self, name: String, default: Int = 0) -> Int:
173+
"""Parse a query parameter as ``Int``, falling back to *default*.
174+
175+
Example: ``ctx.query_int("page", 1)``
176+
"""
177+
var s = self.query(name)
178+
if s:
179+
try:
180+
return atol(s.value())
181+
except:
182+
pass
183+
return default
184+
185+
fn query_bool(self, name: String, default: Bool = False) -> Bool:
186+
"""Parse a query parameter as ``Bool``, falling back to *default*.
187+
188+
Truthy string values: ``"true"``, ``"1"``, ``"yes"``.
189+
190+
Example: ``ctx.query_bool("verbose")``
191+
"""
192+
var s = self.query(name)
193+
if s:
194+
var v = s.value()
195+
return v == "true" or v == "1" or v == "yes"
196+
return default
197+
148198
# ---------------------------------------------------------------- headers
149199

150200
fn header(self, name: String) -> Optional[String]:

lightbug_api/routing.mojo

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from std.utils import Variant
33

44
from lightbug_http import HTTPRequest, HTTPResponse, HTTPService, NotFound, OK
55
from lightbug_http.http import RequestMethod
6+
from lightbug_http.http.common_response import InternalError
67
from lightbug_http.uri import URIDelimiters
78

89
from lightbug_api.context import Context
@@ -30,6 +31,51 @@ comptime HandlerResponse = Variant[HTTPResponse, String]
3031
# Use Context to access the request, path params, query params, headers, body.
3132
comptime Handler = fn (Context) raises -> HandlerResponse
3233

34+
# ------------------------------------------------------------ middleware types
35+
36+
# Middleware return type:
37+
# HTTPResponse — short-circuit: send this response immediately, skip handler
38+
# Bool — continue to the next middleware / handler (value is ignored)
39+
#
40+
# Use the helpers ``next()`` and ``abort(response)`` to return these cleanly.
41+
comptime MiddlewareResult = Variant[HTTPResponse, Bool]
42+
43+
# Every middleware shares this non-capturing function-pointer signature.
44+
comptime Middleware = fn (Context) raises -> MiddlewareResult
45+
46+
47+
fn next() -> MiddlewareResult:
48+
"""Signal that processing should continue to the next middleware / handler."""
49+
return MiddlewareResult(True)
50+
51+
52+
fn abort(var response: HTTPResponse) -> MiddlewareResult:
53+
"""Short-circuit the request with *response*, skipping all further processing."""
54+
return MiddlewareResult(response^)
55+
56+
57+
# Wrapper so Middleware function pointers can live in a List[MiddlewareEntry].
58+
struct MiddlewareEntry(Copyable):
59+
var handler: Middleware
60+
61+
fn __init__(out self, handler: Middleware):
62+
self.handler = handler
63+
64+
fn __init__(out self, *, copy: Self):
65+
self.handler = copy.handler
66+
67+
68+
# ------------------------------------------------------------ error handler
69+
70+
# Called when a route handler raises an unhandled error.
71+
# Return an appropriate HTTPResponse; the exception is consumed.
72+
comptime ErrorHandler = fn (Context, Error) raises -> HTTPResponse
73+
74+
75+
fn _default_error_handler(ctx: Context, e: Error) raises -> HTTPResponse:
76+
print("lightbug_api error:", String(e))
77+
return InternalError()
78+
3379

3480
# --------------------------------------------------------------- path matching
3581

@@ -229,6 +275,8 @@ struct RouterBase[is_main_app: Bool = False](HTTPService, Copyable):
229275
var path_fragment: String
230276
var sub_routers: List[RouterBase[False]]
231277
var routes: List[RouteEntry]
278+
var middleware: List[MiddlewareEntry]
279+
var error_handler: ErrorHandler
232280

233281
# ------------------------------------------------------------------ init
234282

@@ -238,11 +286,15 @@ struct RouterBase[is_main_app: Bool = False](HTTPService, Copyable):
238286
self.path_fragment = "/"
239287
self.sub_routers = List[RouterBase[False]]()
240288
self.routes = List[RouteEntry]()
289+
self.middleware = List[MiddlewareEntry]()
290+
self.error_handler = _default_error_handler
241291

242292
def __init__(out self: Self, path_fragment: String) raises:
243293
self.path_fragment = path_fragment
244294
self.sub_routers = List[RouterBase[False]]()
245295
self.routes = List[RouteEntry]()
296+
self.middleware = List[MiddlewareEntry]()
297+
self.error_handler = _default_error_handler
246298

247299
if not self._validate_path_fragment(path_fragment):
248300
raise Error(RouterErrors.INVALID_PATH_FRAGMENT_ERROR)
@@ -251,6 +303,8 @@ struct RouterBase[is_main_app: Bool = False](HTTPService, Copyable):
251303
self.path_fragment = copy.path_fragment
252304
self.sub_routers = copy.sub_routers.copy()
253305
self.routes = copy.routes.copy()
306+
self.middleware = copy.middleware.copy()
307+
self.error_handler = copy.error_handler
254308

255309
# -------------------------------------------------------- route registration
256310

@@ -296,6 +350,23 @@ struct RouterBase[is_main_app: Bool = False](HTTPService, Copyable):
296350
"""Mount a sub-router, nesting all its routes under its path fragment."""
297351
self.sub_routers.append(router^)
298352

353+
def use(mut self, middleware: Middleware) -> None:
354+
"""Add a middleware function that runs before every handler on this router.
355+
356+
Middleware runs in registration order. Return ``next()`` to continue,
357+
or ``abort(response)`` to short-circuit.
358+
359+
Example::
360+
361+
fn require_auth(ctx: Context) raises -> MiddlewareResult:
362+
if not ctx.header("Authorization"):
363+
return abort(Response.unauthorized())
364+
return next()
365+
366+
app.use(require_auth)
367+
"""
368+
self.middleware.append(MiddlewareEntry(middleware))
369+
299370
# ---------------------------------------------------------------- dispatch
300371

301372
def func(mut self, req: HTTPRequest) raises -> HTTPResponse:
@@ -310,7 +381,21 @@ struct RouterBase[is_main_app: Bool = False](HTTPService, Copyable):
310381
raise e^
311382

312383
var ctx = Context(req.copy(), route_match.path_params.copy())
313-
var res = route_match.handler(ctx)
384+
385+
# Run middleware chain — any middleware may short-circuit with a response.
386+
for i in range(len(self.middleware)):
387+
var mw_result = self.middleware[i].handler(ctx)
388+
if mw_result.isa[HTTPResponse]():
389+
return mw_result.unsafe_take[HTTPResponse]()
390+
# else Bool → continue to next middleware / handler
391+
392+
# Dispatch to the matched handler; convert unhandled errors to responses.
393+
var res: HandlerResponse
394+
try:
395+
res = route_match.handler(ctx)
396+
except e:
397+
return self.error_handler(ctx, e)
398+
314399
return self._encode_response(res^)
315400

316401
# --------------------------------------------------------------- internals

0 commit comments

Comments
 (0)