Skip to content

Commit ab22b27

Browse files
xenovaJason-Shen2
andauthored
Fix custom cache in Node.js when cache returns non-paths (e.g., Response) (#1617)
* Add custom cache unit test * assign cache key on cache hit * also pass externalData in node.js when not using paths (e.g,. custom cache) * Update custom_cache.test.js --------- Co-authored-by: zxshen <zshen339@gatech.edu>
1 parent d834093 commit ab22b27

3 files changed

Lines changed: 113 additions & 2 deletions

File tree

packages/transformers/src/models/session.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ async function getSession(
115115
session_options,
116116
);
117117

118-
if (externalData.length > 0 && !apis.IS_NODE_ENV) {
118+
if (externalData.length > 0 && (!apis.IS_NODE_ENV || externalData.some((data) => typeof data !== 'string'))) {
119119
session_options.externalData = externalData;
120120
}
121121

packages/transformers/src/utils/hub.js

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ export async function loadResourceFile(
270270
response = await checkCachedResource(cache, localPath, proposedCacheKey);
271271

272272
const cacheHit = response !== undefined;
273-
if (!cacheHit) {
273+
if (cacheHit) {
274+
cacheKey = proposedCacheKey;
275+
} else {
274276
// Caching not available, or file is not cached, so we perform the request
275277

276278
if (env.allowLocalModels) {
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import { env, LlamaForCausalLM, AutoTokenizer } from "../../src/transformers.js";
2+
import { init, MAX_TEST_EXECUTION_TIME, DEFAULT_MODEL_OPTIONS } from "../init.js";
3+
4+
// Initialise the testing environment
5+
init();
6+
7+
/**
8+
* A naive custom cache implementation that fetches files directly from the
9+
* Hugging Face Hub and stores them in an internal (in-memory) map.
10+
* This satisfies the CacheInterface contract (`match` + `put`).
11+
*/
12+
class NaiveFetchCache {
13+
constructor() {
14+
/** @type {Map<string, Response>} */
15+
this.cache = new Map();
16+
}
17+
18+
async match(request) {
19+
const cached = this.cache.get(request);
20+
if (cached) {
21+
return cached.clone();
22+
}
23+
24+
// Not in cache — attempt a fresh fetch from the URL.
25+
try {
26+
const response = await fetch(request);
27+
if (response.ok) {
28+
this.cache.set(request, response);
29+
return response.clone();
30+
}
31+
} catch {
32+
// Ignore fetch errors (e.g., invalid URLs like local paths) — treat as cache miss
33+
}
34+
return undefined;
35+
}
36+
37+
async put(request, response) {
38+
if (!this.cache.has(request)) {
39+
this.cache.set(request, response);
40+
}
41+
}
42+
}
43+
44+
describe("Custom cache", () => {
45+
// Store original env values so we can restore them after tests
46+
const originalUseCustomCache = env.useCustomCache;
47+
const originalCustomCache = env.customCache;
48+
const originalUseBrowserCache = env.useBrowserCache;
49+
const originalUseFSCache = env.useFSCache;
50+
const originalAllowLocalModels = env.allowLocalModels;
51+
52+
beforeAll(() => {
53+
// Disable all other caching mechanisms so only customCache is used
54+
env.useCustomCache = true;
55+
env.customCache = new NaiveFetchCache();
56+
env.useBrowserCache = false;
57+
env.useFSCache = false;
58+
env.allowLocalModels = false;
59+
});
60+
61+
afterAll(() => {
62+
// Restore original env values
63+
env.useCustomCache = originalUseCustomCache;
64+
env.customCache = originalCustomCache;
65+
env.useBrowserCache = originalUseBrowserCache;
66+
env.useFSCache = originalUseFSCache;
67+
env.allowLocalModels = originalAllowLocalModels;
68+
});
69+
70+
it(
71+
"should load a model using custom cache (standard)",
72+
async () => {
73+
const model_id = "onnx-internal-testing/tiny-random-LlamaForCausalLM-ONNX";
74+
75+
const tokenizer = await AutoTokenizer.from_pretrained(model_id);
76+
const model = await LlamaForCausalLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
77+
78+
const inputs = await tokenizer("Hello");
79+
const output = await model(inputs);
80+
81+
expect(output.logits).toBeDefined();
82+
const expected_shape = [...inputs.input_ids.dims, model.config.vocab_size];
83+
expect(output.logits.dims).toEqual(expected_shape);
84+
85+
await model.dispose();
86+
},
87+
MAX_TEST_EXECUTION_TIME,
88+
);
89+
90+
it(
91+
"should load a model using custom cache (external data)",
92+
async () => {
93+
const model_id = "onnx-internal-testing/tiny-random-LlamaForCausalLM-ONNX_external";
94+
95+
const tokenizer = await AutoTokenizer.from_pretrained(model_id);
96+
const model = await LlamaForCausalLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
97+
98+
const inputs = await tokenizer("Hello");
99+
const output = await model(inputs);
100+
101+
expect(output.logits).toBeDefined();
102+
const expected_shape = [...inputs.input_ids.dims, model.config.vocab_size];
103+
expect(output.logits.dims).toEqual(expected_shape);
104+
105+
await model.dispose();
106+
},
107+
MAX_TEST_EXECUTION_TIME,
108+
);
109+
});

0 commit comments

Comments
 (0)