From c0b815c82d4cca8d9cb5fa92a51567c3ad9cdef6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6ssler?= Date: Fri, 24 Apr 2026 10:46:33 +0200 Subject: [PATCH 1/6] Add callbackOnBlock setting to wrapExport --- library/agent/hooks/wrapExport.test.ts | 60 ++++++++++++++++++- library/agent/hooks/wrapExport.ts | 40 +++++++++---- .../getCallbackFunctionFromArgs.test.ts | 55 +++++++++++++++++ .../helpers/getCallbackFunctionFromArgs.ts | 9 +++ 4 files changed, 152 insertions(+), 12 deletions(-) create mode 100644 library/helpers/getCallbackFunctionFromArgs.test.ts create mode 100644 library/helpers/getCallbackFunctionFromArgs.ts diff --git a/library/agent/hooks/wrapExport.test.ts b/library/agent/hooks/wrapExport.test.ts index 64da86886..cc8dd51f5 100644 --- a/library/agent/hooks/wrapExport.test.ts +++ b/library/agent/hooks/wrapExport.test.ts @@ -2,7 +2,7 @@ import * as t from "tap"; import { wrapExport } from "./wrapExport"; import { LoggerForTesting } from "../logger/LoggerForTesting"; import { Token } from "../api/Token"; -import { bindContext } from "../Context"; +import { bindContext, runWithContext } from "../Context"; import { createTestAgent } from "../../helpers/createTestAgent"; const logger = new LoggerForTesting(); @@ -10,6 +10,7 @@ const logger = new LoggerForTesting(); createTestAgent({ logger, token: new Token("123"), + block: true, }); t.test("Inspect args", async (t) => { @@ -205,3 +206,60 @@ t.test("Wrap default export", async (t) => { t.same(patched("input"), "input"); t.ok(executedCallback); }); + +t.test("it calls callback on block if callbackOnBlock is set", async (t) => { + const toWrap = { + test(input: string, callback: (err: Error | null) => void) { + callback(null); + }, + }; + + wrapExport( + toWrap, + "test", + { name: "test", type: "external" }, + { + kind: "outgoing_http_op", + inspectArgs: () => { + return { + operation: "http.get", + kind: "ssrf", + source: "body", + pathsToPayload: [""], + metadata: {}, + payload: "foo", + }; + }, + callbackOnBlock: true, + } + ); + + await runWithContext( + { + remoteAddress: "::1", + method: "POST", + url: "http://localhost:4000", + query: {}, + body: undefined, + headers: {}, + cookies: {}, + routeParams: {}, + source: "express", + route: "/posts/:id", + }, + async () => { + await new Promise((resolve) => { + toWrap.test("input", (err) => { + t.ok(err instanceof Error); + if (err instanceof Error) { + t.match( + err.message, + "Zen has blocked a server-side request forgery: http.get(...)" + ); + } + resolve(null); + }); + }); + } + ); +}); diff --git a/library/agent/hooks/wrapExport.ts b/library/agent/hooks/wrapExport.ts index 27d9b561f..51385c58d 100644 --- a/library/agent/hooks/wrapExport.ts +++ b/library/agent/hooks/wrapExport.ts @@ -6,6 +6,7 @@ import type { InterceptorResult } from "./InterceptorResult"; import type { PartialWrapPackageInfo } from "./WrapPackageInfo"; import { wrapDefaultOrNamed } from "./wrapDefaultOrNamed"; import { onInspectionInterceptorResult } from "./onInspectionInterceptorResult"; +import { getCallbackFunctionFromArgs } from "../../helpers/getCallbackFunctionFromArgs"; export type InspectArgsInterceptor = ( args: unknown[], @@ -34,6 +35,11 @@ export type InterceptorObject = { // This will be used to collect stats // For sources, this will often be undefined kind: OperationKind | undefined; + // When true, if blocking is triggered and the last argument is a function, + // call it with the error instead of throwing synchronously. + // Needed for callback-based APIs (e.g. pg.Client.query(sql, params, cb)) + // where a synchronous throw escapes Promise chains and crashes the process. + callbackOnBlock?: boolean; }; /** @@ -69,17 +75,29 @@ export function wrapExport( } } - inspectArgs.call( - // @ts-expect-error We don't now the type of this - this, - args, - interceptors.inspectArgs, - context, - agent, - pkgInfo, - methodName || "", - interceptors.kind - ); + try { + inspectArgs.call( + // @ts-expect-error We don't now the type of this + this, + args, + interceptors.inspectArgs, + context, + agent, + pkgInfo, + methodName || "", + interceptors.kind + ); + } catch (error) { + if (interceptors.callbackOnBlock) { + // Find the last function argument and call it with the error. + const cbFunc = getCallbackFunctionFromArgs(args); + if (cbFunc) { + process.nextTick(() => cbFunc(error)); + return undefined; + } + } + throw error; + } } // Run modifyArgs interceptor if provided diff --git a/library/helpers/getCallbackFunctionFromArgs.test.ts b/library/helpers/getCallbackFunctionFromArgs.test.ts new file mode 100644 index 000000000..79a1c4c4b --- /dev/null +++ b/library/helpers/getCallbackFunctionFromArgs.test.ts @@ -0,0 +1,55 @@ +import * as t from "tap"; +import { getCallbackFunctionFromArgs } from "./getCallbackFunctionFromArgs"; + +t.test( + "getCallbackFunctionFromArgs should return the last function argument", + (t) => { + const callback = () => {}; + const args = [1, "string", callback, () => {}]; + const result = getCallbackFunctionFromArgs(args); + t.equal(result, args[3]); + t.end(); + } +); + +t.test( + "getCallbackFunctionFromArgs should return undefined if no function argument is found", + (t) => { + const args = [1, "string", true, null]; + const result = getCallbackFunctionFromArgs(args); + t.equal(result, undefined); + t.end(); + } +); + +t.test( + "getCallbackFunctionFromArgs should return the last function argument even if there are multiple", + (t) => { + const callback1 = () => {}; + const callback2 = () => {}; + const args = [1, "string", callback1, callback2]; + const result = getCallbackFunctionFromArgs(args); + t.equal(result, callback2); + t.end(); + } +); + +t.test( + "getCallbackFunctionFromArgs should return undefined for an empty array", + (t) => { + const args: any[] = []; + const result = getCallbackFunctionFromArgs(args); + t.equal(result, undefined); + t.end(); + } +); + +t.test( + "getCallbackFunctionFromArgs should return undefined if all arguments are non-functions", + (t) => { + const args = [1, "string", true, null, {}]; + const result = getCallbackFunctionFromArgs(args); + t.equal(result, undefined); + t.end(); + } +); diff --git a/library/helpers/getCallbackFunctionFromArgs.ts b/library/helpers/getCallbackFunctionFromArgs.ts new file mode 100644 index 000000000..33a44bf1f --- /dev/null +++ b/library/helpers/getCallbackFunctionFromArgs.ts @@ -0,0 +1,9 @@ +// Finds the last function argument in the provided array and returns it. If no function is found, it returns undefined. +export function getCallbackFunctionFromArgs(args: any[]): Function | undefined { + for (let i = args.length - 1; i >= 0; i--) { + if (typeof args[i] === "function") { + return args[i] as Function; + } + } + return undefined; +} From 1a06afdc85c92f49ead09ef9d4b4265ec41ab7cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6ssler?= Date: Fri, 24 Apr 2026 11:47:55 +0200 Subject: [PATCH 2/6] Add callbackOnBlock option for ESM instrumentation --- instrumentation-wasm/Cargo.lock | 162 ++++---- instrumentation-wasm/Cargo.toml | 16 +- .../src/js_transformer/helpers/insert_code.rs | 19 +- .../helpers/insert_instrument_method_calls.rs | 1 + .../src/js_transformer/instructions.rs | 1 + .../codeTransformation.benchmark.test.ts | 2 + .../codeTransformation.test.ts | 363 ++++++++++++++++++ .../instrumentation/injectedFunctions.ts | 52 ++- .../hooks/instrumentation/instructions.ts | 2 + library/agent/hooks/instrumentation/types.ts | 7 + 10 files changed, 518 insertions(+), 107 deletions(-) diff --git a/instrumentation-wasm/Cargo.lock b/instrumentation-wasm/Cargo.lock index 8632316d5..a049136a6 100644 --- a/instrumentation-wasm/Cargo.lock +++ b/instrumentation-wasm/Cargo.lock @@ -26,15 +26,15 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.10.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" [[package]] name = "bumpalo" -version = "3.19.1" +version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" [[package]] name = "castaway" @@ -73,9 +73,9 @@ checksum = "417bef24afe1460300965a25ff4a24b8b45ad011948302ec221e8a0a81eb2c79" [[package]] name = "dragonbox_ecma" -version = "0.1.0" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a5577f010d4e1bb3f3c4d6081e05718eb6992cf20119cab4d3abadff198b5ae" +checksum = "fd8e701084c37e7ef62d3f9e453b618130cbc0ef3573847785952a3ac3f746bf" [[package]] name = "either" @@ -85,15 +85,15 @@ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "fastrand" -version = "2.3.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" [[package]] name = "hashbrown" -version = "0.16.1" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" dependencies = [ "allocator-api2", ] @@ -109,21 +109,21 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "json-escape-simd" -version = "3.0.1" +version = "3.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3c2a6c0b4b5637c41719973ef40c6a1cf564f9db6958350de6193fbee9c23f5" +checksum = "35e770254dd7802184595b1d30da2a15cb72569e2aca2b177aef8d22eac8a693" [[package]] name = "memchr" -version = "2.7.6" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" [[package]] name = "node_code_instrumentation" @@ -189,15 +189,15 @@ checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" [[package]] name = "owo-colors" -version = "4.2.3" +version = "4.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c6901729fa79e91a0913333229e9ca5dc725089d1c363b2f4b4760709dc4a52" +checksum = "d211803b9b6b570f68772237e415a029d5a50c65d382910b879fb19d3271f94d" [[package]] name = "oxc-miette" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a7ba54c704edefead1f44e9ef09c43e5cfae666bdc33516b066011f0e6ebf7" +checksum = "4356a61f2ed4c9b3610245215fbf48970eb277126919f87db9d0efa93a74245c" dependencies = [ "cfg-if", "owo-colors", @@ -210,9 +210,9 @@ dependencies = [ [[package]] name = "oxc-miette-derive" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4faecb54d0971f948fbc1918df69b26007e6f279a204793669542e1e8b75eb3" +checksum = "b237422b014f8f8fff75bb9379e697d13f8d57551a22c88bebb39f073c1bf696" dependencies = [ "proc-macro2", "quote", @@ -221,9 +221,9 @@ dependencies = [ [[package]] name = "oxc_allocator" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e6fc6ce99f6a28fd477c6df500bbc9bf1c39db166952e15bea218459cc0db0c" +checksum = "bd3b8bfef454857d3d9ca08fb84c8955da8591b5a82a21bb34a7ebbf94da7b0f" dependencies = [ "allocator-api2", "hashbrown", @@ -233,9 +233,9 @@ dependencies = [ [[package]] name = "oxc_ast" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49fa0813bf9fcff5a4e48fc186ee15a0d276b30b0b575389a34a530864567819" +checksum = "381ae8356082431bd7e217dd78c7179bfc379dbbe7a32494e28be4fc678812c7" dependencies = [ "bitflags", "oxc_allocator", @@ -245,14 +245,15 @@ dependencies = [ "oxc_estree", "oxc_regular_expression", "oxc_span", + "oxc_str", "oxc_syntax", ] [[package]] name = "oxc_ast_macros" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a2b2a2e09ff0dd4790a5ceb4a93349e0ea769d4d98d778946de48decb763b18" +checksum = "c50246449a5fa669debd2debeb90be4c30f0a3a2e954f852ec40e5ef49701285" dependencies = [ "phf", "proc-macro2", @@ -262,9 +263,9 @@ dependencies = [ [[package]] name = "oxc_ast_visit" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef6d2304cb25dbbd028440591bf289ef16e3df98517930e79dcc304be64b3045" +checksum = "82466fd1885834078becf1385380c40624bf511723b695104b21f293c7dc5271" dependencies = [ "oxc_allocator", "oxc_ast", @@ -274,9 +275,9 @@ dependencies = [ [[package]] name = "oxc_codegen" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce92b24319ee9fbfa14a5cc488a5ba91bb04bac070c4bad0ba18c772060d19c0" +checksum = "b69f394fa01810f99943a9191dde9d5757bdccd5c06347af62663eea671a5153" dependencies = [ "bitflags", "cow-utils", @@ -289,21 +290,22 @@ dependencies = [ "oxc_semantic", "oxc_sourcemap", "oxc_span", + "oxc_str", "oxc_syntax", "rustc-hash", ] [[package]] name = "oxc_data_structures" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e8f59bed9522098da177d894dc8635fb3eae218ff97d9c695900cb11fd10a2" +checksum = "1defc2fd17ee94f2c8511b0c4a4756d5868fbee891478953f2354ef444b1962f" [[package]] name = "oxc_diagnostics" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0476859d4319f2b063f7c4a3120ee5b7e3e48032865ca501f8545ff44badcff" +checksum = "1d7ccb0e8e7c9f1fb75e0700b2c75d9d854e534a7a356b13d2936893651f2b98" dependencies = [ "cow-utils", "oxc-miette", @@ -312,9 +314,9 @@ dependencies = [ [[package]] name = "oxc_ecmascript" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bcf46e5b1a6f8ea3797e887a9db4c79ed15894ca8685eb628da462d4c4e913f" +checksum = "1904566c4e725c1511c88166ec203ae97bebb62887441b4a29b1e7757ec39859" dependencies = [ "cow-utils", "num-bigint", @@ -328,9 +330,9 @@ dependencies = [ [[package]] name = "oxc_estree" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2251e6b61eab7b96f0e9d140b68b0f0d8a851c7d260725433e18b1babdcb9430" +checksum = "e87cd0e290bab4cb5d81377bbc1ebd414f01a7af72d7f8e5ccbb4a9a157d71df" [[package]] name = "oxc_index" @@ -344,9 +346,9 @@ dependencies = [ [[package]] name = "oxc_parser" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "439d2580047b77faf6e60d358b48e5292e0e026b9cfc158d46ddd0175244bb26" +checksum = "c71acdb67749ff68bfbbd346da7dd2fe4947964be49ac9ec34d73d10a2396dcd" dependencies = [ "bitflags", "cow-utils", @@ -360,6 +362,7 @@ dependencies = [ "oxc_ecmascript", "oxc_regular_expression", "oxc_span", + "oxc_str", "oxc_syntax", "rustc-hash", "seq-macro", @@ -367,15 +370,16 @@ dependencies = [ [[package]] name = "oxc_regular_expression" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fb5669d3298a92d440afec516943745794cb4cf977911728cd73e3438db87b9" +checksum = "08a1273168ec6d8083e161565d264847249b9aad51c430d92d344303ede058b2" dependencies = [ "bitflags", "oxc_allocator", "oxc_ast_macros", "oxc_diagnostics", "oxc_span", + "oxc_str", "phf", "rustc-hash", "unicode-id-start", @@ -383,9 +387,9 @@ dependencies = [ [[package]] name = "oxc_semantic" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "487e9ef54375b23b159eef73746a02b505c3ae70b9c302610680d3c68a3bb62c" +checksum = "63d4f8a0d3eb4e8e03aa413f54300235cb0314dc26649d2ff19f609b7b478272" dependencies = [ "itertools", "memchr", @@ -396,6 +400,7 @@ dependencies = [ "oxc_ecmascript", "oxc_index", "oxc_span", + "oxc_str", "oxc_syntax", "rustc-hash", "self_cell", @@ -403,9 +408,9 @@ dependencies = [ [[package]] name = "oxc_sourcemap" -version = "6.0.1" +version = "6.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36801dbbd025f2fa133367494e38eef75a53d334ae6746ba0c889fc4e76fa3a3" +checksum = "6d378eb8bad20e89d66276aebab51f6a5408571092cac94abdd3eabb773713d6" dependencies = [ "base64-simd", "json-escape-simd", @@ -416,9 +421,9 @@ dependencies = [ [[package]] name = "oxc_span" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1d452f6a664627bdd0f1f1586f9258f81cd7edc5c83e9ef50019f701ef1722d" +checksum = "9af84474452c3caa7aca1bcaca04b6e16552fe29472059b7921ae7a69790dccf" dependencies = [ "compact_str", "oxc-miette", @@ -430,9 +435,9 @@ dependencies = [ [[package]] name = "oxc_str" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c7a27c4371f69387f3d6f8fa56f70e4c6fa6aedc399285de6ec02bb9fd148d7" +checksum = "136bcc6bed1182df0b9c529e478da55a490b38ba5f1189abf2e7a9b13f46f0b1" dependencies = [ "compact_str", "hashbrown", @@ -442,9 +447,9 @@ dependencies = [ [[package]] name = "oxc_syntax" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d60d91023aafc256ab99c3dbf6181473e495695029c0152d2093e87df18ffe2" +checksum = "b448a086623714675f66b79271e25fa2b51708255fa6af7dad83be88cc6e8726" dependencies = [ "bitflags", "cow-utils", @@ -455,15 +460,16 @@ dependencies = [ "oxc_estree", "oxc_index", "oxc_span", + "oxc_str", "phf", "unicode-id-start", ] [[package]] name = "oxc_traverse" -version = "0.123.0" +version = "0.127.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c31aba1910999e2f9a1cc9c47a490caaed828bb119351abe20a2a7851d554963" +checksum = "648c4e7c8ee0a8d2ff28751cc9dc8d5502a7d3b2b96d4fa73de7fb31b46d54c6" dependencies = [ "itoa", "oxc_allocator", @@ -538,18 +544,18 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.44" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] [[package]] name = "rustc-hash" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" [[package]] name = "rustversion" @@ -559,9 +565,9 @@ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" -version = "1.0.22" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" [[package]] name = "self_cell" @@ -638,9 +644,9 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "syn" -version = "2.0.114" +version = "2.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" dependencies = [ "proc-macro2", "quote", @@ -686,9 +692,9 @@ checksum = "81b79ad29b5e19de4260020f8919b443b2ef0277d242ce532ec7b7a2cc8b6007" [[package]] name = "unicode-ident" -version = "1.0.22" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "unicode-linebreak" @@ -698,9 +704,9 @@ checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" [[package]] name = "unicode-segmentation" -version = "1.12.0" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" [[package]] name = "unicode-width" @@ -716,9 +722,9 @@ checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" [[package]] name = "wasm-bindgen" -version = "0.2.108" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" +checksum = "0bf938a0bacb0469e83c1e148908bd7d5a6010354cf4fb73279b7447422e3a89" dependencies = [ "cfg-if", "once_cell", @@ -729,9 +735,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.108" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" +checksum = "eeff24f84126c0ec2db7a449f0c2ec963c6a49efe0698c4242929da037ca28ed" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -739,9 +745,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.108" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" +checksum = "9d08065faf983b2b80a79fd87d8254c409281cf7de75fc4b773019824196c904" dependencies = [ "bumpalo", "proc-macro2", @@ -752,15 +758,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.108" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" +checksum = "5fd04d9e306f1907bd13c6361b5c6bfc7b3b3c095ed3f8a9246390f8dbdee129" dependencies = [ "unicode-ident", ] [[package]] name = "zmij" -version = "1.0.19" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ff05f8caa9038894637571ae6b9e29466c1f4f829d26c9b28f869a29cbe3445" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/instrumentation-wasm/Cargo.toml b/instrumentation-wasm/Cargo.toml index ffb52c89c..c283b3cfc 100644 --- a/instrumentation-wasm/Cargo.toml +++ b/instrumentation-wasm/Cargo.toml @@ -10,16 +10,16 @@ name = "node_code_instrumentation" crate-type = ["cdylib", "rlib"] [dependencies] -oxc_allocator = "0.123.0" -oxc_ast = "0.123.0" -oxc_codegen = "0.123.0" -oxc_parser = "0.123.0" -oxc_semantic = "0.123.0" -oxc_span = "0.123.0" -oxc_traverse = "0.123.0" +oxc_allocator = "0.127.0" +oxc_ast = "0.127.0" +oxc_codegen = "0.127.0" +oxc_parser = "0.127.0" +oxc_semantic = "0.127.0" +oxc_span = "0.127.0" +oxc_traverse = "0.127.0" serde = "1.0.228" serde_json = "1.0.149" -wasm-bindgen = "0.2.108" +wasm-bindgen = "0.2.118" [profile.release] strip = true diff --git a/instrumentation-wasm/src/js_transformer/helpers/insert_code.rs b/instrumentation-wasm/src/js_transformer/helpers/insert_code.rs index 24ee375de..0ae33fff8 100644 --- a/instrumentation-wasm/src/js_transformer/helpers/insert_code.rs +++ b/instrumentation-wasm/src/js_transformer/helpers/insert_code.rs @@ -3,7 +3,7 @@ use oxc_ast::{ AstBuilder, NONE, ast::{ Argument, ArrayExpressionElement, AssignmentOperator, AssignmentTarget, Expression, - FunctionBody, Statement, + FunctionBody, Statement, UnaryOperator, }, }; use oxc_span::SPAN; @@ -16,6 +16,7 @@ pub fn insert_inspect_args<'a>( pkg_version: &'a str, body: &mut Box<'a, FunctionBody<'a>>, is_constructor: bool, + callback_on_block: bool, ) { let mut inspect_args: OxcVec<'a, Argument<'a>> = builder.vec_with_capacity(4); @@ -48,10 +49,20 @@ pub fn insert_inspect_args<'a>( false, ); - let stmt_expression = builder.statement_expression(SPAN, call_expr); - let insert_pos = get_insert_pos(body, is_constructor); - body.statements.insert(insert_pos, stmt_expression); + + if !callback_on_block { + let stmt_expression = builder.statement_expression(SPAN, call_expr); + + body.statements.insert(insert_pos, stmt_expression); + return; + } + + let return_stmt = builder.statement_return(SPAN, None); + let test = builder.expression_unary(SPAN, UnaryOperator::LogicalNot, call_expr); + let if_stmt = builder.statement_if(SPAN, test, return_stmt, None); + + body.statements.insert(insert_pos, if_stmt); } // Modify the arguments by adding a statement to the beginning of the function diff --git a/instrumentation-wasm/src/js_transformer/helpers/insert_instrument_method_calls.rs b/instrumentation-wasm/src/js_transformer/helpers/insert_instrument_method_calls.rs index a2f3bfe84..1475e0d78 100644 --- a/instrumentation-wasm/src/js_transformer/helpers/insert_instrument_method_calls.rs +++ b/instrumentation-wasm/src/js_transformer/helpers/insert_instrument_method_calls.rs @@ -38,6 +38,7 @@ pub fn insert_instrument_method_calls<'a>( pkg_version, body, is_constructor, + instruction.callback_on_block, ); } diff --git a/instrumentation-wasm/src/js_transformer/instructions.rs b/instrumentation-wasm/src/js_transformer/instructions.rs index b6890ab6f..f99ed1c0e 100644 --- a/instrumentation-wasm/src/js_transformer/instructions.rs +++ b/instrumentation-wasm/src/js_transformer/instructions.rs @@ -21,4 +21,5 @@ pub struct FunctionInstructions { pub modify_return_value: bool, pub modify_arguments_object: bool, pub class_name: Option, + pub callback_on_block: bool, } diff --git a/library/agent/hooks/instrumentation/codeTransformation.benchmark.test.ts b/library/agent/hooks/instrumentation/codeTransformation.benchmark.test.ts index 16a404d24..f9b168494 100644 --- a/library/agent/hooks/instrumentation/codeTransformation.benchmark.test.ts +++ b/library/agent/hooks/instrumentation/codeTransformation.benchmark.test.ts @@ -50,6 +50,7 @@ t.test("Benchmark: Small code transformation", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -96,6 +97,7 @@ t.test("Benchmark: Large code transformation", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], diff --git a/library/agent/hooks/instrumentation/codeTransformation.test.ts b/library/agent/hooks/instrumentation/codeTransformation.test.ts index 5c2b7f5e6..32d99e208 100644 --- a/library/agent/hooks/instrumentation/codeTransformation.test.ts +++ b/library/agent/hooks/instrumentation/codeTransformation.test.ts @@ -51,6 +51,7 @@ t.test("add inspectArgs to method definition (ESM)", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -109,6 +110,7 @@ t.test("add inspectArgs to method definition (CJS)", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -167,6 +169,7 @@ t.test("wrong function name", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -224,6 +227,7 @@ t.test("typescript code", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -283,6 +287,7 @@ t.test("typescript code in a js file", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -315,6 +320,7 @@ t.test("empty code", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -360,6 +366,7 @@ t.test("add modifyArgs to method definition (ESM)", async (t) => { modifyArgs: true, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -420,6 +427,7 @@ t.test( modifyArgs: true, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -477,6 +485,7 @@ t.test("modify rest parameter args", async (t) => { modifyArgs: true, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -523,6 +532,7 @@ t.test("modify rest parameter args", async (t) => { modifyArgs: true, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -578,6 +588,7 @@ t.test("add inspectArgs to method definition (unambiguous)", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -634,6 +645,7 @@ t.test("add inspectArgs to method definition (unambiguous)", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -687,6 +699,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -736,6 +749,7 @@ t.test( modifyArgs: true, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -786,6 +800,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -837,6 +852,7 @@ t.test( modifyArgs: true, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -888,6 +904,7 @@ t.test( modifyArgs: true, modifyReturnValue: false, modifyArgumentsObject: true, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -936,6 +953,7 @@ t.test("does not modify code if function name is not found", async (t) => { modifyArgs: true, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -990,6 +1008,7 @@ t.test("add modifyArgs to method definition (ESM)", async (t) => { modifyArgs: false, modifyReturnValue: true, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -1411,6 +1430,7 @@ t.test("it adds all imports if necessary (CJS)", async (t) => { modifyArgs: true, modifyReturnValue: true, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -1448,6 +1468,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -1481,6 +1502,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -1524,6 +1546,7 @@ t.test("Modify function declaration (ESM)", async (t) => { modifyArgs: false, modifyReturnValue: true, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -1579,6 +1602,7 @@ t.test("Modify function declaration (CJS)", async (t) => { modifyArgs: true, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -1629,6 +1653,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -1683,6 +1708,7 @@ t.test("Modify function expression (ESM)", async (t) => { modifyArgs: false, modifyReturnValue: true, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -1735,6 +1761,7 @@ t.test("Modify constructor (ESM)", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -1784,6 +1811,7 @@ t.test("Modify constructor with super (ESM)", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -1834,6 +1862,7 @@ t.test("Modify constructor with super (ESM)", async (t) => { modifyArgs: true, modifyReturnValue: true, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -2026,6 +2055,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -2075,6 +2105,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -2122,6 +2153,7 @@ t.test("Modify async return value (ESM)", async (t) => { modifyArgs: false, modifyReturnValue: true, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -2165,6 +2197,7 @@ t.test("Modify async arrow function variable declaration (ESM)", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -2209,6 +2242,7 @@ t.test("Modify function variable declaration (ESM)", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -2253,6 +2287,7 @@ t.test("Do not modify function variable declaration (ESM)", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -2300,6 +2335,7 @@ t.test("Test codegen comment behavior", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -2354,6 +2390,7 @@ t.test("it works with mts extension (ESM)", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, className: undefined, }, ], @@ -2412,6 +2449,7 @@ t.test("it works with cts extension (ESM)", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -2468,6 +2506,7 @@ t.test("Does not instrument if class name does not match", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, className: "NonMatchingClassName", }, ], @@ -2523,6 +2562,7 @@ t.test("It does instrument if class name matches", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, className: "Test", }, ], @@ -2582,6 +2622,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, className: "TestClass", }, ], @@ -2642,6 +2683,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, className: "TestNonMatching", }, ], @@ -2701,6 +2743,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, className: undefined, }, ], @@ -2756,6 +2799,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -2809,6 +2853,7 @@ t.test("It does instrument private methods in classes", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, className: undefined, }, ], @@ -2862,6 +2907,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, }, ], accessLocalVariables: [], @@ -2917,6 +2963,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, className: "Server", }, ], @@ -2979,6 +3026,7 @@ t.test( modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, className: "Server", }, ], @@ -3041,6 +3089,7 @@ t.test("Two classes with same name in different block scopes", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, className: "Server", }, ], @@ -3070,3 +3119,317 @@ t.test("Two classes with same name in different block scopes", async (t) => { }` ); }); + +t.test( + "use inspectArgs with callbackOnBlock for method definition (CJS)", + async (t) => { + const result = transformCode( + "testpkg", + "1.0.0", + "test.js", + ` + const { test } = require("test"); + class Test { + + private testValue = 42; + + constructor() { + this.testFunction(testValue); + } + testFunction(arg1, cb) { + cb(arg1); + } + } + `, + "commonjs", + { + path: "test.js", + versionRange: "^1.0.0", + identifier: "testpkg.test.js.^1.0.0", + functions: [ + { + nodeType: "MethodDefinition", + name: "testFunction", + identifier: "testpkg.test.js.testFunction.MethodDefinition.v1.0.0", + inspectArgs: true, + modifyArgs: false, + modifyReturnValue: false, + modifyArgumentsObject: false, + callbackOnBlock: true, + }, + ], + accessLocalVariables: [], + } + ); + + isSameCode( + t, + result, + `const { __instrumentInspectArgs } = require("@aikidosec/firewall/instrument/internals"); + const { test } = require("test"); + class Test { + private testValue = 42; + + constructor() { + this.testFunction(testValue); + } + testFunction(arg1, cb) { + if (!__instrumentInspectArgs("testpkg.test.js.testFunction.MethodDefinition.v1.0.0", arguments, "1.0.0", this)) return; + cb(arg1); + } + }` + ); + } +); + +t.test( + "use inspectArgs with callbackOnBlock for function declaration (ESM)", + async (t) => { + const result = transformCode( + "testpkg", + "1.0.0", + "test.js", + ` + import { test } from "test"; + + export function testFunction(arg1, cb) { + + try { + someFunctionThatMightThrow(arg1); + cb(arg1); + } catch (e) { + cb(e); + } + } + `, + "module", + { + path: "test.js", + versionRange: "^1.0.0", + identifier: "testpkg.test.js.^1.0.0", + functions: [ + { + nodeType: "FunctionDeclaration", + name: "testFunction", + identifier: + "testpkg.test.js.testFunction.FunctionDeclaration.v1.0.0", + inspectArgs: true, + modifyArgs: false, + modifyReturnValue: false, + modifyArgumentsObject: false, + callbackOnBlock: true, + }, + ], + accessLocalVariables: [], + } + ); + + isSameCode( + t, + result, + `import { __instrumentInspectArgs } from "@aikidosec/firewall/instrument/internals"; + import { test } from "test"; + + export function testFunction(arg1, cb) { + if (!__instrumentInspectArgs("testpkg.test.js.testFunction.FunctionDeclaration.v1.0.0", arguments, "1.0.0", this)) return; + try { + someFunctionThatMightThrow(arg1); + cb(arg1); + } catch (e) { + cb(e); + } + }` + ); + } +); + +t.test( + "use inspectArgs with callbackOnBlock for method definition (ESM)", + async (t) => { + const result = transformCode( + "testpkg", + "1.0.0", + "test.js", + ` + import { test } from "test"; + class Test { + testFunction(arg1, cb) { + cb(arg1); + } + } + `, + "module", + { + path: "test.js", + versionRange: "^1.0.0", + identifier: "testpkg.test.js.^1.0.0", + functions: [ + { + nodeType: "MethodDefinition", + name: "testFunction", + identifier: "testpkg.test.js.testFunction.MethodDefinition.v1.0.0", + inspectArgs: true, + modifyArgs: false, + modifyReturnValue: false, + modifyArgumentsObject: false, + callbackOnBlock: true, + }, + ], + accessLocalVariables: [], + } + ); + + isSameCode( + t, + result, + `import { __instrumentInspectArgs } from "@aikidosec/firewall/instrument/internals"; + import { test } from "test"; + class Test { + testFunction(arg1, cb) { + if (!__instrumentInspectArgs("testpkg.test.js.testFunction.MethodDefinition.v1.0.0", arguments, "1.0.0", this)) return; + cb(arg1); + } + }` + ); + } +); + +t.test( + "use inspectArgs with callbackOnBlock for function assignment (CJS)", + async (t) => { + const result = transformCode( + "express", + "1.0.0", + "application.js", + ` + const app = require("example"); + app.use = function (fn, cb) { + console.log("test"); + }; + `, + "commonjs", + { + path: "application.js", + versionRange: "^1.0.0", + identifier: "testpkg.test.js.^1.0.0", + functions: [ + { + nodeType: "FunctionAssignment", + name: "app.use", + identifier: + "express.application.js.app.use.FunctionAssignment.v1.0.0", + inspectArgs: true, + modifyArgs: false, + modifyReturnValue: false, + modifyArgumentsObject: false, + callbackOnBlock: true, + }, + ], + accessLocalVariables: [], + } + ); + + isSameCode( + t, + result, + `const { __instrumentInspectArgs } = require("@aikidosec/firewall/instrument/internals"); + const app = require("example"); + app.use = function (fn, cb) { + if (!__instrumentInspectArgs("express.application.js.app.use.FunctionAssignment.v1.0.0", arguments, "1.0.0", this)) return; + console.log("test"); + };` + ); + } +); + +t.test( + "use inspectArgs with callbackOnBlock for function expression (ESM)", + async (t) => { + const result = transformCode( + "testpkg", + "1.0.0", + "application.js", + ` + const y = function testFunction(arg1, cb) { + console.log("test"); + } + `, + "module", + { + path: "application.js", + versionRange: "^1.0.0", + identifier: "testpkg.test.js.^1.0.0", + functions: [ + { + nodeType: "FunctionExpression", + name: "testFunction", + identifier: + "testpkg.application.js.testFunction.FunctionExpression.v1.0.0", + inspectArgs: true, + modifyArgs: false, + modifyReturnValue: false, + modifyArgumentsObject: false, + callbackOnBlock: true, + }, + ], + accessLocalVariables: [], + } + ); + + isSameCode( + t, + result, + `import { __instrumentInspectArgs } from "@aikidosec/firewall/instrument/internals"; + const y = function testFunction(arg1, cb) { + if (!__instrumentInspectArgs("testpkg.application.js.testFunction.FunctionExpression.v1.0.0", arguments, "1.0.0", this)) return; + console.log("test"); + };` + ); + } +); + +t.test( + "use inspectArgs with callbackOnBlock for function variable declaration (ESM)", + async (t) => { + const result = transformCode( + "testpkg", + "1.0.0", + "application.js", + ` + const test = function (arg1, cb) { + console.log("test"); + } + `, + "module", + { + path: "application.js", + versionRange: "^1.0.0", + identifier: "testpkg.test.js.^1.0.0", + functions: [ + { + nodeType: "FunctionVariableDeclaration", + name: "test", + identifier: + "testpkg.application.js.test.FunctionVariableDeclaration.v1.0.0", + inspectArgs: true, + modifyArgs: false, + modifyReturnValue: false, + modifyArgumentsObject: false, + callbackOnBlock: true, + }, + ], + accessLocalVariables: [], + } + ); + + isSameCode( + t, + result, + `import { __instrumentInspectArgs } from "@aikidosec/firewall/instrument/internals"; + const test = function(arg1, cb) { + if (!__instrumentInspectArgs("testpkg.application.js.test.FunctionVariableDeclaration.v1.0.0", arguments, "1.0.0", this)) return; + console.log("test"); + };` + ); + } +); diff --git a/library/agent/hooks/instrumentation/injectedFunctions.ts b/library/agent/hooks/instrumentation/injectedFunctions.ts index e64a5be6f..4aea41029 100644 --- a/library/agent/hooks/instrumentation/injectedFunctions.ts +++ b/library/agent/hooks/instrumentation/injectedFunctions.ts @@ -1,42 +1,60 @@ import { getInstance } from "../../AgentSingleton"; import { bindContext, getContext } from "../../Context"; +import { getCallbackFunctionFromArgs } from "../../../helpers/getCallbackFunctionFromArgs"; import { inspectArgs } from "../wrapExport"; import { getFileCallbackInfo, getFunctionCallbackInfo } from "./instructions"; +/** + * Returns false when a block was handled via callback (caller should return early), true otherwise. + * Throws when blocking and callbackOnBlock is false. + */ export function __instrumentInspectArgs( id: string, args: IArguments | unknown[], pkgVersion: string, subject: unknown // "This" of the method being called -) { +): boolean { const agent = getInstance(); if (!agent) { - return; + return true; } const context = getContext(); const cbInfo = getFunctionCallbackInfo(id); if (!cbInfo) { - return; + return true; } if (typeof cbInfo.funcs.inspectArgs === "function") { - inspectArgs.call( - subject, - Array.from(args), - cbInfo.funcs.inspectArgs, - context, - agent, - { - name: cbInfo.pkgName, - version: pkgVersion, - type: "external", - }, - cbInfo.methodName, - cbInfo.operationKind - ); + try { + inspectArgs.call( + subject, + Array.from(args), + cbInfo.funcs.inspectArgs, + context, + agent, + { + name: cbInfo.pkgName, + version: pkgVersion, + type: "external", + }, + cbInfo.methodName, + cbInfo.operationKind + ); + } catch (error) { + if (cbInfo.funcs.callbackOnBlock) { + const cbFunc = getCallbackFunctionFromArgs(Array.from(args)); + if (cbFunc) { + process.nextTick(() => cbFunc(error)); + return false; + } + } + throw error; + } } + + return true; } export function __instrumentModifyArgs( diff --git a/library/agent/hooks/instrumentation/instructions.ts b/library/agent/hooks/instrumentation/instructions.ts index 5c6bef869..a2d90f00d 100644 --- a/library/agent/hooks/instrumentation/instructions.ts +++ b/library/agent/hooks/instrumentation/instructions.ts @@ -81,6 +81,7 @@ export function setPackagesToInstrument(_packages: Package[]) { modifyArgs: func.modifyArgs, modifyReturnValue: func.modifyReturnValue, bindContext: func.bindContext ?? false, + callbackOnBlock: func.callbackOnBlock ?? false, }, }); @@ -92,6 +93,7 @@ export function setPackagesToInstrument(_packages: Package[]) { modifyArgs: !!func.modifyArgs, modifyReturnValue: !!func.modifyReturnValue, modifyArgumentsObject: func.modifyArgumentsObject ?? false, + callbackOnBlock: func.callbackOnBlock ?? false, className: func.className, }; }), diff --git a/library/agent/hooks/instrumentation/types.ts b/library/agent/hooks/instrumentation/types.ts index 56156e4b6..bcd91b843 100644 --- a/library/agent/hooks/instrumentation/types.ts +++ b/library/agent/hooks/instrumentation/types.ts @@ -60,6 +60,7 @@ export type IntereptorFunctionsObj = { modifyArgs?: ModifyArgsInterceptor; modifyReturnValue?: ModifyReturnValueInterceptor; bindContext: boolean; + callbackOnBlock: boolean; }; export type IntereptorCallbackInfoObj = { @@ -116,6 +117,11 @@ export type PackageFunctionInstrumentationInstruction = { * If enabled, the bindContext function will be called for all callbacks that are passed to the function. */ bindContext?: boolean; + /** + * If true, when a block occurs the last function argument is called with the error instead of throwing. + * Useful for libraries that use error-first callbacks instead of promises/throws. + */ + callbackOnBlock?: boolean; /** * Can be used to specify the class name to limit the instrumentation to a specific method of a class. @@ -150,6 +156,7 @@ export type PackageFileInstrumentationInstructionJSON = { modifyArgs: boolean; modifyReturnValue: boolean; modifyArgumentsObject: boolean; + callbackOnBlock: boolean; className?: string; }[]; }; From e27b8f05b9c2b5c81a1f0c33c82a8c1a3a9f713b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6ssler?= Date: Fri, 24 Apr 2026 12:30:04 +0200 Subject: [PATCH 3/6] Fix and add tests --- .../src/js_transformer/helpers/insert_code.rs | 2 + .../instrumentation/instructions.test.ts | 60 +++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/instrumentation-wasm/src/js_transformer/helpers/insert_code.rs b/instrumentation-wasm/src/js_transformer/helpers/insert_code.rs index 0ae33fff8..ab842d05b 100644 --- a/instrumentation-wasm/src/js_transformer/helpers/insert_code.rs +++ b/instrumentation-wasm/src/js_transformer/helpers/insert_code.rs @@ -9,6 +9,8 @@ use oxc_ast::{ use oxc_span::SPAN; // Add a statement to the beginning of the function: __instrumentInspectArgs('function_identifier', arguments, "{pkg_version}", this); +// In case of callback_on_block being true, we add an if statement that checks the result of the callback and returns early if the callback returns false: +// if (!__instrumentInspectArgs('function_identifier', arguments, "{pkg_version}", this)) return; pub fn insert_inspect_args<'a>( allocator: &'a Allocator, builder: &'a AstBuilder, diff --git a/library/agent/hooks/instrumentation/instructions.test.ts b/library/agent/hooks/instrumentation/instructions.test.ts index d796ec1d0..05b37b2b9 100644 --- a/library/agent/hooks/instrumentation/instructions.test.ts +++ b/library/agent/hooks/instrumentation/instructions.test.ts @@ -65,6 +65,7 @@ t.test("it works", async (t) => { modifyArgs: false, modifyReturnValue: false, modifyArgumentsObject: false, + callbackOnBlock: false, className: undefined, }, ], @@ -270,6 +271,64 @@ t.test("it works using injected functions", async (t) => { t.equal(wrapped.test, 42); }); +t.test("callbackOnBlock calls callback instead of throwing", async (t) => { + let callbackError: unknown; + + const pkg = new Package("foo"); + pkg.withVersion("^1.0.0").addFileInstrumentation({ + path: "bar.js", + functions: [ + { + nodeType: "MethodDefinition", + name: "baz", + operationKind: "outgoing_http_op", + callbackOnBlock: true, + inspectArgs: () => { + return { + operation: "http.get", + hostname: "example.com", + }; + }, + }, + ], + }); + + setPackagesToInstrument([pkg]); + createTestAgent(); + + t.equal( + getPackageFileInstrumentationInstructions("foo", "1.0.0", "bar.js") + ?.functions[0].callbackOnBlock, + true + ); + + const shouldContinue = __instrumentInspectArgs( + "foo.bar.js.baz.MethodDefinition.^1.0.0", + [ + "input", + (err: unknown) => { + callbackError = err; + }, + ], + "1.0.0", + this + ); + + t.equal(shouldContinue, false); + + await new Promise((resolve) => process.nextTick(resolve)); + + t.ok(callbackError instanceof Error); + if (callbackError instanceof Error) { + t.match( + callbackError.message, + "Zen has blocked an outbound connection: http.get(...)" + ); + } + + setPackagesToInstrument([]); +}); + t.test("modifyArgs always returns a array", async (t) => { const pkg = new Package("foo"); pkg.withVersion("^1.0.0").addFileInstrumentation({ @@ -639,6 +698,7 @@ t.test( modifyReturnValue: false, modifyArgumentsObject: false, className: "MyClass", + callbackOnBlock: false, }, ], accessLocalVariables: [], From 8dc260dd81b9f7037eb91d3fbd8e8121bcc4ae48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6ssler?= Date: Fri, 24 Apr 2026 13:47:20 +0200 Subject: [PATCH 4/6] Modify instrumentation of packages supporting callbacks --- .../getCallbackFunctionFromArgs.test.ts | 19 +++- .../helpers/getCallbackFunctionFromArgs.ts | 10 +-- library/sinks/MariaDB.tests.ts | 88 ++++++++----------- library/sinks/MariaDB.ts | 4 + library/sinks/MySQL.test.ts | 29 ++++++ library/sinks/MySQL.ts | 2 + library/sinks/MySQL2.tests.ts | 13 +++ library/sinks/MySQL2.ts | 4 + library/sinks/Postgres.test.ts | 31 ++++--- library/sinks/Postgres.ts | 3 + library/sinks/SQLite3.test.ts | 13 +++ library/sinks/SQLite3.ts | 2 + 12 files changed, 141 insertions(+), 77 deletions(-) diff --git a/library/helpers/getCallbackFunctionFromArgs.test.ts b/library/helpers/getCallbackFunctionFromArgs.test.ts index 79a1c4c4b..cfbec6433 100644 --- a/library/helpers/getCallbackFunctionFromArgs.test.ts +++ b/library/helpers/getCallbackFunctionFromArgs.test.ts @@ -2,12 +2,12 @@ import * as t from "tap"; import { getCallbackFunctionFromArgs } from "./getCallbackFunctionFromArgs"; t.test( - "getCallbackFunctionFromArgs should return the last function argument", + "getCallbackFunctionFromArgs should return the last argument if it is a function", (t) => { const callback = () => {}; - const args = [1, "string", callback, () => {}]; + const args = [1, "string", callback]; const result = getCallbackFunctionFromArgs(args); - t.equal(result, args[3]); + t.equal(result, callback); t.end(); } ); @@ -23,7 +23,7 @@ t.test( ); t.test( - "getCallbackFunctionFromArgs should return the last function argument even if there are multiple", + "getCallbackFunctionFromArgs should return the last argument if multiple functions are passed", (t) => { const callback1 = () => {}; const callback2 = () => {}; @@ -53,3 +53,14 @@ t.test( t.end(); } ); + +t.test( + "getCallbackFunctionFromArgs should return undefined if last argument is not a function, even if an earlier argument is", + (t) => { + const callback = () => {}; + const args = [1, callback, { options: true }]; + const result = getCallbackFunctionFromArgs(args); + t.equal(result, undefined); + t.end(); + } +); diff --git a/library/helpers/getCallbackFunctionFromArgs.ts b/library/helpers/getCallbackFunctionFromArgs.ts index 33a44bf1f..30a463b53 100644 --- a/library/helpers/getCallbackFunctionFromArgs.ts +++ b/library/helpers/getCallbackFunctionFromArgs.ts @@ -1,9 +1,5 @@ -// Finds the last function argument in the provided array and returns it. If no function is found, it returns undefined. +// Returns the last argument if it is a function, otherwise undefined. export function getCallbackFunctionFromArgs(args: any[]): Function | undefined { - for (let i = args.length - 1; i >= 0; i--) { - if (typeof args[i] === "function") { - return args[i] as Function; - } - } - return undefined; + const last = args[args.length - 1]; + return typeof last === "function" ? last : undefined; } diff --git a/library/sinks/MariaDB.tests.ts b/library/sinks/MariaDB.tests.ts index 27943e4e0..ba4ec2a8d 100644 --- a/library/sinks/MariaDB.tests.ts +++ b/library/sinks/MariaDB.tests.ts @@ -228,61 +228,43 @@ export async function createMariadbTests(versionPkgName: string) { t.same(Number(meta.insertId), 0); connection.execute("TRUNCATE cats"); - try { - runWithContext(dangerousContext, () => { - pool.query("-- should be blocked", () => { - t.fail("Should not be called"); + // With callbackOnBlock, errors are routed to the callback instead of thrown + runWithContext(dangerousContext, () => { + pool.query("-- should be blocked", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) { + t.same( + err.message, + "Zen has blocked an SQL injection: mariadb.query(...) originating from body.myTitle" + ); + } + runWithContext(dangerousContext, () => { + connection.query("-- should be blocked", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) { + t.same( + err.message, + "Zen has blocked an SQL injection: mariadb.query(...) originating from body.myTitle" + ); + } + runWithContext(dangerousContext, () => { + connection.execute("-- should be blocked", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) { + t.same( + err.message, + "Zen has blocked an SQL injection: mariadb.execute(...) originating from body.myTitle" + ); + } + connection.end(); + pool.end(() => { + t.end(); + }); + }); + }); + }); }); - t.fail("Should not be called"); }); - } catch (error) { - t.ok(error instanceof Error); - if (error instanceof Error) { - t.same( - error.message, - "Zen has blocked an SQL injection: mariadb.query(...) originating from body.myTitle" - ); - } - } - - try { - runWithContext(dangerousContext, () => { - connection.query("-- should be blocked", () => { - t.fail("Should not be called"); - }); - t.fail("Should not be called"); - }); - } catch (error) { - t.ok(error instanceof Error); - if (error instanceof Error) { - t.same( - error.message, - "Zen has blocked an SQL injection: mariadb.query(...) originating from body.myTitle" - ); - } - } - - try { - runWithContext(dangerousContext, () => { - connection.execute("-- should be blocked", () => { - t.fail("Should not be called"); - }); - t.fail("Should not be called"); - }); - } catch (error) { - t.ok(error instanceof Error); - if (error instanceof Error) { - t.same( - error.message, - "Zen has blocked an SQL injection: mariadb.execute(...) originating from body.myTitle" - ); - } - } - - connection.end(); - - pool.end(() => { - t.end(); }); } ); diff --git a/library/sinks/MariaDB.ts b/library/sinks/MariaDB.ts index 4b72b68cc..0795777a5 100644 --- a/library/sinks/MariaDB.ts +++ b/library/sinks/MariaDB.ts @@ -64,6 +64,7 @@ export class MariaDB implements Wrapper { wrapExport(exports.prototype, fn, pkgInfo, { kind: "sql_op", inspectArgs: (args) => this.inspectQuery(args, fn), + callbackOnBlock: true, }); } } @@ -73,6 +74,7 @@ export class MariaDB implements Wrapper { wrapExport(exports.prototype, fn, pkgInfo, { kind: "sql_op", inspectArgs: (args) => this.inspectQuery(args, fn), + callbackOnBlock: true, }); } } @@ -101,6 +103,7 @@ export class MariaDB implements Wrapper { operationKind: "sql_op", bindContext: true, inspectArgs: (args) => this.inspectQuery(args, fn), + callbackOnBlock: true, })) ) .addMultiFileInstrumentation( @@ -111,6 +114,7 @@ export class MariaDB implements Wrapper { operationKind: "sql_op", bindContext: true, inspectArgs: (args) => this.inspectQuery(args, fn), + callbackOnBlock: true, })) ); } diff --git a/library/sinks/MySQL.test.ts b/library/sinks/MySQL.test.ts index 92112d373..9e576bec7 100644 --- a/library/sinks/MySQL.test.ts +++ b/library/sinks/MySQL.test.ts @@ -148,6 +148,35 @@ t.test("it detects SQL injections", async () => { t.same(getContext(), context); }); }); + + await new Promise((resolve) => { + runWithContext(context, () => { + connection.query("-- should be blocked", (err: Error | null) => { + t.ok(err instanceof Error); + t.same( + err?.message, + "Zen has blocked an SQL injection: MySQL.query(...) originating from body.myTitle" + ); + resolve(); + }); + }); + }); + + await new Promise((resolve) => { + runWithContext(context, () => { + connection.query( + { sql: "-- should be blocked" }, + (err: Error | null) => { + t.ok(err instanceof Error); + t.same( + err?.message, + "Zen has blocked an SQL injection: MySQL.query(...) originating from body.myTitle" + ); + resolve(); + } + ); + }); + }); } catch (error: any) { t.fail(error); } finally { diff --git a/library/sinks/MySQL.ts b/library/sinks/MySQL.ts index ad15ed7af..96cb47b22 100644 --- a/library/sinks/MySQL.ts +++ b/library/sinks/MySQL.ts @@ -105,6 +105,7 @@ export class MySQL implements Wrapper { wrapExport(exports.prototype, "query", pkgInfo, { kind: "sql_op", inspectArgs: (args) => this.inspectQuery(args), + callbackOnBlock: true, }); }) .addFileInstrumentation({ @@ -116,6 +117,7 @@ export class MySQL implements Wrapper { operationKind: "sql_op", bindContext: true, inspectArgs: (args) => this.inspectQuery(args), + callbackOnBlock: true, }, ], }); diff --git a/library/sinks/MySQL2.tests.ts b/library/sinks/MySQL2.tests.ts index cf6287e11..6320cbe0c 100644 --- a/library/sinks/MySQL2.tests.ts +++ b/library/sinks/MySQL2.tests.ts @@ -159,6 +159,19 @@ export function createMySQL2Tests(versionPkgName: string) { runWithContext(safeContext, () => { connection2!.query("-- This is a comment"); }); + + await new Promise((resolve) => { + runWithContext(dangerousContext, () => { + connection2!.execute("-- should be blocked", (err: any) => { + t.ok(err instanceof Error); + t.same( + err?.message, + "Zen has blocked an SQL injection: mysql2.execute(...) originating from body.myTitle" + ); + resolve(); + }); + }); + }); } catch (error: any) { t.fail(error); } finally { diff --git a/library/sinks/MySQL2.ts b/library/sinks/MySQL2.ts index 2db4bae74..3f2b59294 100644 --- a/library/sinks/MySQL2.ts +++ b/library/sinks/MySQL2.ts @@ -130,6 +130,7 @@ export class MySQL2 implements Wrapper { inspectArgs: (args) => this.inspectQuery("mysql2.query", args), operationKind: "sql_op", bindContext: true, + callbackOnBlock: true, }, { nodeType: "MethodDefinition", @@ -137,6 +138,7 @@ export class MySQL2 implements Wrapper { inspectArgs: (args) => this.inspectQuery("mysql2.execute", args), operationKind: "sql_op", bindContext: true, + callbackOnBlock: true, }, ]; } @@ -156,6 +158,7 @@ export class MySQL2 implements Wrapper { wrapExport(connectionPrototype, "query", pkgInfo, { kind: "sql_op", inspectArgs: (args) => this.inspectQuery("mysql2.query", args), + callbackOnBlock: true, }); } @@ -164,6 +167,7 @@ export class MySQL2 implements Wrapper { wrapExport(connectionPrototype, "execute", pkgInfo, { kind: "sql_op", inspectArgs: (args) => this.inspectQuery("mysql2.execute", args), + callbackOnBlock: true, }); } }; diff --git a/library/sinks/Postgres.test.ts b/library/sinks/Postgres.test.ts index f1ae02803..ddacd0c78 100644 --- a/library/sinks/Postgres.test.ts +++ b/library/sinks/Postgres.test.ts @@ -117,19 +117,24 @@ t.test("it inspects query method calls and blocks if needed", async (t) => { } ); - // Check if context is available in the callback - runWithContext(context, () => { - client.query("SELECT petname FROM cats;", (error, result) => { - t.match(getContext(), context); - - try { - client.query("-- should be blocked", () => {}); - } catch (error: any) { - t.match( - error.message, - /Zen has blocked an SQL injection: pg.query\(\.\.\.\) originating from body\.myTitle/ - ); - } + // Check that context is available in callback and error is routed to callback on block + await new Promise((resolve, reject) => { + runWithContext(context, () => { + client.query("SELECT petname FROM cats;", (error, result) => { + if (error) { + reject(error); + return; + } + t.match(getContext(), context); + client.query("-- should be blocked", (err: Error | null) => { + t.ok(err instanceof Error); + t.match( + err?.message, + /Zen has blocked an SQL injection: pg\.query\(\.\.\.\) originating from body\.myTitle/ + ); + resolve(); + }); + }); }); }); } catch (error: any) { diff --git a/library/sinks/Postgres.ts b/library/sinks/Postgres.ts index 895aab6c5..01a822076 100644 --- a/library/sinks/Postgres.ts +++ b/library/sinks/Postgres.ts @@ -129,6 +129,7 @@ export class Postgres implements Wrapper { wrapExport(exports.Client.prototype, "query", pkgInfo, { kind: "sql_op", inspectArgs: (args) => this.inspectQuery(args), + callbackOnBlock: true, }); }) .addFileInstrumentation({ @@ -140,6 +141,7 @@ export class Postgres implements Wrapper { operationKind: "sql_op", bindContext: true, inspectArgs: (args) => this.inspectQuery(args), + callbackOnBlock: true, }, ], }); @@ -205,6 +207,7 @@ export class Postgres implements Wrapper { wrapExport(pool.Client.prototype, "query", pkgInfo, { kind: "sql_op", inspectArgs: (args) => this.inspectQuery(args), + callbackOnBlock: true, }); return args; diff --git a/library/sinks/SQLite3.test.ts b/library/sinks/SQLite3.test.ts index 69e42ff50..8692f2007 100644 --- a/library/sinks/SQLite3.test.ts +++ b/library/sinks/SQLite3.test.ts @@ -152,6 +152,19 @@ t.test( ); } delete process.env.AIKIDO_BLOCK_INVALID_SQL; + + await new Promise((resolve) => { + runWithContext(dangerousContext, () => { + db.run("SELECT 1;-- should be blocked", (err: Error | null) => { + t.ok(err instanceof Error); + t.same( + err?.message, + "Zen has blocked an SQL injection: sqlite3.run(...) originating from body.myTitle" + ); + resolve(); + }); + }); + }); } catch (error: any) { t.fail(error); } finally { diff --git a/library/sinks/SQLite3.ts b/library/sinks/SQLite3.ts index e0c10bbc3..ae5631a23 100644 --- a/library/sinks/SQLite3.ts +++ b/library/sinks/SQLite3.ts @@ -79,6 +79,7 @@ export class SQLite3 implements Wrapper { inspectArgs: (args) => { return this.inspectQuery(`sqlite3.${func}`, args); }, + callbackOnBlock: true, }); } @@ -87,6 +88,7 @@ export class SQLite3 implements Wrapper { inspectArgs: (args) => { return this.inspectPath(`sqlite3.backup`, args); }, + callbackOnBlock: true, }); } From 1ab8627f417b6d8067df005206ac6442ba7be42f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6ssler?= Date: Fri, 24 Apr 2026 14:42:11 +0200 Subject: [PATCH 5/6] Also respect callbacks for built-ins --- library/sinks/ChildProcess.test.ts | 94 ++++++------ library/sinks/ChildProcess.ts | 1 + library/sinks/FileSystem.test.ts | 236 ++++++++++++++++++----------- library/sinks/FileSystem.ts | 85 +++++++---- 4 files changed, 255 insertions(+), 161 deletions(-) diff --git a/library/sinks/ChildProcess.test.ts b/library/sinks/ChildProcess.test.ts index dac2043ba..65311bc0e 100644 --- a/library/sinks/ChildProcess.test.ts +++ b/library/sinks/ChildProcess.test.ts @@ -93,17 +93,26 @@ t.test("it works", async (t) => { }); runWithContext(unsafeContext, () => { - throws( - () => exec("ls `echo .`", (err, stdout, stderr) => {}).unref(), - "Zen has blocked a shell injection: child_process.execFile(...) originating from body.file.matches" - ); - throws( () => execSync("ls `echo .`"), "Zen has blocked a shell injection: child_process.execSync(...) originating from body.file.matches" ); }); + await runWithContext(unsafeContext, async () => { + await new Promise((resolve) => { + exec("ls `echo .`", (err) => { + t.ok(err instanceof Error); + if (err instanceof Error) + t.match( + err.message, + "Zen has blocked a shell injection: child_process.execFile(...) originating from body.file.matches" + ); + resolve(); + }); + }); + }); + runWithContext(unsafeContext, () => { throws( () => spawn("ls `echo .`", [], { shell: true }).unref(), @@ -156,49 +165,11 @@ t.test("it works", async (t) => { }); runWithContext(unsafeContext, () => { - throws( - () => - execFile( - "ls `echo .`", - [], - { shell: true }, - (err, stdout, stderr) => {} - ).unref(), - "Zen has blocked a shell injection: child_process.execFile(...) originating from body.file.matches" - ); - throws( () => execFileSync("ls `echo .`", [], { shell: true }), "Zen has blocked a shell injection: child_process.execFileSync(...) originating from body.file.matches" ); - throws( - () => - execFile( - "ls", - ["`echo .`"], - { shell: true }, - (err, stdout, stderr) => {} - ).unref(), - "Zen has blocked a shell injection: child_process.execFile(...) originating from body.file.matches" - ); - - throws( - () => - execFile("sh", ["-c", "`echo .`"], (err, stdout, stderr) => {}).unref(), - "Zen has blocked a shell injection: child_process.execFile(...) originating from body.file.matches" - ); - - throws( - () => - execFile( - "/bin/sh", - ["-c", "`echo .`"], - (err, stdout, stderr) => {} - ).unref(), - "Zen has blocked a shell injection: child_process.execFile(...) originating from body.file.matches" - ); - throws( () => execFileSync("/bin/sh", ["-c", "`echo .`"]), "Zen has blocked a shell injection: child_process.execFileSync(...) originating from body.file.matches" @@ -210,6 +181,43 @@ t.test("it works", async (t) => { ); }); + await runWithContext(unsafeContext, async () => { + const msg = + "Zen has blocked a shell injection: child_process.execFile(...) originating from body.file.matches"; + + await new Promise((resolve) => { + execFile("ls `echo .`", [], { shell: true }, (err) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, msg); + resolve(); + }); + }); + + await new Promise((resolve) => { + execFile("ls", ["`echo .`"], { shell: true }, (err) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, msg); + resolve(); + }); + }); + + await new Promise((resolve) => { + execFile("sh", ["-c", "`echo .`"], (err) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, msg); + resolve(); + }); + }); + + await new Promise((resolve) => { + execFile("/bin/sh", ["-c", "`echo .`"], (err) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, msg); + resolve(); + }); + }); + }); + runWithContext( { ...unsafeContext, body: { file: { matches: "/../rce.js" } } }, () => { diff --git a/library/sinks/ChildProcess.ts b/library/sinks/ChildProcess.ts index 4120606f6..51456fcb4 100644 --- a/library/sinks/ChildProcess.ts +++ b/library/sinks/ChildProcess.ts @@ -43,6 +43,7 @@ export class ChildProcess implements Wrapper { inspectArgs: (args) => { return this.inspectExecFile(args, "execFile"); }, + callbackOnBlock: true, }); wrapExport(exports, "execFileSync", pkgInfo, { kind: "exec_op", diff --git a/library/sinks/FileSystem.test.ts b/library/sinks/FileSystem.test.ts index 85394993d..3b01ad7db 100644 --- a/library/sinks/FileSystem.test.ts +++ b/library/sinks/FileSystem.test.ts @@ -106,16 +106,23 @@ t.test("it works", async (t) => { }); await runWithContext(unsafeContext, async () => { - throws( - () => - writeFile( - "../../test.txt", - "some file content to test with", - { encoding: "utf-8" }, - () => {} - ), - "Zen has blocked a path traversal attack: fs.writeFile(...) originating from body.file.matches" - ); + const writeFileMsg = + "Zen has blocked a path traversal attack: fs.writeFile(...) originating from body.file.matches"; + const renameMsg = + "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches"; + + await new Promise((resolve) => { + writeFile( + "../../test.txt", + "some file content to test with", + { encoding: "utf-8" }, + (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, writeFileMsg); + resolve(); + } + ); + }); throws( () => @@ -136,10 +143,7 @@ t.test("it works", async (t) => { ); t.ok(error instanceof Error); if (error instanceof Error) { - t.match( - error.message, - "Zen has blocked a path traversal attack: fs.writeFile(...) originating from body.file.matches" - ); + t.match(error.message, writeFileMsg); t.same(error.stack!.includes("wrapExport.ts"), false); } @@ -152,80 +156,126 @@ t.test("it works", async (t) => { ); t.ok(error2 instanceof Error); if (error2 instanceof Error) { + t.match(error2.message, writeFileMsg); + } + + const error3 = await t.rejects(() => + fsDotPromise.readFile("../../test.txt", { encoding: "utf-8" }) + ); + t.ok(error3 instanceof Error); + if (error3 instanceof Error) { t.match( - error2.message, - "Zen has blocked a path traversal attack: fs.writeFile(...) originating from body.file.matches" + error3.message, + "Zen has blocked a path traversal attack: fs.readFile(...) originating from body.file.matches" ); } - throws( - () => rename("../../test.txt", "./test2.txt", () => {}), - "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches" - ); + await new Promise((resolve) => { + rename("../../test.txt", "./test2.txt", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, renameMsg); + resolve(); + }); + }); - throws( - () => rename("./test.txt", "../../test.txt", () => {}), - "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches" - ); + await new Promise((resolve) => { + rename("./test.txt", "../../test.txt", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, renameMsg); + resolve(); + }); + }); - throws( - () => rename(new URL("file:///../test.txt"), "../test2.txt", () => {}), - "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches" - ); + await new Promise((resolve) => { + rename(new URL("file:///../test.txt"), "../test2.txt", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, renameMsg); + resolve(); + }); + }); - throws( - () => rename(new URL("file:///./../test.txt"), "../test2.txt", () => {}), - "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches" - ); + await new Promise((resolve) => { + rename(new URL("file:///./../test.txt"), "../test2.txt", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, renameMsg); + resolve(); + }); + }); - throws( - () => rename(new URL("file:///../../test.txt"), "../test2.txt", () => {}), - "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches" - ); + await new Promise((resolve) => { + rename(new URL("file:///../../test.txt"), "../test2.txt", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, renameMsg); + resolve(); + }); + }); - throws( - () => rename(Buffer.from("../test.txt"), "../test2.txt", () => {}), - "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches" - ); + await new Promise((resolve) => { + rename(Buffer.from("../test.txt"), "../test2.txt", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, renameMsg); + resolve(); + }); + }); }); - runWithContext(unsafeContextAbsolute, () => { - throws( - () => rename(new URL("file:///etc/passwd"), "../test123.txt", () => {}), - "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches" - ); - throws( - () => - rename(new URL("file:///../etc/passwd"), "../test123.txt", () => {}), - "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches" - ); + await runWithContext(unsafeContextAbsolute, async () => { + const msg = + "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches"; - throws( - () => rename("/etc/passwd", "../test123.txt", () => {}), - "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches" - ); + await new Promise((resolve) => { + rename(new URL("file:///etc/passwd"), "../test123.txt", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, msg); + resolve(); + }); + }); + + await new Promise((resolve) => { + rename(new URL("file:///../etc/passwd"), "../test123.txt", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, msg); + resolve(); + }); + }); + + await new Promise((resolve) => { + rename("/etc/passwd", "../test123.txt", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, msg); + resolve(); + }); + }); }); - runWithContext( + await runWithContext( { ...unsafeContextAbsolute, body: { file: { matches: "//etc/passwd" } }, }, - () => { - throws( - () => - rename(new URL("file:////etc/passwd"), "../test123.txt", () => {}), - "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches" - ); + async () => { + const msg = + "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches"; - throws( - () => rename("//etc/passwd", "../test123.txt", () => {}), - "Zen has blocked a path traversal attack: fs.rename(...) originating from body.file.matches" - ); + await new Promise((resolve) => { + rename(new URL("file:////etc/passwd"), "../test123.txt", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, msg); + resolve(); + }); + }); + + await new Promise((resolve) => { + rename("//etc/passwd", "../test123.txt", (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) t.match(err.message, msg); + resolve(); + }); + }); } ); - runWithContext( + await runWithContext( { remoteAddress: "::1", method: "POST", @@ -240,16 +290,22 @@ t.test("it works", async (t) => { source: "express", route: "/posts/:id", }, - () => { - throws( - () => - rename( - new URL("file:///.\t./etc/passwd"), - "../test123.txt", - () => {} - ), - "Zen has blocked a path traversal attack: fs.rename(...) originating from query.q" - ); + async () => { + await new Promise((resolve) => { + rename( + new URL("file:///.\t./etc/passwd"), + "../test123.txt", + (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) + t.match( + err.message, + "Zen has blocked a path traversal attack: fs.rename(...) originating from query.q" + ); + resolve(); + } + ); + }); } ); @@ -273,7 +329,7 @@ t.test("it works", async (t) => { } ); - runWithContext( + await runWithContext( { remoteAddress: "::1", method: "POST", @@ -288,16 +344,22 @@ t.test("it works", async (t) => { source: "express", route: "/posts/:id", }, - () => { - throws( - () => - rename( - new URL("file:///.\t\t./etc/passwd"), - "../test123.txt", - () => {} - ), - "Zen has blocked a path traversal attack: fs.rename(...) originating from query.q" - ); + async () => { + await new Promise((resolve) => { + rename( + new URL("file:///.\t\t./etc/passwd"), + "../test123.txt", + (err: any) => { + t.ok(err instanceof Error); + if (err instanceof Error) + t.match( + err.message, + "Zen has blocked a path traversal attack: fs.rename(...) originating from query.q" + ); + resolve(); + } + ); + }); } ); diff --git a/library/sinks/FileSystem.ts b/library/sinks/FileSystem.ts index fec0e0a45..eb7761166 100644 --- a/library/sinks/FileSystem.ts +++ b/library/sinks/FileSystem.ts @@ -12,6 +12,7 @@ type FileSystemFunction = { pathsArgs: number; // The amount of arguments that are paths sync: boolean; // Whether the function has a synchronous version (e.g. fs.accessSync) promise: boolean; // Whether the function has a promise version (e.g. fs.promises.access) + callback: boolean; // Whether the async version accepts an error-first callback as last arg }; export class FileSystem implements Wrapper { @@ -51,44 +52,64 @@ export class FileSystem implements Wrapper { private getFunctions(): Record { const functions: Record = { - appendFile: { pathsArgs: 1, sync: true, promise: true }, - chmod: { pathsArgs: 1, sync: true, promise: true }, - chown: { pathsArgs: 1, sync: true, promise: true }, - createReadStream: { pathsArgs: 1, sync: false, promise: false }, - createWriteStream: { pathsArgs: 1, sync: false, promise: false }, - lchown: { pathsArgs: 1, sync: true, promise: true }, - lutimes: { pathsArgs: 1, sync: true, promise: true }, - mkdir: { pathsArgs: 1, sync: true, promise: true }, - open: { pathsArgs: 1, sync: true, promise: true }, - opendir: { pathsArgs: 1, sync: true, promise: true }, - readdir: { pathsArgs: 1, sync: true, promise: true }, - readFile: { pathsArgs: 1, sync: true, promise: true }, - readlink: { pathsArgs: 1, sync: true, promise: true }, - unlink: { pathsArgs: 1, sync: true, promise: true }, - realpath: { pathsArgs: 1, sync: true, promise: true }, - rename: { pathsArgs: 2, sync: true, promise: true }, - rmdir: { pathsArgs: 1, sync: true, promise: true }, - rm: { pathsArgs: 1, sync: true, promise: true }, - symlink: { pathsArgs: 2, sync: true, promise: true }, - truncate: { pathsArgs: 1, sync: true, promise: true }, - utimes: { pathsArgs: 1, sync: true, promise: true }, - writeFile: { pathsArgs: 1, sync: true, promise: true }, - copyFile: { pathsArgs: 2, sync: true, promise: true }, - cp: { pathsArgs: 2, sync: true, promise: true }, - link: { pathsArgs: 2, sync: true, promise: true }, - watch: { pathsArgs: 1, sync: false, promise: false }, - watchFile: { pathsArgs: 1, sync: false, promise: false }, - mkdtemp: { pathsArgs: 1, sync: true, promise: true }, + appendFile: { pathsArgs: 1, sync: true, promise: true, callback: true }, + chmod: { pathsArgs: 1, sync: true, promise: true, callback: true }, + chown: { pathsArgs: 1, sync: true, promise: true, callback: true }, + createReadStream: { + pathsArgs: 1, + sync: false, + promise: false, + callback: false, + }, + createWriteStream: { + pathsArgs: 1, + sync: false, + promise: false, + callback: false, + }, + lchown: { pathsArgs: 1, sync: true, promise: true, callback: true }, + lutimes: { pathsArgs: 1, sync: true, promise: true, callback: true }, + mkdir: { pathsArgs: 1, sync: true, promise: true, callback: true }, + open: { pathsArgs: 1, sync: true, promise: true, callback: true }, + opendir: { pathsArgs: 1, sync: true, promise: true, callback: true }, + readdir: { pathsArgs: 1, sync: true, promise: true, callback: true }, + readFile: { pathsArgs: 1, sync: true, promise: true, callback: true }, + readlink: { pathsArgs: 1, sync: true, promise: true, callback: true }, + unlink: { pathsArgs: 1, sync: true, promise: true, callback: true }, + realpath: { pathsArgs: 1, sync: true, promise: true, callback: true }, + rename: { pathsArgs: 2, sync: true, promise: true, callback: true }, + rmdir: { pathsArgs: 1, sync: true, promise: true, callback: true }, + rm: { pathsArgs: 1, sync: true, promise: true, callback: true }, + symlink: { pathsArgs: 2, sync: true, promise: true, callback: true }, + truncate: { pathsArgs: 1, sync: true, promise: true, callback: true }, + utimes: { pathsArgs: 1, sync: true, promise: true, callback: true }, + writeFile: { pathsArgs: 1, sync: true, promise: true, callback: true }, + copyFile: { pathsArgs: 2, sync: true, promise: true, callback: true }, + cp: { pathsArgs: 2, sync: true, promise: true, callback: true }, + link: { pathsArgs: 2, sync: true, promise: true, callback: true }, + watch: { pathsArgs: 1, sync: false, promise: false, callback: false }, + watchFile: { pathsArgs: 1, sync: false, promise: false, callback: false }, + mkdtemp: { pathsArgs: 1, sync: true, promise: true, callback: true }, }; // Added in v19.8.0 if (isVersionGreaterOrEqual("19.8.0", getSemverNodeVersion())) { - functions.openAsBlob = { pathsArgs: 1, sync: false, promise: false }; + functions.openAsBlob = { + pathsArgs: 1, + sync: false, + promise: false, + callback: false, + }; } // Only available on macOS if (process.platform === "darwin") { - functions.lchmod = { pathsArgs: 1, sync: true, promise: true }; + functions.lchmod = { + pathsArgs: 1, + sync: true, + promise: true, + callback: true, + }; } return functions; @@ -121,13 +142,14 @@ export class FileSystem implements Wrapper { const functions = this.getFunctions(); Object.keys(functions).forEach((name) => { - const { pathsArgs, sync } = functions[name]; + const { pathsArgs, sync, callback } = functions[name]; wrapExport(exports, name, pkgInfo, { kind: "fs_op", inspectArgs: (args) => { return this.inspectPath(args, name, pathsArgs); }, + callbackOnBlock: callback, }); if (sync) { @@ -146,6 +168,7 @@ export class FileSystem implements Wrapper { inspectArgs: (args) => { return this.inspectPath(args, "realpath.native", 1); }, + callbackOnBlock: true, }); wrapExport(exports.realpathSync, "native", pkgInfo, { From 0cbea4e157661aa025437410bdf3f324582cf241 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6ssler?= Date: Fri, 24 Apr 2026 15:09:16 +0200 Subject: [PATCH 6/6] Fix some more tests --- library/sources/Express.tests.ts | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/library/sources/Express.tests.ts b/library/sources/Express.tests.ts index d6398862f..331c3d165 100644 --- a/library/sources/Express.tests.ts +++ b/library/sources/Express.tests.ts @@ -121,8 +121,10 @@ export async function createExpressTests(expressPackageName: string) { }); app.use("/attack-in-middleware", (req, res, next) => { - readdir(req.query.directory as string, () => {}); - next(); + readdir(req.query.directory as string, (err) => { + if (err) return next(err); + next(); + }); }); function apiMiddleware(req: Request, res: Response, next: NextFunction) { @@ -189,16 +191,18 @@ export async function createExpressTests(expressPackageName: string) { res.send(context); }); - app.get("/files", (req, res) => { - readdir(req.query.directory as string, () => {}); - - res.send(getContext()); + app.get("/files", (req, res, next) => { + readdir(req.query.directory as string, (err) => { + if (err) return next(err); + res.send(getContext()); + }); }); - app.get("/files-subdomains", (req, res) => { - readdir(req.subdomains[2], () => {}); - - res.send(getContext()); + app.get("/files-subdomains", (req, res, next) => { + readdir(req.subdomains[2], (err) => { + if (err) return next(err); + res.send(getContext()); + }); }); app.get("/attack-in-middleware", (req, res) => { @@ -778,18 +782,22 @@ export async function createExpressTests(expressPackageName: string) { t.test("it detects path traversal with double encoding", async (t) => { const app = express(); - app.get("/search", (req, res) => { + app.get("/search", (req, res, next) => { const searchTerm = req.query.q as string; const fileUrl = new URL(`file:///public/${searchTerm}`); readFile(fileUrl, "utf-8", (err, data) => { if (err) { - return res.status(500).send("Error reading file"); + return next(err); } res.send(`File content of /public/${searchTerm} : ${data}`); }); }); + app.use((error: Error, req: Request, res: Response, next: NextFunction) => { + res.status(500).send(String(error)); + }); + const blockedResponse = await request(app).get( "/search?q=.%252E/etc/passwd" );