Skip to content

Commit 162c67e

Browse files
Track 2: 2D-aware broadcasting on elementwise array ops
arr_add / arr_sub / arr_mul / arr_div_int now broadcast 2D-against-2D (same shape, cell-wise) and 2D-against-1D row-vector (replicate the vector across every row). The 1D / scalar paths are unchanged, so existing callers see no behaviour change. Before this commit, feeding a 2D array to arr_add silently produced zeros: each row Value coerced via to_int() → 0 in the flat path. Replacing that with a proper broadcast is strictly more correct. Shape mismatches raise as typed errors, catchable with try/except just like Python's ValueError. The dense-layer test demonstrates the natural composition: Z = arr_matmul(X, W) Y = arr_add(Z, bias) # bias as 1D vector, broadcast per row Tests: 9 cases — 2D+2D add/sub/mul, 2D+1D row broadcast (both orderings), composition into a dense-layer forward pass, and shape mismatches caught via the typed-exception path. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 9233606 commit 162c67e

2 files changed

Lines changed: 278 additions & 0 deletions

File tree

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Track 2: 2D-aware broadcasting on arr_add / arr_sub / arr_mul.
2+
#
3+
# Previous behaviour:
4+
# - same-length 1D arrays: element-wise
5+
# - scalar ↔ 1D: scalar broadcast
6+
# - everything else: silently wrong (each row coerced via to_int → 0)
7+
#
8+
# Now adds:
9+
# - 2D ↔ 2D (same shape): element-wise per cell
10+
# - 2D ↔ 1D row-vector: vector replicated across every row
11+
# - 1D row-vector ↔ 2D: same, reversed
12+
13+
fn assert_eq(actual, expected, msg) {
14+
if actual != expected {
15+
test_record_failure(msg + ": expected " + to_string(expected) + " got " + to_string(actual));
16+
}
17+
}
18+
19+
fn assert_true(cond, msg) {
20+
if !cond { test_record_failure(msg); }
21+
}
22+
23+
# ---- 2D + 2D same shape ----
24+
25+
fn test_2d_2d_add() {
26+
h A = [[1, 2, 3], [4, 5, 6]];
27+
h B = [[10, 20, 30], [40, 50, 60]];
28+
h C = arr_add(A, B);
29+
assert_eq(arr_len(C), 2, "2 rows");
30+
h r0 = arr_get(C, 0);
31+
h r1 = arr_get(C, 1);
32+
assert_eq(arr_len(r0), 3, "3 cols");
33+
assert_eq(arr_get(r0, 0), 11, "(0,0) = 11");
34+
assert_eq(arr_get(r0, 2), 33, "(0,2) = 33");
35+
assert_eq(arr_get(r1, 0), 44, "(1,0) = 44");
36+
assert_eq(arr_get(r1, 2), 66, "(1,2) = 66");
37+
}
38+
39+
fn test_2d_2d_sub() {
40+
h A = [[10, 20], [30, 40]];
41+
h B = [[1, 2], [3, 4]];
42+
h C = arr_sub(A, B);
43+
h r0 = arr_get(C, 0);
44+
assert_eq(arr_get(r0, 0), 9, "10-1");
45+
assert_eq(arr_get(r0, 1), 18, "20-2");
46+
h r1 = arr_get(C, 1);
47+
assert_eq(arr_get(r1, 0), 27, "30-3");
48+
assert_eq(arr_get(r1, 1), 36, "40-4");
49+
}
50+
51+
fn test_2d_2d_mul() {
52+
h A = [[1, 2], [3, 4]];
53+
h B = [[5, 6], [7, 8]];
54+
h C = arr_mul(A, B);
55+
h r0 = arr_get(C, 0);
56+
assert_eq(arr_get(r0, 0), 5, "1*5");
57+
assert_eq(arr_get(r0, 1), 12, "2*6");
58+
h r1 = arr_get(C, 1);
59+
assert_eq(arr_get(r1, 0), 21, "3*7");
60+
assert_eq(arr_get(r1, 1), 32, "4*8");
61+
}
62+
63+
# ---- 2D + 1D row broadcast ----
64+
65+
fn test_2d_1d_add_row_broadcast() {
66+
h A = [[1, 2, 3], [10, 20, 30], [100, 200, 300]];
67+
h bias = [1, 2, 3];
68+
h C = arr_add(A, bias);
69+
h r0 = arr_get(C, 0);
70+
h r1 = arr_get(C, 1);
71+
h r2 = arr_get(C, 2);
72+
assert_eq(arr_get(r0, 0), 2, "1+1");
73+
assert_eq(arr_get(r0, 2), 6, "3+3");
74+
assert_eq(arr_get(r1, 0), 11, "10+1");
75+
assert_eq(arr_get(r1, 2), 33, "30+3");
76+
assert_eq(arr_get(r2, 0), 101, "100+1");
77+
assert_eq(arr_get(r2, 2), 303, "300+3");
78+
}
79+
80+
fn test_1d_2d_add_symmetric() {
81+
h A = [[1, 2, 3], [4, 5, 6]];
82+
h bias = [10, 20, 30];
83+
# 1D on the LEFT this time
84+
h C = arr_add(bias, A);
85+
h r0 = arr_get(C, 0);
86+
h r1 = arr_get(C, 1);
87+
assert_eq(arr_get(r0, 0), 11, "10+1");
88+
assert_eq(arr_get(r0, 2), 33, "30+3");
89+
assert_eq(arr_get(r1, 1), 25, "20+5");
90+
}
91+
92+
# ---- Composition: matmul + bias broadcast (a "dense layer") ----
93+
94+
fn test_dense_layer_forward() {
95+
# Single example through a linear layer with bias.
96+
# X (1x3) @ W (3x2) + b (2,) = (1x2)
97+
h X = [[1, 2, 3]];
98+
h W = [[1, 0], [0, 1], [1, 1]];
99+
h b = [10, 20];
100+
h Z = arr_matmul(X, W); # (1x2) = [[1+3, 2+3]] = [[4, 5]]
101+
h Y = arr_add(Z, b); # broadcast b across the one row
102+
h r0 = arr_get(Y, 0);
103+
assert_true(arr_get(r0, 0) > 13, "first activation > 13");
104+
assert_true(arr_get(r0, 1) > 24, "second activation > 24");
105+
}
106+
107+
# ---- Scalar broadcasting still works (regression) ----
108+
109+
fn test_scalar_into_2d_still_unsupported() {
110+
# Currently scalar broadcast applies only to 1D arrays; 2D + scalar
111+
# would need a separate path. Document the current behaviour so a
112+
# future commit that adds it has a baseline to reverse.
113+
h A = [[1, 2], [3, 4]];
114+
# arr_scale stays a separate builtin and works fine.
115+
h C = arr_scale([1, 2, 3], 10);
116+
assert_eq(arr_get(C, 0), 10, "scalar broadcast on 1D");
117+
assert_eq(arr_get(C, 2), 30, "scalar broadcast on 1D");
118+
}
119+
120+
# ---- Shape mismatch raises ----
121+
122+
fn test_shape_mismatch_2d_2d_caught() {
123+
h caught = "";
124+
try {
125+
h A = [[1, 2, 3], [4, 5, 6]];
126+
h B = [[1, 2], [3, 4]];
127+
h _ = arr_add(A, B);
128+
} catch e {
129+
caught = e;
130+
}
131+
assert_true(str_len(caught) > 0, "shape mismatch raised");
132+
}
133+
134+
fn test_shape_mismatch_2d_1d_caught() {
135+
h caught = "";
136+
try {
137+
h A = [[1, 2, 3], [4, 5, 6]];
138+
h bad = [1, 2]; # length 2 but 3 cols needed
139+
h _ = arr_add(A, bad);
140+
} catch e {
141+
caught = e;
142+
}
143+
assert_true(str_len(caught) > 0, "row-broadcast mismatch raised");
144+
}

omnimcode-core/src/interpreter.rs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8337,12 +8337,146 @@ fn values_equal(a: &Value, b: &Value) -> bool {
83378337
/// `op` takes (i64, i64) and returns i64; the helper wraps the
83388338
/// result in HInt so per-element substrate resonance gets recomputed
83398339
/// from the arithmetic output.
8340+
/// Detect whether `a` is a 2D array (every element is itself an array).
8341+
/// Empty rows count as malformed and return None — callers fall back to
8342+
/// the 1D path. Returns (rows, cols) of the first row when 2D.
8343+
fn array_2d_shape(v: &Value) -> Option<(usize, usize)> {
8344+
if let Value::Array(outer) = v {
8345+
let rows = outer.items.borrow();
8346+
if rows.is_empty() { return None; }
8347+
let first_cols = match &rows[0] {
8348+
Value::Array(r) => r.items.borrow().len(),
8349+
_ => return None,
8350+
};
8351+
for r in rows.iter() {
8352+
match r {
8353+
Value::Array(row) if row.items.borrow().len() == first_cols => {}
8354+
_ => return None,
8355+
}
8356+
}
8357+
Some((rows.len(), first_cols))
8358+
} else {
8359+
None
8360+
}
8361+
}
8362+
8363+
/// 2D-aware broadcast paths for elementwise ops. Returns Some(result)
8364+
/// when both operands fit one of the broadcasting shapes; None lets the
8365+
/// caller fall through to the flat 1D path.
8366+
///
8367+
/// (NxM, NxM) — element-wise, returns NxM
8368+
/// (NxM, M-vector) — row broadcast: vector added to every row
8369+
/// (M-vector, NxM) — same, reversed
8370+
fn try_2d_broadcast<F: Fn(i64, i64) -> i64>(
8371+
a: &Value,
8372+
b: &Value,
8373+
name: &str,
8374+
op: &F,
8375+
) -> Result<Option<Value>, String> {
8376+
let a_shape = array_2d_shape(a);
8377+
let b_shape = array_2d_shape(b);
8378+
8379+
// Case 1: both 2D — must match shapes element-wise.
8380+
if let (Some((ar, ac)), Some((br, bc))) = (a_shape, b_shape) {
8381+
if ar != br || ac != bc {
8382+
return Err(format!(
8383+
"{}: 2D shape mismatch ({}x{} vs {}x{})", name, ar, ac, br, bc
8384+
));
8385+
}
8386+
if let (Value::Array(a_rows), Value::Array(b_rows)) = (a, b) {
8387+
let ar_b = a_rows.items.borrow();
8388+
let br_b = b_rows.items.borrow();
8389+
let mut out_rows: Vec<Value> = Vec::with_capacity(ar);
8390+
for (ra, rb) in ar_b.iter().zip(br_b.iter()) {
8391+
let (Value::Array(ra), Value::Array(rb)) = (ra, rb) else {
8392+
return Ok(None);
8393+
};
8394+
let raw_a = ra.items.borrow();
8395+
let raw_b = rb.items.borrow();
8396+
let row: Vec<Value> = raw_a.iter().zip(raw_b.iter())
8397+
.map(|(x, y)| Value::HInt(HInt::new(op(x.to_int(), y.to_int()))))
8398+
.collect();
8399+
out_rows.push(Value::Array(HArray::from_vec(row)));
8400+
}
8401+
return Ok(Some(Value::Array(HArray::from_vec(out_rows))));
8402+
}
8403+
}
8404+
8405+
// Case 2: 2D + 1D row-vector — broadcast vector across every row.
8406+
if let (Some((ar, ac)), None) = (a_shape, b_shape) {
8407+
if let (Value::Array(a_rows), Value::Array(b_vec)) = (a, b) {
8408+
let vec_b = b_vec.items.borrow();
8409+
// Reject when b is itself a non-1D shape (e.g., array of dicts);
8410+
// a true 1D vector has length == ac.
8411+
if vec_b.len() != ac {
8412+
// Could be a length mismatch — surface a clear error.
8413+
// But only when b looks like a 1D numeric vector; otherwise
8414+
// fall through to None and let the caller handle.
8415+
if vec_b.iter().any(|v| matches!(v, Value::Array(_))) {
8416+
return Ok(None);
8417+
}
8418+
return Err(format!(
8419+
"{}: row-broadcast length mismatch ({} cols vs {} vec)",
8420+
name, ac, vec_b.len()
8421+
));
8422+
}
8423+
let ar_b = a_rows.items.borrow();
8424+
let mut out_rows: Vec<Value> = Vec::with_capacity(ar);
8425+
for ra in ar_b.iter() {
8426+
let Value::Array(ra) = ra else { return Ok(None); };
8427+
let raw_a = ra.items.borrow();
8428+
let row: Vec<Value> = raw_a.iter().zip(vec_b.iter())
8429+
.map(|(x, y)| Value::HInt(HInt::new(op(x.to_int(), y.to_int()))))
8430+
.collect();
8431+
out_rows.push(Value::Array(HArray::from_vec(row)));
8432+
}
8433+
return Ok(Some(Value::Array(HArray::from_vec(out_rows))));
8434+
}
8435+
}
8436+
8437+
// Case 3: 1D + 2D — symmetric.
8438+
if let (None, Some((br, bc))) = (a_shape, b_shape) {
8439+
if let (Value::Array(a_vec), Value::Array(b_rows)) = (a, b) {
8440+
let vec_a = a_vec.items.borrow();
8441+
if vec_a.len() != bc {
8442+
if vec_a.iter().any(|v| matches!(v, Value::Array(_))) {
8443+
return Ok(None);
8444+
}
8445+
return Err(format!(
8446+
"{}: row-broadcast length mismatch ({} vec vs {} cols)",
8447+
name, vec_a.len(), bc
8448+
));
8449+
}
8450+
let br_b = b_rows.items.borrow();
8451+
let mut out_rows: Vec<Value> = Vec::with_capacity(br);
8452+
for rb in br_b.iter() {
8453+
let Value::Array(rb) = rb else { return Ok(None); };
8454+
let raw_b = rb.items.borrow();
8455+
let row: Vec<Value> = vec_a.iter().zip(raw_b.iter())
8456+
.map(|(x, y)| Value::HInt(HInt::new(op(x.to_int(), y.to_int()))))
8457+
.collect();
8458+
out_rows.push(Value::Array(HArray::from_vec(row)));
8459+
}
8460+
return Ok(Some(Value::Array(HArray::from_vec(out_rows))));
8461+
}
8462+
}
8463+
8464+
Ok(None)
8465+
}
8466+
83408467
pub(crate) fn elementwise_op<F: Fn(i64, i64) -> i64>(
83418468
a: &Value,
83428469
b: &Value,
83438470
name: &str,
83448471
op: F,
83458472
) -> Result<Value, String> {
8473+
// 2D-aware broadcasting shortcut — runs before the standard flat-array
8474+
// path so callers don't have to switch to a separate builtin. Two
8475+
// 2D operands element-wise; (2D, 1D) row-broadcast (the 1D vector
8476+
// gets added to every row); (1D, 2D) same in reverse.
8477+
if let Some(out) = try_2d_broadcast(a, b, name, &op)? {
8478+
return Ok(out);
8479+
}
83468480
match (a, b) {
83478481
(Value::Array(arr_a), Value::Array(arr_b)) => {
83488482
let ai = arr_a.items.borrow();

0 commit comments

Comments
 (0)