77load_dotenv ()
88
99import numpy as np
10- from scene_utils import setup_scene , create_sparse_matrix , create_small_graph_from_matrix
10+ from scene_utils import setup_scene , create_sparse_matrix , create_square_digraph
1111
1212
1313# Shared data from notebook
1717# Graph B: edges (1→2), (2→3), (3→0)
1818B_DATA = [[0 , 0 , 0 , 0 ], [0 , 0 , 1 , 0 ], [0 , 0 , 0 , 1 ], [1 , 0 , 0 , 0 ]]
1919
20- # Union result: edges (0→1), (1→2), (2→3), (3→0)
21- UNION_DATA = [[0 , 1 , 0 , 0 ], [0 , 0 , 1 , 0 ], [0 , 0 , 0 , 1 ], [1 , 0 , 0 , 0 ]]
20+ # Union result with binary.plus: edges sum where both exist
21+ # (0→1)=1, (1→2)=2, (2→3)=2, (3→0)=1
22+ UNION_DATA = [[0 , 1 , 0 , 0 ], [0 , 0 , 2 , 0 ], [0 , 0 , 0 , 2 ], [1 , 0 , 0 , 0 ]]
2223
2324# Node positions (square layout)
2425POS = {0 : np .array ([0 , 1 , 0 ]), 1 : np .array ([1 , 1 , 0 ]), 2 : np .array ([1 , 0 , 0 ]), 3 : np .array ([0 , 0 , 0 ])}
@@ -41,7 +42,7 @@ def construct(self):
4142 self .play (Write (title ))
4243
4344 # Create Graph A (left, blue)
44- graph_a = self . create_graph (A_DATA , BLUE )
45+ graph_a = create_square_digraph (A_DATA , BLUE )
4546 graph_a .scale (1.2 ).shift (LEFT * 4 + DOWN * 0.5 )
4647 label_a = Text ("Graph A" , font_size = 24 , color = BLUE ).next_to (graph_a , UP )
4748
@@ -50,7 +51,7 @@ def construct(self):
5051 mat_a .next_to (graph_a , DOWN , buff = 0.5 )
5152
5253 # Create Graph B (right, green)
53- graph_b = self . create_graph (B_DATA , GREEN )
54+ graph_b = create_square_digraph (B_DATA , GREEN )
5455 graph_b .scale (1.2 ).shift (RIGHT * 4 + DOWN * 0.5 )
5556 label_b = Text ("Graph B" , font_size = 24 , color = GREEN ).next_to (graph_b , UP )
5657
@@ -93,31 +94,35 @@ def construct(self):
9394 self .wait (1 )
9495
9596 with self .voiceover (
96- """When we compute eWiseAdd, we get all edges from both: 0 to 1 from A,
97- 1 to 2 and 2 to 3 from both, and 3 to 0 from B. The result has four
98- distinct edge positions representing the union ."""
97+ """When we compute eWiseAdd with binary plus , we get all edges from both
98+ graphs. Edge 0 to 1 comes from A, edge 3 to 0 comes from B. The shared
99+ edges 1 to 2 and 2 to 3 sum to 2 in the result ."""
99100 ):
100- # Show formula
101- formula = MathTex (r"A \cup B" , font_size = 48 , color = YELLOW )
102- formula .move_to (ORIGIN + UP * 0.5 )
103- self .play (Write (formula ))
101+ # Show Python code below title
102+ code = Code (
103+ code_string = "A.ewise_add(B, binary.plus)" ,
104+ language = "python" ,
105+ background = "window" ,
106+ ).scale (0.7 )
107+ code .next_to (title , DOWN , buff = 0.3 )
108+ self .play (FadeIn (code ))
104109 self .wait (0.5 )
105110
106111 # Fade out highlights and prepare for result
107112 self .play (FadeOut (common_highlight_a ), FadeOut (common_highlight_b ))
108113
109- # Create union result graph (center, yellow)
110- graph_union = self . create_graph (UNION_DATA , YELLOW )
114+ # Create union result graph (center, yellow) with edge weights
115+ graph_union = create_square_digraph (UNION_DATA , YELLOW , show_weights = True )
111116 graph_union .scale (1.2 ).move_to (ORIGIN + DOWN * 0.5 )
112117 label_union = Text ("A ∪ B" , font_size = 24 , color = YELLOW ).next_to (graph_union , UP )
113118
114119 # Create union matrix
115120 mat_union = create_sparse_matrix (UNION_DATA , scale = 0.5 )
116121 mat_union .next_to (graph_union , DOWN , buff = 0.5 )
117122
118- # Animate combining: fade formula, bring in result
123+ # Animate bringing in result
119124 self .play (
120- ReplacementTransform ( formula , label_union ),
125+ Write ( label_union ),
121126 Create (graph_union ),
122127 FadeIn (mat_union ),
123128 )
@@ -131,48 +136,14 @@ def construct(self):
131136
132137 # Cleanup
133138 self .play (
134- FadeOut (title ), FadeOut (graph_a ), FadeOut (label_a ), FadeOut (mat_a ),
139+ FadeOut (title ), FadeOut (code ),
140+ FadeOut (graph_a ), FadeOut (label_a ), FadeOut (mat_a ),
135141 FadeOut (graph_b ), FadeOut (label_b ), FadeOut (mat_b ),
136142 FadeOut (graph_union ), FadeOut (label_union ), FadeOut (mat_union ),
137143 FadeOut (edge_count ),
138144 )
139145 self .wait (0.5 )
140146
141- def create_graph (self , matrix_data , color ):
142- """Create a directed graph from adjacency matrix with square layout."""
143- n = len (matrix_data )
144- positions = {
145- 0 : np .array ([- 0.7 , 0.7 , 0 ]),
146- 1 : np .array ([0.7 , 0.7 , 0 ]),
147- 2 : np .array ([0.7 , - 0.7 , 0 ]),
148- 3 : np .array ([- 0.7 , - 0.7 , 0 ]),
149- }
150-
151- # Create vertices
152- vertices = {}
153- for i in range (n ):
154- label = MathTex (str (i ), color = BLACK ).scale (0.5 )
155- dot = LabeledDot (label , radius = 0.2 , fill_color = WHITE , fill_opacity = 1 )
156- dot .move_to (positions [i ])
157- vertices [i ] = dot
158-
159- # Create edges
160- edges = VGroup ()
161- for i in range (n ):
162- for j in range (n ):
163- if matrix_data [i ][j ] != 0 :
164- arrow = Arrow (
165- positions [i ], positions [j ],
166- color = color , buff = 0.25 , stroke_width = 3 ,
167- tip_length = 0.15 , max_tip_length_to_length_ratio = 0.25
168- )
169- edges .add (arrow )
170-
171- graph = VGroup (edges , * vertices .values ())
172- graph .vertices = vertices
173- graph .edges = edges
174- return graph
175-
176147 def get_nearest_node (self , pos ):
177148 """Get the nearest node index for a position."""
178149 positions = {
0 commit comments