Skip to content

Commit a196db3

Browse files
committed
refactor: use web worker to stop blocking main thread
1 parent 0abcb75 commit a196db3

File tree

2 files changed

+263
-134
lines changed

2 files changed

+263
-134
lines changed

src/lib/formFieldDetection.ts

Lines changed: 84 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
import * as ort from "onnxruntime-web";
21
import * as pdfjsLib from "pdfjs-dist";
3-
import { applyNonMaximumSuppression } from "./utils";
2+
import InferenceWorker from "../workers/inference.worker.ts?worker";
3+
4+
pdfjsLib.GlobalWorkerOptions.workerSrc = `https://cdn.jsdelivr.net/npm/pdfjs-dist@${pdfjsLib.version}/build/pdf.worker.min.mjs`;
5+
6+
const TARGET_SIZE = 1216;
47

58
interface DetectedField {
69
type: string;
@@ -44,11 +47,6 @@ export type DetectionResult =
4447
| { success: true; data: DetectionData }
4548
| { success: false; error: { code: ErrorCode; message: string } };
4649

47-
const CLASS_NAMES = ["TextBox", "ChoiceButton", "Signature"];
48-
const IOU_THRESHOLD = 0.45;
49-
const TARGET_SIZE = 1216;
50-
const ADJUSTED_HEIGHT_FACTOR = 1;
51-
5250
const COLORS = {
5351
TextBox: {
5452
label: "#3B82F6",
@@ -64,18 +62,6 @@ const COLORS = {
6462
},
6563
};
6664

67-
const sortFieldsByReadingOrder = (fields: DetectedField[]): DetectedField[] => {
68-
return [...fields].sort((a, b) => {
69-
const [aX, aY] = a.bbox;
70-
const [bX, bY] = b.bbox;
71-
const yDiff = aY - bY;
72-
if (Math.abs(yDiff) > 0.01) {
73-
return yDiff;
74-
}
75-
return aX - bX;
76-
});
77-
};
78-
7965
const drawDetections = (
8066
canvas: HTMLCanvasElement,
8167
fields: DetectedField[]
@@ -105,125 +91,52 @@ const drawDetections = (
10591
});
10692
};
10793

108-
const processPage = async (
109-
page: pdfjsLib.PDFPageProxy,
110-
session: ort.InferenceSession,
111-
confidenceThreshold: number
112-
): Promise<PageDetectionData> => {
94+
const renderPdfPageToImageData = async (
95+
page: pdfjsLib.PDFPageProxy
96+
): Promise<{
97+
imageData: ImageData;
98+
pdfMetadata: {
99+
originalWidth: number;
100+
originalHeight: number;
101+
canvasSize: number;
102+
offsetX: number;
103+
offsetY: number;
104+
};
105+
}> => {
113106
const viewport = page.getViewport({ scale: 1.0 });
114107
const scale = Math.min(
115108
TARGET_SIZE / viewport.width,
116109
TARGET_SIZE / viewport.height
117110
);
118111
const scaledViewport = page.getViewport({ scale });
119112

120-
const tempCanvas = document.createElement("canvas");
121-
const tempContext = tempCanvas.getContext("2d")!;
122-
tempCanvas.height = scaledViewport.height;
123-
tempCanvas.width = scaledViewport.width;
113+
const canvas = document.createElement("canvas");
114+
canvas.width = scaledViewport.width;
115+
canvas.height = scaledViewport.height;
116+
const context = canvas.getContext("2d")!;
124117

125118
await page.render({
126-
canvasContext: tempContext,
119+
canvasContext: context,
127120
viewport: scaledViewport,
128-
canvas: tempCanvas,
121+
canvas,
129122
}).promise;
130123

131-
const canvas = document.createElement("canvas");
132-
const context = canvas.getContext("2d")!;
133-
canvas.width = TARGET_SIZE;
134-
canvas.height = TARGET_SIZE;
135-
136-
context.fillStyle = "white";
137-
context.fillRect(0, 0, TARGET_SIZE, TARGET_SIZE);
124+
const finalCanvas = document.createElement("canvas");
125+
finalCanvas.width = TARGET_SIZE;
126+
finalCanvas.height = TARGET_SIZE;
127+
const finalContext = finalCanvas.getContext("2d")!;
138128

139-
const offsetX = (TARGET_SIZE - tempCanvas.width) / 2;
140-
const offsetY = (TARGET_SIZE - tempCanvas.height) / 2;
141-
context.drawImage(tempCanvas, offsetX, offsetY);
129+
finalContext.fillStyle = "white";
130+
finalContext.fillRect(0, 0, TARGET_SIZE, TARGET_SIZE);
142131

143-
const imageData = context.getImageData(0, 0, TARGET_SIZE, TARGET_SIZE);
132+
const offsetX = (TARGET_SIZE - canvas.width) / 2;
133+
const offsetY = (TARGET_SIZE - canvas.height) / 2;
134+
finalContext.drawImage(canvas, offsetX, offsetY);
144135

145-
const rgbData = new Float32Array(3 * canvas.height * canvas.width);
146-
147-
for (let i = 0; i < imageData.data.length / 4; i++) {
148-
const r = imageData.data[i * 4] / 255.0;
149-
const g = imageData.data[i * 4 + 1] / 255.0;
150-
const b = imageData.data[i * 4 + 2] / 255.0;
151-
152-
rgbData[i] = r;
153-
rgbData[canvas.height * canvas.width + i] = g;
154-
rgbData[2 * canvas.height * canvas.width + i] = b;
155-
}
156-
157-
const tensor = new ort.Tensor("float32", rgbData, [
158-
1,
159-
3,
160-
canvas.height,
161-
canvas.width,
162-
]);
163-
164-
const feeds = { images: tensor };
165-
const output = await session.run(feeds);
166-
167-
const outputTensor = output["output0"];
168-
const outputData = outputTensor.data as Float32Array;
169-
const outputDims = outputTensor.dims as number[];
170-
171-
const numPredictions = outputDims[2];
172-
const detections: Array<{
173-
box: [number, number, number, number];
174-
classId: number;
175-
confidence: number;
176-
}> = [];
177-
178-
for (let i = 0; i < numPredictions; i++) {
179-
const cx = outputData[i];
180-
const cy = outputData[numPredictions + i];
181-
const w = outputData[2 * numPredictions + i];
182-
const h = outputData[3 * numPredictions + i];
183-
184-
const class0Score = outputData[4 * numPredictions + i];
185-
const class1Score = outputData[5 * numPredictions + i];
186-
const class2Score = outputData[6 * numPredictions + i];
187-
188-
const scores = [class0Score, class1Score, class2Score];
189-
const maxScore = Math.max(...scores);
190-
const classId = scores.indexOf(maxScore);
191-
192-
if (maxScore > confidenceThreshold) {
193-
detections.push({
194-
box: [
195-
cx / TARGET_SIZE,
196-
cy / TARGET_SIZE,
197-
w / TARGET_SIZE,
198-
h / TARGET_SIZE,
199-
],
200-
classId,
201-
confidence: maxScore,
202-
});
203-
}
204-
}
205-
206-
const nmsDetections = applyNonMaximumSuppression(detections, IOU_THRESHOLD);
207-
208-
const unsortedFields: DetectedField[] = nmsDetections.map((det) => {
209-
const [cx, cy, w, h] = det.box;
210-
const adjustedH = h * ADJUSTED_HEIGHT_FACTOR;
211-
const x0 = cx - w / 2;
212-
const y0 = cy + h / 2 - adjustedH;
213-
return {
214-
type: CLASS_NAMES[det.classId],
215-
bbox: [x0, y0, w, adjustedH],
216-
confidence: det.confidence,
217-
};
218-
});
219-
220-
const fields = sortFieldsByReadingOrder(unsortedFields);
221-
222-
drawDetections(canvas, fields);
136+
const imageData = finalContext.getImageData(0, 0, TARGET_SIZE, TARGET_SIZE);
223137

224138
return {
225-
fields,
226-
imageData: canvas.toDataURL(),
139+
imageData,
227140
pdfMetadata: {
228141
originalWidth: viewport.width,
229142
originalHeight: viewport.height,
@@ -244,6 +157,7 @@ export const detectFormFields = async (
244157
const startTime = performance.now();
245158

246159
onUpdateDetectionStatus("Loading PDF...");
160+
247161
const arrayBuffer = await pdfFile.arrayBuffer();
248162
const pdf = await pdfjsLib.getDocument({ data: arrayBuffer }).promise;
249163

@@ -257,26 +171,62 @@ export const detectFormFields = async (
257171
`Running form field detection using ${modelName} model...`
258172
);
259173

260-
const session = await ort.InferenceSession.create(modelPath, {
261-
executionProviders: ["wasm"],
262-
});
263-
174+
const worker = new InferenceWorker();
264175
const pages: PageDetectionData[] = [];
265176

266177
for (let pageNum = 1; pageNum <= pdf.numPages; pageNum++) {
267-
onUpdateDetectionStatus(
268-
`Processing page ${pageNum} of ${pdf.numPages}...`
269-
);
178+
onUpdateDetectionStatus(`Processing page ${pageNum} of ${pdf.numPages}...`);
179+
270180
const page = await pdf.getPage(pageNum);
271-
const pageResult = await processPage(page, session, confidenceThreshold);
272-
pages.push(pageResult);
181+
const { imageData, pdfMetadata } = await renderPdfPageToImageData(page);
182+
183+
const inferenceResult = await new Promise<{
184+
fields: DetectedField[];
185+
}>((resolve, reject) => {
186+
const messageHandler = (event: MessageEvent) => {
187+
const { type, data } = event.data;
188+
189+
if (type === "result") {
190+
worker.removeEventListener("message", messageHandler);
191+
if (!data.success) {
192+
reject(new Error(data.error.message));
193+
return;
194+
}
195+
resolve({ fields: data.fields });
196+
}
197+
};
198+
199+
worker.addEventListener("message", messageHandler);
200+
201+
worker.postMessage({
202+
imageDataArray: imageData.data,
203+
imageWidth: imageData.width,
204+
imageHeight: imageData.height,
205+
modelPath,
206+
confidenceThreshold,
207+
isFirstPage: pageNum === 1,
208+
});
209+
});
210+
211+
const canvas = document.createElement("canvas");
212+
canvas.width = TARGET_SIZE;
213+
canvas.height = TARGET_SIZE;
214+
const ctx = canvas.getContext("2d")!;
215+
ctx.putImageData(imageData, 0, 0);
216+
217+
drawDetections(canvas, inferenceResult.fields);
218+
219+
pages.push({
220+
fields: inferenceResult.fields,
221+
imageData: canvas.toDataURL(),
222+
pdfMetadata,
223+
});
273224
}
274225

226+
worker.terminate();
227+
275228
const endTime = performance.now();
276-
const totalFields = pages.reduce(
277-
(sum, page) => sum + page.fields.length,
278-
0
279-
);
229+
const totalFields = pages.reduce((sum, page) => sum + page.fields.length, 0);
280230

281231
return {
282232
success: true,

0 commit comments

Comments
 (0)