Skip to content

Commit b29e84b

Browse files
committed
extratrees: Update to use shared datasets
1 parent b4de882 commit b29e84b

2 files changed

Lines changed: 96 additions & 63 deletions

File tree

tests/test_extratrees_cancer.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,35 @@
44
import gc
55
import npyfile
66

7-
def load_npy_int16(filename):
8-
"""Load .npy file and convert to int16 array"""
7+
DATA_DIR = 'examples/datasets/cancer/'
8+
DATA_FILES = {
9+
'X_train': DATA_DIR + 'X_train.npy',
10+
'y_train': DATA_DIR + 'y_train.npy',
11+
'X_test': DATA_DIR + 'X_test.npy',
12+
'y_test': DATA_DIR + 'y_test.npy',
13+
}
14+
15+
def load_npy_features_int16(filename):
16+
"""Load .npy file and convert to int16 array (scaled from float32)"""
17+
shape, data = npyfile.load(filename)
18+
# Scale float32 data to int16 range (multiply by 1000 and convert)
19+
scaled = [int(v * 1000) for v in data]
20+
return array.array('h', scaled)
21+
22+
def load_npy_labels_int16(filename):
23+
"""Load .npy file and convert to int16 array (labels, no scaling)"""
924
shape, data = npyfile.load(filename)
10-
return array.array('h', data)
25+
# Labels are already integers (0.0, 1.0), just convert directly
26+
labels = [int(v) for v in data]
27+
return array.array('h', labels)
1128

1229
def test_real_dataset():
1330
print("=== REAL DATASET TEST ===")
1431

15-
X_train_flat = load_npy_int16('X_train.npy')
16-
y_train = load_npy_int16('y_train.npy')
17-
X_test_flat = load_npy_int16('X_test.npy')
18-
y_test = load_npy_int16('y_test.npy')
32+
X_train_flat = load_npy_features_int16(DATA_FILES['X_train'])
33+
y_train = load_npy_labels_int16(DATA_FILES['y_train'])
34+
X_test_flat = load_npy_features_int16(DATA_FILES['X_test'])
35+
y_test = load_npy_labels_int16(DATA_FILES['y_test'])
1936

2037

2138
n_features = 30

tests/test_extratrees_wine.py

Lines changed: 72 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,72 @@
11
#!/usr/bin/env python3
2-
# Wine Quality test for MicroPython
2+
# Wine dataset test for MicroPython (sklearn Wine - 3 class classification)
33

44
import array
55
import gc
66
import time
77
import npyfile
88
import emlearn_extratrees
99

10-
def load_npy_int16(filename):
11-
"""Load .npy file and convert to int16 array"""
10+
DATA_DIR = 'examples/datasets/wine/'
11+
DATA_FILES = {
12+
'X_train': DATA_DIR + 'X_train.npy',
13+
'y_train': DATA_DIR + 'y_train.npy',
14+
'X_test': DATA_DIR + 'X_test.npy',
15+
'y_test': DATA_DIR + 'y_test.npy',
16+
}
17+
18+
def load_npy_features_int16(filename):
19+
"""Load .npy file and convert to int16 array (scaled from float32)"""
20+
shape, data = npyfile.load(filename)
21+
# Scale float32 data to int16 range (multiply by 1000 and convert)
22+
scaled = [int(v * 1000) for v in data]
23+
return array.array('h', scaled)
24+
25+
def load_npy_labels_int16(filename):
26+
"""Load .npy file and convert to int16 array (labels, no scaling)"""
1227
shape, data = npyfile.load(filename)
13-
return array.array('h', data)
28+
# Labels are already integers (0.0, 1.0, 2.0), just convert directly
29+
labels = [int(v) for v in data]
30+
return array.array('h', labels)
1431

15-
def test_wine_quality():
16-
print("=== WINE QUALITY DATASET TEST ===")
32+
def test_wine():
33+
print("=== WINE DATASET TEST (3-class) ===")
1734

1835
# Load preprocessed data
1936
try:
20-
X_train_flat = load_npy_int16('X_train.npy')
21-
y_train = load_npy_int16('y_train.npy')
22-
X_test_flat = load_npy_int16('X_test.npy')
23-
y_test = load_npy_int16('y_test.npy')
37+
X_train_flat = load_npy_features_int16(DATA_FILES['X_train'])
38+
y_train = load_npy_labels_int16(DATA_FILES['y_train'])
39+
X_test_flat = load_npy_features_int16(DATA_FILES['X_test'])
40+
y_test = load_npy_labels_int16(DATA_FILES['y_test'])
2441
except:
25-
print("Error: Run wine_quality_prep.py first")
42+
print("Error: Run wine/prepare.py first")
2643
return
2744

28-
n_features = 12 # 11 wine features + wine_type
45+
n_features = 13 # 13 wine features (alcohol, malic_acid, ash, etc.)
2946
n_train = len(y_train)
3047
n_test = len(y_test)
3148

49+
# Determine number of classes from data
50+
n_classes = int(max(y_train)) + 1
51+
3252
print(f"Loaded: {n_train} train, {n_test} test samples")
33-
print(f"Features: {n_features} (alcohol, acidity, etc. + wine_type)")
34-
print("Task: Predict good wine (quality >= 6) vs poor wine")
53+
print(f"Features: {n_features}")
54+
print(f"Classes: {n_classes} (wine cultivars 0, 1, 2)")
55+
print("Task: Classify wine cultivar")
3556

36-
# Create model - adjusted for large dataset constraints
57+
# Create model
3758
model = emlearn_extratrees.new(
38-
12, # n_features
39-
2, # n_classes
40-
5, # n_trees
41-
10, # max_depth
42-
3, # min_samples_leaf
43-
20, # n_thresholds
44-
0.20, # subsample_ratio (much smaller: 15% of 5197 = ~780 samples)
45-
0.8, # feature_subsample_ratio
46-
2000, # max_nodes
47-
10000, # max_samples (matches subsample size)
48-
42 # rng_seed
59+
n_features, # n_features
60+
n_classes, # n_classes
61+
20, # n_trees
62+
12, # max_depth
63+
2, # min_samples_leaf
64+
15, # n_thresholds
65+
0.8, # subsample_ratio
66+
1.0, # feature_subsample_ratio
67+
3000, # max_nodes
68+
500, # max_samples
69+
42 # rng_seed
4970
)
5071

5172
train_start = time.ticks_ms()
@@ -57,8 +78,11 @@ def test_wine_quality():
5778

5879
# Test
5980
correct = 0
60-
tp, tn, fp, fn = 0, 0, 0, 0
61-
probabilities = array.array('f', [0.0, 0.0])
81+
probabilities = array.array('f', [0.0] * n_classes)
82+
83+
# Track class-wise accuracy
84+
class_correct = [0] * n_classes
85+
class_total = [0] * n_classes
6286

6387
for i in range(n_test):
6488
start_idx = i * n_features
@@ -68,48 +92,40 @@ def test_wine_quality():
6892
predicted = model.predict_proba(features, probabilities)
6993
actual = y_test[i]
7094

95+
# Track per-class stats
96+
class_total[actual] += 1
7197
if predicted == actual:
7298
correct += 1
73-
74-
# Confusion matrix
75-
if predicted == 1 and actual == 1:
76-
tp += 1
77-
elif predicted == 0 and actual == 0:
78-
tn += 1
79-
elif predicted == 1 and actual == 0:
80-
fp += 1
81-
elif predicted == 0 and actual == 1:
82-
fn += 1
99+
class_correct[actual] += 1
83100

84101
if i < 5:
85-
conf = max(probabilities[0], probabilities[1])
86-
wine_quality = "good" if actual == 1 else "poor"
87-
pred_quality = "good" if predicted == 1 else "poor"
88-
print(f"Sample {i}: pred={pred_quality}, actual={wine_quality}, conf={conf:.3f}")
102+
conf = max(probabilities)
103+
print(f"Sample {i}: pred={predicted}, actual={actual}, conf={conf:.3f}")
89104

90105
accuracy = correct / n_test
91-
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
92-
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
93-
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
94106

95107
print(f"\nResults:")
96108
print(f"Accuracy: {accuracy:.3f} ({correct}/{n_test})")
97-
print(f"Precision: {precision:.3f}")
98-
print(f"Recall: {recall:.3f}")
99-
print(f"F1-Score: {f1:.3f}")
100-
print(f"Confusion: TP={tp}, TN={tn}, FP={fp}, FN={fn}")
101-
print(f"Target (sklearn): ~0.80")
102109

103-
if accuracy >= 0.78:
104-
print("✅ EXCELLENT: Great performance on wine quality!")
105-
elif accuracy >= 0.75:
110+
# Per-class accuracy
111+
print(f"\nPer-class accuracy:")
112+
for c in range(n_classes):
113+
if class_total[c] > 0:
114+
class_acc = class_correct[c] / class_total[c]
115+
print(f" Class {c}: {class_acc:.3f} ({class_correct[c]}/{class_total[c]})")
116+
117+
print(f"Target (sklearn ExtraTrees): ~0.95")
118+
119+
if accuracy >= 0.90:
120+
print("✅ EXCELLENT: Great wine classification!")
121+
elif accuracy >= 0.85:
106122
print("✅ VERY GOOD: Strong wine classification!")
123+
elif accuracy >= 0.80:
124+
print("✅ GOOD: Solid wine classification!")
107125
elif accuracy >= 0.70:
108-
print("✅ GOOD: Solid wine quality prediction!")
109-
elif accuracy >= 0.65:
110126
print("⚠️ FAIR: Working but could improve")
111127
else:
112128
print("❌ POOR: Needs significant improvement")
113129

114130
if __name__ == "__main__":
115-
test_wine_quality()
131+
test_wine()

0 commit comments

Comments
 (0)