Skip to content

Commit 7aa91ed

Browse files
committed
Add KaTeX support for rendering equations and enhance UI for equation display
- Integrated KaTeX for rendering mathematical equations in the neural network playground. - Added a new equation legend section to provide human-readable mappings for network equations. - Updated the UI to include a toggle for displaying model equations and their legends. - Enhanced the layout and styling of the equation panel for better visibility and usability. - Updated package dependencies to include KaTeX and modified build scripts accordingly. Made-with: Cursor
1 parent 02469bd commit 7aa91ed

9 files changed

Lines changed: 552 additions & 28 deletions

File tree

index.html

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@
3636

3737
<meta name="author" content="Daniel Smilkov and Shan Carter">
3838
<title>A Neural Network Playground</title>
39+
<link rel="stylesheet" href="katex.min.css" type="text/css">
3940
<link rel="stylesheet" href="bundle.css" type="text/css">
4041
<link href="https://fonts.googleapis.com/css?family=Roboto:300,400,500|Material+Icons" rel="stylesheet" type="text/css">
4142
<script src="lib.js"></script>
43+
<script src="katex.min.js"></script>
4244
</head>
4345
<body>
4446
<!-- GitHub link -->
@@ -142,6 +144,7 @@ <h1 class="l--page">Tinker With a <b>Neural Network</b> <span class="optional">R
142144

143145
<!-- Main Part -->
144146
<div id="main-part" class="l--page">
147+
<div class="main-part-columns">
145148

146149
<!-- Data Column-->
147150
<div class="column data">
@@ -308,6 +311,16 @@ <h4>Output</h4>
308311
</div>
309312
</div>
310313

314+
</div>
315+
316+
<div id="nn-equation-panel" class="nn-equation-panel ui-equationPanel">
317+
<div class="nn-equation-header">
318+
<h5 class="nn-equation-title">Model equations</h5>
319+
</div>
320+
<div id="nn-equation" class="nn-equation"></div>
321+
<div id="nn-equation-legend" class="nn-equation-legend" aria-label="Equation symbol reference"></div>
322+
</div>
323+
311324
</div>
312325

313326
<!-- More -->

package-lock.json

Lines changed: 39 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"scripts": {
77
"clean": "rimraf dist",
88
"start": "npm run serve-watch",
9-
"prep": "copyfiles analytics.js dist && concat node_modules/material-design-lite/material.min.js node_modules/seedrandom/seedrandom.min.js > dist/lib.js",
9+
"prep": "copyfiles analytics.js dist && copyfiles -u 3 node_modules/katex/dist/katex.min.css dist && copyfiles -u 3 node_modules/katex/dist/katex.min.js dist && copyfiles -u 3 \"node_modules/katex/dist/fonts/*.woff2\" dist && concat node_modules/material-design-lite/material.min.js node_modules/seedrandom/seedrandom.min.js > dist/lib.js",
1010
"build-css": "concat node_modules/material-design-lite/material.min.css styles.css > dist/bundle.css",
1111
"watch-css": "concat node_modules/material-design-lite/material.min.css styles.css -o dist/bundle.css",
1212
"build-html": "copyfiles index.html dist",
@@ -32,6 +32,7 @@
3232
},
3333
"dependencies": {
3434
"d3": "^3.5.16",
35+
"katex": "^0.16.11",
3536
"material-design-lite": "^1.3.0",
3637
"seedrandom": "^2.4.3"
3738
}

src/equation-legend.ts

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/* Copyright 2016 Google Inc. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
import * as nn from "./nn";
17+
18+
/** e.g. 6 -> "6th", 21 -> "21st" (for hidden-layer column wording). */
19+
function ordinalSuffix(n: number): string {
20+
let mod100 = n % 100;
21+
if (mod100 >= 11 && mod100 <= 13) {
22+
return n + "th";
23+
}
24+
switch (n % 10) {
25+
case 1:
26+
return n + "st";
27+
case 2:
28+
return n + "nd";
29+
case 3:
30+
return n + "rd";
31+
default:
32+
return n + "th";
33+
}
34+
}
35+
36+
export interface LegendRow {
37+
/** Small KaTeX fragment; omit for note-only rows. */
38+
symbolTex?: string;
39+
detail: string;
40+
}
41+
42+
export interface LegendSection {
43+
title: string;
44+
rows: LegendRow[];
45+
}
46+
47+
const FEATURE_GUIDE: {[id: string]: string} = {
48+
"x": "Horizontal coordinate of each data point (matches a node in the Features column, top to bottom).",
49+
"y": "Vertical coordinate of each data point.",
50+
"xSquared": "Square of the horizontal coordinate.",
51+
"ySquared": "Square of the vertical coordinate.",
52+
"xTimesY": "Product of horizontal and vertical coordinates.",
53+
"sinX": "Sine of the horizontal coordinate.",
54+
"sinY": "Sine of the vertical coordinate.",
55+
"cosX": "Cosine of the horizontal coordinate.",
56+
"cosY": "Cosine of the vertical coordinate."
57+
};
58+
59+
/**
60+
* Human-readable mapping between equation symbols and the diagram (columns and
61+
* neuron order top-to-bottom) for the layered equation view.
62+
*/
63+
export function buildEquationLegendSections(
64+
network: nn.Node[][], inputIds: string[], inputSymbols: string[],
65+
hiddenActivationSummary: string,
66+
outputActivationSummary: string): LegendSection[] {
67+
let sections: LegendSection[] = [];
68+
69+
if (network == null || network.length < 2) {
70+
return sections;
71+
}
72+
73+
let inputRows: LegendRow[] = [];
74+
for (let i = 0; i < inputIds.length; i++) {
75+
let id = inputIds[i];
76+
let guide = FEATURE_GUIDE[id] != null ?
77+
FEATURE_GUIDE[id] :
78+
"Enabled input feature in the Features column.";
79+
inputRows.push({
80+
symbolTex: inputSymbols[i],
81+
detail: guide
82+
});
83+
}
84+
sections.push({
85+
title: "Input layer (left column in the diagram)",
86+
rows: inputRows
87+
});
88+
89+
let numHidden = network.length - 2;
90+
for (let layerIdx = 1; layerIdx <= numHidden; layerIdx++) {
91+
let layer = network[layerIdx];
92+
let colOrdinal = layerIdx === 1 ? "first" :
93+
layerIdx === 2 ? "second" :
94+
layerIdx === 3 ? "third" :
95+
layerIdx === 4 ? "fourth" :
96+
layerIdx === 5 ? "fifth" : ordinalSuffix(layerIdx);
97+
let hiddenRows: LegendRow[] = [];
98+
for (let i = 0; i < layer.length; i++) {
99+
let node = layer[i];
100+
let ordinal = i + 1;
101+
hiddenRows.push({
102+
symbolTex: "h_{" + ordinal + "}^{(" + layerIdx + ")}",
103+
detail: "Neuron " + ordinal + " from the top in the " + colOrdinal +
104+
" hidden column after inputs (square node id " + node.id +
105+
" in the diagram)."
106+
});
107+
}
108+
sections.push({
109+
title: "Hidden layer " + layerIdx + " (" + colOrdinal +
110+
" column of weighted neurons; " + hiddenActivationSummary + ")",
111+
rows: hiddenRows
112+
});
113+
}
114+
115+
sections.push({
116+
title: "Output layer (rightmost column)",
117+
rows: [{
118+
symbolTex: "\\hat{y}",
119+
detail: "Network output after the output activation (" +
120+
outputActivationSummary + ")."
121+
}]
122+
});
123+
124+
return sections;
125+
}

src/network-equation.ts

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/* Copyright 2016 Google Inc. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
import * as nn from "./nn";
17+
18+
/** Formats a scalar for LaTeX (fixed decimals, trim negative zero). */
19+
function fmt(n: number): string {
20+
let s = n.toFixed(3);
21+
if (s === "-0.000") {
22+
return "0.000";
23+
}
24+
return s;
25+
}
26+
27+
/** Builds bias + sum_k w_k * symbol_k with clean + / - spacing. */
28+
function affineTex(node: nn.Node, prevTex: string[]): string {
29+
let parts: string[] = [fmt(node.bias)];
30+
for (let j = 0; j < node.inputLinks.length; j++) {
31+
parts.push(fmt(node.inputLinks[j].weight) + " \\cdot " + prevTex[j]);
32+
}
33+
return parts.join(" + ").replace(/\+ -/g, "- ");
34+
}
35+
36+
/** Maps activation keys (as in state.activations) to LaTeX. */
37+
function actWrap(activationKey: string, body: string): string {
38+
switch (activationKey) {
39+
case "relu":
40+
return "\\operatorname{ReLU}\\left(" + body + "\\right)";
41+
case "tanh":
42+
return "\\tanh\\left(" + body + "\\right)";
43+
case "sigmoid":
44+
return "\\sigma\\left(" + body + "\\right)";
45+
case "linear":
46+
return body;
47+
default:
48+
return "\\left(" + body + "\\right)";
49+
}
50+
}
51+
52+
function sigmaLegendRow(): string {
53+
return "&\\text{with } \\sigma(t)=\\dfrac{1}{1+e^{-t}}";
54+
}
55+
56+
function needsSigmaLegend(hiddenKey: string, outputKey: string): boolean {
57+
return hiddenKey === "sigmoid" || outputKey === "sigmoid";
58+
}
59+
60+
/**
61+
* Layer-wise definition: equations reference h_j^{(l)} symbols.
62+
*/
63+
function buildLayeredTex(
64+
network: nn.Node[][], inputSymbols: string[],
65+
hiddenActivationKey: string, outputActivationKey: string): string {
66+
let lines: string[] = [];
67+
let prevTex = inputSymbols.slice();
68+
69+
for (let layerIdx = 1; layerIdx < network.length; layerIdx++) {
70+
let layer = network[layerIdx];
71+
let isOutput = layerIdx === network.length - 1;
72+
let actKey = isOutput ? outputActivationKey : hiddenActivationKey;
73+
let nextPrev: string[] = [];
74+
75+
for (let i = 0; i < layer.length; i++) {
76+
let node = layer[i];
77+
let inner = affineTex(node, prevTex);
78+
let rhs = actWrap(actKey, inner);
79+
80+
if (isOutput) {
81+
lines.push("\\hat{y} &= " + rhs);
82+
} else {
83+
let h = "h_{" + (i + 1) + "}^{(" + layerIdx + ")}";
84+
nextPrev.push(h);
85+
lines.push(h + " &= " + rhs);
86+
}
87+
}
88+
if (!isOutput) {
89+
prevTex = nextPrev;
90+
}
91+
}
92+
93+
if (needsSigmaLegend(hiddenActivationKey, outputActivationKey)) {
94+
lines.push(sigmaLegendRow());
95+
}
96+
return "\\begin{aligned}\n" + lines.join("\\\\\n") + "\n\\end{aligned}";
97+
}
98+
99+
/**
100+
* Layer-wise MLP formula with numeric weights and biases.
101+
* inputSymbols must match network[0] order (same as constructInput / buildNetwork).
102+
*/
103+
export function buildNetworkEquationTex(
104+
network: nn.Node[][], inputSymbols: string[],
105+
hiddenActivationKey: string, outputActivationKey: string): string {
106+
if (network == null || network.length < 2) {
107+
return "";
108+
}
109+
if (inputSymbols.length !== network[0].length) {
110+
return "";
111+
}
112+
return buildLayeredTex(
113+
network, inputSymbols, hiddenActivationKey, outputActivationKey);
114+
}

src/nn.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,9 @@ export function backProp(network: Node[][], target: number,
321321
// Compute the error derivative with respect to each node's output.
322322
node.outputDer = 0;
323323
for (let j = 0; j < node.outputs.length; j++) {
324-
let output = node.outputs[j];
325-
node.outputDer += output.weight * output.dest.inputDer;
324+
let link = node.outputs[j];
325+
let dest = link.dest;
326+
node.outputDer += link.weight * dest.inputDer;
326327
}
327328
}
328329
}

0 commit comments

Comments
 (0)