Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions internal/guard/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ var logWasm = logger.New("guard:wasm")
// JIT compilation when multiple guards load the same WASM binary.
var globalCompilationCache = wazero.NewCompilationCache()

// CloseGlobalCompilationCache releases JIT resources held by the shared
// compilation cache. It must be called once during graceful shutdown, after
// all WasmGuard runtimes have been closed (i.e., after Registry.Close()).
// Calling it while guards are still active or calling it more than once leads
// to undefined behavior. It is not safe to call concurrently.
func CloseGlobalCompilationCache(ctx context.Context) error {
return globalCompilationCache.Close(ctx)
}

// WasmGuardOptions configures optional settings for WASM guard creation
type WasmGuardOptions struct {
// Stdout is the writer for WASM stdout output. Defaults to os.Stdout if nil.
Expand Down Expand Up @@ -310,8 +319,10 @@ func (g *WasmGuard) hostLog(ctx context.Context, m api.Module, stack []uint64) {
logWasm.Printf("%sINFO: %s", prefix, msg)
case logLevelWarn:
logWasm.Printf("%sWARN: %s", prefix, msg)
logger.LogWarn("guard", "[%s] %s", g.name, msg)
case logLevelError:
logWasm.Printf("%sERROR: %s", prefix, msg)
logger.LogError("guard", "[%s] %s", g.name, msg)
default:
logWasm.Printf("%s%s", prefix, msg)
}
Expand Down Expand Up @@ -972,11 +983,8 @@ func (g *WasmGuard) tryCallWasmFunction(ctx context.Context, fn api.Function, me

resultLen := int32(results[0])
if resultLen == -2 {
if sizeBytes, ok := mem.Read(outputPtr, 4); ok && len(sizeBytes) == 4 {
requiredSize := uint32(sizeBytes[0]) | uint32(sizeBytes[1])<<8 | uint32(sizeBytes[2])<<16 | uint32(sizeBytes[3])<<24
if requiredSize > 0 {
return nil, requiredSize, nil
}
if requiredSize, ok := mem.ReadUint32Le(outputPtr); ok && requiredSize > 0 {
return nil, requiredSize, nil
}
return nil, 0, nil
}
Expand Down Expand Up @@ -1039,11 +1047,8 @@ func (g *WasmGuard) tryCallWasmFunction(ctx context.Context, fn api.Function, me
// The guard can optionally return the required size in the output buffer as a uint32
if resultLen == -2 {
// Try to read the required size from the output buffer (first 4 bytes as uint32)
if sizeBytes, ok := mem.Read(outputPtr, 4); ok && len(sizeBytes) == 4 {
requiredSize := uint32(sizeBytes[0]) | uint32(sizeBytes[1])<<8 | uint32(sizeBytes[2])<<16 | uint32(sizeBytes[3])<<24
if requiredSize > 0 {
return nil, requiredSize, nil
}
if requiredSize, ok := mem.ReadUint32Le(outputPtr); ok && requiredSize > 0 {
return nil, requiredSize, nil
}
// Guard didn't specify size, return 0 to trigger doubling
return nil, 0, nil
Expand Down
5 changes: 5 additions & 0 deletions internal/server/unified.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,11 @@ func (us *UnifiedServer) InitiateShutdown() int {
us.guardRegistry.Close(context.Background())
}

// Release JIT resources held by the shared WASM compilation cache
if err := guard.CloseGlobalCompilationCache(context.Background()); err != nil {
logger.LogError("shutdown", "Failed to close WASM compilation cache: %v", err)
}

logger.LogInfo("shutdown", "Backend servers terminated successfully")
})
return serversTerminated
Expand Down
Loading