Skip to content

Commit bbc604a

Browse files
committed
extratrees: Add a missing test file
1 parent 453f350 commit bbc604a

1 file changed

Lines changed: 239 additions & 0 deletions

File tree

tests/test_extratrees_xor.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Final comprehensive XOR test suite
2+
import array
3+
import emlearn_extratrees
4+
5+
def test_xor_comprehensive():
6+
"""Comprehensive XOR test with the fixed algorithm"""
7+
print("=== Comprehensive XOR Test ===")
8+
9+
# XOR training data - repeated for better training
10+
base_pattern = [
11+
(0, 0, 0), # XOR: (0,0) -> 0
12+
(0, 100, 1), # XOR: (0,1) -> 1
13+
(100, 0, 1), # XOR: (1,0) -> 1
14+
(100, 100, 0), # XOR: (1,1) -> 0
15+
]
16+
17+
# Repeat pattern multiple times to give ensemble more training data
18+
X_data = []
19+
y_data = []
20+
for _ in range(8): # 32 samples total
21+
for x1, x2, y in base_pattern:
22+
X_data.extend([x1, x2])
23+
y_data.append(y)
24+
25+
X = array.array('h', X_data)
26+
y = array.array('h', y_data)
27+
28+
print(f"Training data: {len(y_data)} samples (8x XOR pattern)")
29+
30+
# Test with ensemble of trees (now that individual trees work)
31+
model = emlearn_extratrees.new(
32+
2, # n_features
33+
2, # n_classes
34+
10, # n_trees (ensemble)
35+
8, # max_depth
36+
1, # min_samples_leaf
37+
10, # n_thresholds
38+
0.8, # subsample_ratio (80% for diversity)
39+
1.0, # feature_subsample_ratio (use both features)
40+
500, # max_nodes
41+
100, # max_samples
42+
42 # rng_seed
43+
)
44+
45+
model.train(X, y)
46+
47+
print(f"Model: {model.get_n_trees()} trees, {model.get_n_nodes_used()} nodes total")
48+
49+
# Test core XOR patterns
50+
test_cases = [
51+
([0, 0], 0),
52+
([0, 100], 1),
53+
([100, 0], 1),
54+
([100, 100], 0),
55+
]
56+
57+
print("\nCore XOR Results:")
58+
correct = 0
59+
probabilities = array.array('f', [0.0, 0.0])
60+
61+
for features, expected in test_cases:
62+
test_features = array.array('h', features)
63+
predicted = model.predict_proba(test_features, probabilities)
64+
is_correct = predicted == expected
65+
if is_correct:
66+
correct += 1
67+
68+
confidence = max(probabilities[0], probabilities[1])
69+
print(f" {features} -> pred={predicted}, exp={expected}, conf={confidence:.2f} {'✓' if is_correct else '✗'}")
70+
71+
core_accuracy = 100.0 * correct / 4
72+
print(f"Core XOR Accuracy: {core_accuracy:.0f}%")
73+
74+
# Test interpolation (intermediate values)
75+
print("\nInterpolation Test:")
76+
interpolation_cases = [
77+
([25, 25], "?"), # Between (0,0) and (100,100) - ambiguous
78+
([25, 75], "?"), # Between (0,100) and (100,0) - ambiguous
79+
([10, 90], 1), # Closer to (0,100) -> should be 1
80+
([90, 10], 1), # Closer to (100,0) -> should be 1
81+
([90, 90], 0), # Closer to (100,100) -> should be 0
82+
([10, 10], 0), # Closer to (0,0) -> should be 0
83+
]
84+
85+
for features, expected in interpolation_cases:
86+
test_features = array.array('h', features)
87+
predicted = model.predict_proba(test_features, probabilities)
88+
confidence = max(probabilities[0], probabilities[1])
89+
90+
if expected == "?":
91+
marker = "?"
92+
else:
93+
marker = "✓" if predicted == expected else "✗"
94+
95+
print(f" {features} -> pred={predicted}, exp={expected}, conf={confidence:.2f} {marker}")
96+
97+
return core_accuracy >= 100
98+
99+
def test_xor_robustness():
100+
"""Test XOR robustness with different parameters"""
101+
print("\n=== XOR Robustness Test ===")
102+
103+
# XOR data
104+
X_data = [0, 0, 0, 100, 100, 0, 100, 100] * 6 # 24 samples
105+
y_data = [0, 1, 1, 0] * 6
106+
107+
X = array.array('h', X_data)
108+
y = array.array('h', y_data)
109+
110+
configs = [
111+
(5, 6, "5 trees, depth 6"),
112+
(15, 10, "15 trees, depth 10"),
113+
(20, 12, "20 trees, depth 12"),
114+
]
115+
116+
results = []
117+
118+
for n_trees, max_depth, desc in configs:
119+
print(f"\nTesting {desc}:")
120+
121+
model = emlearn_extratrees.new(2, 2, n_trees, max_depth, 1, 8, 0.9, 1.0, 1000, 100, 123)
122+
model.train(X, y)
123+
124+
# Test all XOR cases
125+
correct = 0
126+
probabilities = array.array('f', [0.0, 0.0])
127+
test_cases = [([0, 0], 0), ([0, 100], 1), ([100, 0], 1), ([100, 100], 0)]
128+
129+
for features, expected in test_cases:
130+
test_features = array.array('h', features)
131+
predicted = model.predict_proba(test_features, probabilities)
132+
if predicted == expected:
133+
correct += 1
134+
135+
accuracy = 100.0 * correct / 4
136+
results.append(accuracy)
137+
print(f" Accuracy: {accuracy:.0f}% ({correct}/4 correct)")
138+
139+
avg_accuracy = sum(results) / len(results)
140+
print(f"\nAverage accuracy across configs: {avg_accuracy:.0f}%")
141+
142+
return avg_accuracy >= 75
143+
144+
def test_xor_different_values():
145+
"""Test XOR with different value ranges"""
146+
print("\n=== XOR with Different Value Ranges ===")
147+
148+
# Test with different value ranges to ensure generalization
149+
test_ranges = [
150+
([0, 1], "Binary"),
151+
([0, 10], "0-10"),
152+
([0, 1000], "0-1000"),
153+
([-50, 50], "-50 to 50"),
154+
]
155+
156+
results = []
157+
158+
for value_range, desc in test_ranges:
159+
print(f"\nTesting {desc} range:")
160+
161+
low, high = value_range
162+
X_data = [
163+
low, low, # (low,low) -> 0
164+
low, high, # (low,high) -> 1
165+
high, low, # (high,low) -> 1
166+
high, high, # (high,high) -> 0
167+
] * 8 # 32 samples
168+
y_data = [0, 1, 1, 0] * 8
169+
170+
X = array.array('h', X_data)
171+
y = array.array('h', y_data)
172+
173+
model = emlearn_extratrees.new(2, 2, 12, 10, 1, 10, 0.8, 1.0, 800, 100, 456)
174+
model.train(X, y)
175+
176+
# Test
177+
test_cases = [
178+
([low, low], 0),
179+
([low, high], 1),
180+
([high, low], 1),
181+
([high, high], 0),
182+
]
183+
184+
correct = 0
185+
probabilities = array.array('f', [0.0, 0.0])
186+
187+
for features, expected in test_cases:
188+
test_features = array.array('h', features)
189+
predicted = model.predict_proba(test_features, probabilities)
190+
if predicted == expected:
191+
correct += 1
192+
193+
accuracy = 100.0 * correct / 4
194+
results.append(accuracy)
195+
print(f" Accuracy: {accuracy:.0f}%")
196+
197+
avg_accuracy = sum(results) / len(results)
198+
print(f"\nAverage across value ranges: {avg_accuracy:.0f}%")
199+
200+
return avg_accuracy >= 75
201+
202+
if __name__ == "__main__":
203+
print("🔥 FIXED XOR TEST SUITE 🔥")
204+
print("=" * 60)
205+
206+
try:
207+
# Test 1: Comprehensive XOR
208+
success1 = test_xor_comprehensive()
209+
210+
if success1:
211+
print("\n✅ COMPREHENSIVE XOR TEST PASSED!")
212+
213+
# Test 2: Robustness
214+
success2 = test_xor_robustness()
215+
216+
# Test 3: Different value ranges
217+
success3 = test_xor_different_values()
218+
219+
if success2:
220+
print("\n✅ ROBUSTNESS TEST PASSED!")
221+
if success3:
222+
print("\n✅ VALUE RANGE TEST PASSED!")
223+
224+
if success1 and success2 and success3:
225+
print("\n🎉🎉🎉 ALL XOR TESTS PASSED! 🎉🎉🎉")
226+
print("Your Extra Trees implementation is WORKING PERFECTLY!")
227+
print("The algorithm can now learn complex non-linear patterns like XOR!")
228+
else:
229+
print("\n🔥 Core XOR works! Some edge cases may need fine-tuning.")
230+
231+
else:
232+
print("\n❌ Something is still wrong with the core algorithm")
233+
234+
except Exception as e:
235+
print(f"❌ Error: {e}")
236+
import sys
237+
sys.print_exception(e)
238+
239+
print("\n" + "="*60)

0 commit comments

Comments
 (0)