Skip to content

Commit 35d328b

Browse files
authored
Merge pull request #17 from ibelem/softmax-4-d-update-2
Playground - softmax() on N-D tensors - update 2
2 parents cd3fd81 + 9b1fe06 commit 35d328b

1 file changed

Lines changed: 75 additions & 37 deletions

File tree

app/[lang]/playground/editor-files-webnn.js

Lines changed: 75 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -976,53 +976,82 @@ button {
976976
"static": {
977977
'/webnn.js': {
978978
active: true,
979-
code: `async function webnn() {
979+
code: `let resultArray = null;
980+
let currentIndex = 0;
981+
let chunkSize = 512;
982+
983+
async function webnn() {
980984
// Compute softmax() on N-D tensors (N > 2)
981-
const context = await navigator.ml.createContext();
982-
await navigator.ml.createContext({deviceType: 'cpu'});
985+
const context = await navigator.ml.createContext({ deviceType: 'gpu' });
983986
const builder = new MLGraphBuilder(context);
984-
const input = builder.input('input', {dataType: 'float32', shape: [1, 5, 1024, 1024]});
987+
const input = builder.input('input', { dataType: 'float32', shape: [1, 5, 1024, 1024] });
985988
const out = builder.softmax(input, 1);
986-
const graph = await builder.build({'output': out});
987-
const [inputTensor, outputTensor] = await Promise.all([
988-
context.createTensor({dataType: input.dataType, shape: input.shape, writable: true}),
989-
context.createTensor({dataType: input.dataType, shape: input.shape, readable: true})
989+
const graph = await builder.build({ 'output': out });
990+
991+
const [inputTensor, outputTensor] = await Promise.all([
992+
context.createTensor({ dataType: input.dataType, shape: input.shape, writable: true }),
993+
context.createTensor({ dataType: input.dataType, shape: input.shape, readable: true })
990994
]);
991-
992-
context.writeTensor(inputTensor, new Float32Array(1*5*1024*1024).fill(1.0));
993-
994-
const inputs = {'input': inputTensor};
995-
const outputs = {'output': outputTensor};
996-
context.dispatch(graph, inputs, outputs);
995+
996+
const size = 1 * 5 * 1024 * 1024;
997+
const randomData = new Float32Array(size);
998+
for (let i = 0; i < size; i++) {
999+
randomData[i] = Math.random();
1000+
}
1001+
context.writeTensor(inputTensor, randomData);
1002+
1003+
const inputs = { 'input': inputTensor };
1004+
const outputs = { 'output': outputTensor };
1005+
await context.dispatch(graph, inputs, outputs);
9971006
const result = await context.readTensor(outputTensor);
9981007
return new Float32Array(result);
9991008
}
1000-
1001-
// Debounce utility to prevent rapid successive clicks
1002-
function debounce(func, wait) {
1003-
let timeout;
1004-
return function (...args) {
1005-
clearTimeout(timeout);
1006-
timeout = setTimeout(() => func(...args), wait);
1007-
};
1009+
1010+
function showChunk() {
1011+
const output = document.querySelector("#output");
1012+
if (!resultArray) return;
1013+
const end = Math.min(currentIndex + chunkSize, resultArray.length);
1014+
const chunk = Array.from(resultArray.slice(currentIndex, end));
1015+
output.textContent += (currentIndex === 0 ? "" : "\n") + chunk.join(', ');
1016+
currentIndex = end;
1017+
// Hide button if all data is shown
1018+
if (currentIndex >= resultArray.length) {
1019+
document.getElementById("showMoreBtn").style.display = "none";
1020+
}
10081021
}
10091022
1010-
// Event listener with debouncing
1011-
document.querySelector("#run").addEventListener(
1012-
"click",
1013-
debounce(async () => {
1014-
const output = document.querySelector("#output");
1015-
output.textContent = "Inferencing...";
1016-
try {
1017-
const { result } = await webnn(); // Use default values
1018-
console.log(result);
1019-
output.innerHTML = "Output value:" + result;
1020-
} catch (error) {
1021-
console.log(error.message);
1022-
output.textContent = "Error: " + error.message;
1023+
async function main() {
1024+
const output = document.querySelector("#output");
1025+
const btn = document.getElementById("showMoreBtn");
1026+
const chunkInput = document.getElementById("chunkInput");
1027+
chunkSize = parseInt(chunkInput.value, 10) || 256;
1028+
output.textContent = "Inferencing...";
1029+
btn.style.display = "none";
1030+
try {
1031+
resultArray = await webnn();
1032+
currentIndex = 0;
1033+
output.textContent = "";
1034+
showChunk();
1035+
if (resultArray.length > chunkSize) {
1036+
btn.style.display = "inline-block";
10231037
}
1024-
}, 300) // 300ms debounce
1025-
);`
1038+
} catch (error) {
1039+
output.textContent = "Error: " + error.message;
1040+
}
1041+
}
1042+
1043+
document.getElementById("showMoreBtn").addEventListener("click", showChunk);
1044+
document.getElementById("chunkInput").addEventListener("change", function () {
1045+
chunkSize = parseInt(this.value, 10) || 256;
1046+
currentIndex = 0;
1047+
document.querySelector("#output").textContent = "";
1048+
showChunk();
1049+
if (resultArray && resultArray.length > chunkSize) {
1050+
document.getElementById("showMoreBtn").style.display = "inline-block";
1051+
}
1052+
});
1053+
document.addEventListener("DOMContentLoaded", main, false);
1054+
`
10261055
},
10271056
'/index.html': {
10281057
code: `<!DOCTYPE html>
@@ -1037,7 +1066,10 @@ document.querySelector("#run").addEventListener(
10371066
10381067
<body>
10391068
<h1>Compute softmax() on N-D tensors (N > 2)</h1>
1069+
<label for="chunkInput">Chunk size:</label>
1070+
<input type="number" id="chunkInput" value="512" min="1" style="width:80px;">
10401071
<div id="output"></div>
1072+
<button id="showMore" style="display:none;">Show next chunk</button>
10411073
<script src="./webnn.js"></script>
10421074
</body>
10431075
@@ -1050,10 +1082,16 @@ document.querySelector("#run").addEventListener(
10501082
10511083
h1 {
10521084
color: #E44D26;
1085+
font-size: 0.8rem;
10531086
}
10541087
10551088
button {
10561089
margin: 0.5rem 0;
1090+
padding: 0.5rem 1rem;
1091+
}
1092+
1093+
#output {
1094+
font-size: 0.6rem;
10571095
}`}
10581096
},
10591097
},

0 commit comments

Comments
 (0)