Skip to content

Commit ea427a9

Browse files
committed
add code boxes to all elementwise scenes in chapter 6.
1 parent d0a65cc commit ea427a9

11 files changed

Lines changed: 460 additions & 222 deletions

File tree

Chapter6/Scene1.py

Lines changed: 23 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
load_dotenv()
88

99
import numpy as np
10-
from scene_utils import setup_scene, create_sparse_matrix, create_small_graph_from_matrix
10+
from scene_utils import setup_scene, create_sparse_matrix, create_square_digraph
1111

1212

1313
# Shared data from notebook
@@ -17,8 +17,9 @@
1717
# Graph B: edges (1→2), (2→3), (3→0)
1818
B_DATA = [[0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0]]
1919

20-
# Union result: edges (0→1), (1→2), (2→3), (3→0)
21-
UNION_DATA = [[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0]]
20+
# Union result with binary.plus: edges sum where both exist
21+
# (0→1)=1, (1→2)=2, (2→3)=2, (3→0)=1
22+
UNION_DATA = [[0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 2], [1, 0, 0, 0]]
2223

2324
# Node positions (square layout)
2425
POS = {0: np.array([0, 1, 0]), 1: np.array([1, 1, 0]), 2: np.array([1, 0, 0]), 3: np.array([0, 0, 0])}
@@ -41,7 +42,7 @@ def construct(self):
4142
self.play(Write(title))
4243

4344
# Create Graph A (left, blue)
44-
graph_a = self.create_graph(A_DATA, BLUE)
45+
graph_a = create_square_digraph(A_DATA, BLUE)
4546
graph_a.scale(1.2).shift(LEFT * 4 + DOWN * 0.5)
4647
label_a = Text("Graph A", font_size=24, color=BLUE).next_to(graph_a, UP)
4748

@@ -50,7 +51,7 @@ def construct(self):
5051
mat_a.next_to(graph_a, DOWN, buff=0.5)
5152

5253
# Create Graph B (right, green)
53-
graph_b = self.create_graph(B_DATA, GREEN)
54+
graph_b = create_square_digraph(B_DATA, GREEN)
5455
graph_b.scale(1.2).shift(RIGHT * 4 + DOWN * 0.5)
5556
label_b = Text("Graph B", font_size=24, color=GREEN).next_to(graph_b, UP)
5657

@@ -93,31 +94,35 @@ def construct(self):
9394
self.wait(1)
9495

9596
with self.voiceover(
96-
"""When we compute eWiseAdd, we get all edges from both: 0 to 1 from A,
97-
1 to 2 and 2 to 3 from both, and 3 to 0 from B. The result has four
98-
distinct edge positions representing the union."""
97+
"""When we compute eWiseAdd with binary plus, we get all edges from both
98+
graphs. Edge 0 to 1 comes from A, edge 3 to 0 comes from B. The shared
99+
edges 1 to 2 and 2 to 3 sum to 2 in the result."""
99100
):
100-
# Show formula
101-
formula = MathTex(r"A \cup B", font_size=48, color=YELLOW)
102-
formula.move_to(ORIGIN + UP * 0.5)
103-
self.play(Write(formula))
101+
# Show Python code below title
102+
code = Code(
103+
code_string="A.ewise_add(B, binary.plus)",
104+
language="python",
105+
background="window",
106+
).scale(0.7)
107+
code.next_to(title, DOWN, buff=0.3)
108+
self.play(FadeIn(code))
104109
self.wait(0.5)
105110

106111
# Fade out highlights and prepare for result
107112
self.play(FadeOut(common_highlight_a), FadeOut(common_highlight_b))
108113

109-
# Create union result graph (center, yellow)
110-
graph_union = self.create_graph(UNION_DATA, YELLOW)
114+
# Create union result graph (center, yellow) with edge weights
115+
graph_union = create_square_digraph(UNION_DATA, YELLOW, show_weights=True)
111116
graph_union.scale(1.2).move_to(ORIGIN + DOWN * 0.5)
112117
label_union = Text("A ∪ B", font_size=24, color=YELLOW).next_to(graph_union, UP)
113118

114119
# Create union matrix
115120
mat_union = create_sparse_matrix(UNION_DATA, scale=0.5)
116121
mat_union.next_to(graph_union, DOWN, buff=0.5)
117122

118-
# Animate combining: fade formula, bring in result
123+
# Animate bringing in result
119124
self.play(
120-
ReplacementTransform(formula, label_union),
125+
Write(label_union),
121126
Create(graph_union),
122127
FadeIn(mat_union),
123128
)
@@ -131,48 +136,14 @@ def construct(self):
131136

132137
# Cleanup
133138
self.play(
134-
FadeOut(title), FadeOut(graph_a), FadeOut(label_a), FadeOut(mat_a),
139+
FadeOut(title), FadeOut(code),
140+
FadeOut(graph_a), FadeOut(label_a), FadeOut(mat_a),
135141
FadeOut(graph_b), FadeOut(label_b), FadeOut(mat_b),
136142
FadeOut(graph_union), FadeOut(label_union), FadeOut(mat_union),
137143
FadeOut(edge_count),
138144
)
139145
self.wait(0.5)
140146

141-
def create_graph(self, matrix_data, color):
142-
"""Create a directed graph from adjacency matrix with square layout."""
143-
n = len(matrix_data)
144-
positions = {
145-
0: np.array([-0.7, 0.7, 0]),
146-
1: np.array([0.7, 0.7, 0]),
147-
2: np.array([0.7, -0.7, 0]),
148-
3: np.array([-0.7, -0.7, 0]),
149-
}
150-
151-
# Create vertices
152-
vertices = {}
153-
for i in range(n):
154-
label = MathTex(str(i), color=BLACK).scale(0.5)
155-
dot = LabeledDot(label, radius=0.2, fill_color=WHITE, fill_opacity=1)
156-
dot.move_to(positions[i])
157-
vertices[i] = dot
158-
159-
# Create edges
160-
edges = VGroup()
161-
for i in range(n):
162-
for j in range(n):
163-
if matrix_data[i][j] != 0:
164-
arrow = Arrow(
165-
positions[i], positions[j],
166-
color=color, buff=0.25, stroke_width=3,
167-
tip_length=0.15, max_tip_length_to_length_ratio=0.25
168-
)
169-
edges.add(arrow)
170-
171-
graph = VGroup(edges, *vertices.values())
172-
graph.vertices = vertices
173-
graph.edges = edges
174-
return graph
175-
176147
def get_nearest_node(self, pos):
177148
"""Get the nearest node index for a position."""
178149
positions = {

Chapter6/Scene2.py

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
load_dotenv()
88

99
import numpy as np
10-
from scene_utils import setup_scene, create_sparse_matrix
10+
from scene_utils import setup_scene, create_sparse_matrix, create_square_digraph
1111

1212

1313
# Shared data from notebook
@@ -37,7 +37,7 @@ def construct(self):
3737
self.play(Write(title))
3838

3939
# Create Graph A (left, blue)
40-
graph_a = self.create_graph(A_DATA, BLUE)
40+
graph_a = create_square_digraph(A_DATA, BLUE)
4141
graph_a.scale(1.2).shift(LEFT * 4 + DOWN * 0.5)
4242
label_a = Text("Graph A", font_size=24, color=BLUE).next_to(graph_a, UP)
4343

@@ -46,7 +46,7 @@ def construct(self):
4646
mat_a.next_to(graph_a, DOWN, buff=0.5)
4747

4848
# Create Graph B (right, green)
49-
graph_b = self.create_graph(B_DATA, GREEN)
49+
graph_b = create_square_digraph(B_DATA, GREEN)
5050
graph_b.scale(1.2).shift(RIGHT * 4 + DOWN * 0.5)
5151
label_b = Text("Graph B", font_size=24, color=GREEN).next_to(graph_b, UP)
5252

@@ -61,14 +61,18 @@ def construct(self):
6161
self.wait(1)
6262

6363
with self.voiceover(
64-
"""The result contains just two edges: 1 to 2 and 2 to 3. Contrast this
65-
with union which gave us four edges. Intersection keeps only what
66-
both graphs share."""
64+
"""The result contains just two edges: 1 to 2 and 2 to 3, each with
65+
value 1 times 1 equals 1. Contrast this with union which gave us four
66+
edges. Intersection keeps only what both graphs share."""
6767
):
68-
# Show formula
69-
formula = MathTex(r"A \cap B", font_size=48, color=RED_C)
70-
formula.move_to(ORIGIN + UP * 0.5)
71-
self.play(Write(formula))
68+
# Show Python code below title
69+
code = Code(
70+
code_string="A.ewise_mult(B, binary.times)",
71+
language="python",
72+
background="window",
73+
).scale(0.7)
74+
code.next_to(title, DOWN, buff=0.3)
75+
self.play(FadeIn(code))
7276
self.wait(0.5)
7377

7478
# Highlight common edges and fade unique ones
@@ -110,8 +114,8 @@ def construct(self):
110114
)
111115
self.wait(0.5)
112116

113-
# Create intersection result graph (center, red)
114-
graph_int = self.create_graph(INTERSECTION_DATA, RED_C)
117+
# Create intersection result graph (center, red) with edge weights
118+
graph_int = create_square_digraph(INTERSECTION_DATA, RED_C, show_weights=True)
115119
graph_int.scale(1.2).move_to(ORIGIN + DOWN * 0.5)
116120
label_int = Text("A ∩ B", font_size=24, color=RED_C).next_to(graph_int, UP)
117121

@@ -120,7 +124,7 @@ def construct(self):
120124
mat_int.next_to(graph_int, DOWN, buff=0.5)
121125

122126
self.play(
123-
ReplacementTransform(formula, label_int),
127+
Write(label_int),
124128
Create(graph_int),
125129
FadeIn(mat_int),
126130
)
@@ -138,48 +142,14 @@ def construct(self):
138142

139143
# Cleanup
140144
self.play(
141-
FadeOut(title), FadeOut(graph_a), FadeOut(label_a), FadeOut(mat_a),
145+
FadeOut(title), FadeOut(code),
146+
FadeOut(graph_a), FadeOut(label_a), FadeOut(mat_a),
142147
FadeOut(graph_b), FadeOut(label_b), FadeOut(mat_b),
143148
FadeOut(graph_int), FadeOut(label_int), FadeOut(mat_int),
144149
FadeOut(comparison),
145150
)
146151
self.wait(0.5)
147152

148-
def create_graph(self, matrix_data, color):
149-
"""Create a directed graph from adjacency matrix with square layout."""
150-
n = len(matrix_data)
151-
positions = {
152-
0: np.array([-0.7, 0.7, 0]),
153-
1: np.array([0.7, 0.7, 0]),
154-
2: np.array([0.7, -0.7, 0]),
155-
3: np.array([-0.7, -0.7, 0]),
156-
}
157-
158-
# Create vertices
159-
vertices = {}
160-
for i in range(n):
161-
label = MathTex(str(i), color=BLACK).scale(0.5)
162-
dot = LabeledDot(label, radius=0.2, fill_color=WHITE, fill_opacity=1)
163-
dot.move_to(positions[i])
164-
vertices[i] = dot
165-
166-
# Create edges
167-
edges = VGroup()
168-
for i in range(n):
169-
for j in range(n):
170-
if matrix_data[i][j] != 0:
171-
arrow = Arrow(
172-
positions[i], positions[j],
173-
color=color, buff=0.25, stroke_width=3,
174-
tip_length=0.15, max_tip_length_to_length_ratio=0.25
175-
)
176-
edges.add(arrow)
177-
178-
graph = VGroup(edges, *vertices.values())
179-
graph.vertices = vertices
180-
graph.edges = edges
181-
return graph
182-
183153
def get_nearest_node(self, pos):
184154
"""Get the nearest node index for a position."""
185155
positions = {

Chapter6/Scene3.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,25 @@ def construct(self):
3434

3535
# Create weighted graph
3636
graph = self.create_weighted_graph(W_DATA)
37-
graph.scale(1.5).shift(LEFT * 3 + DOWN * 0.3)
37+
graph.scale(1.5).shift(LEFT * 3 + UP * 0.3)
3838
label = Text("Weighted Graph W", font_size=24).next_to(graph, UP)
3939

4040
# Create matrix
4141
mat = create_sparse_matrix(W_DATA, scale=0.55)
42-
mat.shift(RIGHT * 3 + UP * 0.5)
42+
mat.shift(RIGHT * 3 + UP * 1.0)
4343
mat_label = Text("W", font_size=28).next_to(mat, UP)
4444

4545
self.play(Create(graph), Write(label), FadeIn(mat), Write(mat_label))
4646
self.wait(0.5)
4747

48-
# Show threshold condition
49-
threshold = MathTex(r"\text{weight} > 3", font_size=36, color=YELLOW)
50-
threshold.next_to(mat, DOWN, buff=0.5)
51-
self.play(Write(threshold))
48+
# Show Python code for select
49+
code = Code(
50+
code_string='W.select(">", 3)',
51+
language="python",
52+
background="window",
53+
).scale(0.7)
54+
code.next_to(mat, DOWN, buff=0.4)
55+
self.play(FadeIn(code))
5256
self.wait(1)
5357

5458
with self.voiceover(
@@ -110,7 +114,7 @@ def construct(self):
110114

111115
# Show result matrix
112116
result_mat = create_sparse_matrix(FILTERED_DATA, scale=0.55)
113-
result_mat.next_to(threshold, DOWN, buff=0.5)
117+
result_mat.next_to(code, DOWN, buff=0.4)
114118
result_label = Text("Result: 2 edges", font_size=20, color=GREEN)
115119
result_label.next_to(result_mat, DOWN, buff=0.3)
116120

@@ -120,7 +124,7 @@ def construct(self):
120124
# Cleanup
121125
self.play(
122126
FadeOut(title), FadeOut(graph), FadeOut(label),
123-
FadeOut(mat), FadeOut(mat_label), FadeOut(threshold),
127+
FadeOut(mat), FadeOut(mat_label), FadeOut(code),
124128
FadeOut(result_mat), FadeOut(result_label),
125129
)
126130
self.wait(0.5)

Chapter6/Scene4.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,20 @@ def construct(self):
3434

3535
# Create original matrix
3636
mat_m = create_sparse_matrix(M_DATA, scale=0.6)
37-
mat_m.shift(LEFT * 3)
37+
mat_m.shift(LEFT * 4.5)
3838
mat_label = MathTex("M", font_size=36).next_to(mat_m, UP, buff=0.3)
3939

4040
self.play(FadeIn(mat_m), Write(mat_label))
4141
self.wait(0.5)
4242

43-
# Show formula
44-
formula = MathTex(r"\sqrt{M}", font_size=48, color=GREEN)
45-
formula.move_to(ORIGIN)
46-
self.play(Write(formula))
43+
# Show Python code
44+
code = Code(
45+
code_string="M.apply(unary.sqrt)",
46+
language="python",
47+
background="window",
48+
).scale(0.7)
49+
code.move_to(ORIGIN + UP * 0.5)
50+
self.play(FadeIn(code))
4751
self.wait(0.5)
4852

4953
with self.voiceover(
@@ -53,8 +57,8 @@ def construct(self):
5357
):
5458
# Create result matrix
5559
mat_result = create_sparse_matrix(SQRT_DATA, scale=0.6)
56-
mat_result.shift(RIGHT * 3)
57-
result_label = MathTex(r"\sqrt{M}", font_size=36, color=GREEN)
60+
mat_result.shift(RIGHT * 4.5)
61+
result_label = Text("Result", font_size=24, color=GREEN)
5862
result_label.next_to(mat_result, UP, buff=0.3)
5963

6064
# Animate sqrt transformation for each non-zero entry
@@ -107,7 +111,7 @@ def construct(self):
107111
# Cleanup
108112
self.play(
109113
FadeOut(title), FadeOut(mat_m), FadeOut(mat_label),
110-
FadeOut(formula), FadeOut(mat_result), FadeOut(result_label),
114+
FadeOut(code), FadeOut(mat_result), FadeOut(result_label),
111115
FadeOut(note),
112116
)
113117
self.wait(0.5)

0 commit comments

Comments
 (0)