11#!/usr/bin/env python3
2- # Wine Quality test for MicroPython
2+ # Wine dataset test for MicroPython (sklearn Wine - 3 class classification)
33
44import array
55import gc
66import time
77import npyfile
88import emlearn_extratrees
99
10- def load_npy_int16 (filename ):
11- """Load .npy file and convert to int16 array"""
10+ DATA_DIR = 'examples/datasets/wine/'
11+ DATA_FILES = {
12+ 'X_train' : DATA_DIR + 'X_train.npy' ,
13+ 'y_train' : DATA_DIR + 'y_train.npy' ,
14+ 'X_test' : DATA_DIR + 'X_test.npy' ,
15+ 'y_test' : DATA_DIR + 'y_test.npy' ,
16+ }
17+
18+ def load_npy_features_int16 (filename ):
19+ """Load .npy file and convert to int16 array (scaled from float32)"""
20+ shape , data = npyfile .load (filename )
21+ # Scale float32 data to int16 range (multiply by 1000 and convert)
22+ scaled = [int (v * 1000 ) for v in data ]
23+ return array .array ('h' , scaled )
24+
25+ def load_npy_labels_int16 (filename ):
26+ """Load .npy file and convert to int16 array (labels, no scaling)"""
1227 shape , data = npyfile .load (filename )
13- return array .array ('h' , data )
28+ # Labels are already integers (0.0, 1.0, 2.0), just convert directly
29+ labels = [int (v ) for v in data ]
30+ return array .array ('h' , labels )
1431
15- def test_wine_quality ():
16- print ("=== WINE QUALITY DATASET TEST ===" )
32+ def test_wine ():
33+ print ("=== WINE DATASET TEST (3-class) ===" )
1734
1835 # Load preprocessed data
1936 try :
20- X_train_flat = load_npy_int16 ( 'X_train.npy' )
21- y_train = load_npy_int16 ( 'y_train.npy' )
22- X_test_flat = load_npy_int16 ( 'X_test.npy' )
23- y_test = load_npy_int16 ( 'y_test.npy' )
37+ X_train_flat = load_npy_features_int16 ( DATA_FILES [ 'X_train' ] )
38+ y_train = load_npy_labels_int16 ( DATA_FILES [ 'y_train' ] )
39+ X_test_flat = load_npy_features_int16 ( DATA_FILES [ 'X_test' ] )
40+ y_test = load_npy_labels_int16 ( DATA_FILES [ 'y_test' ] )
2441 except :
25- print ("Error: Run wine_quality_prep .py first" )
42+ print ("Error: Run wine/prepare .py first" )
2643 return
2744
28- n_features = 12 # 11 wine features + wine_type
45+ n_features = 13 # 13 wine features (alcohol, malic_acid, ash, etc.)
2946 n_train = len (y_train )
3047 n_test = len (y_test )
3148
49+ # Determine number of classes from data
50+ n_classes = int (max (y_train )) + 1
51+
3252 print (f"Loaded: { n_train } train, { n_test } test samples" )
33- print (f"Features: { n_features } (alcohol, acidity, etc. + wine_type)" )
34- print ("Task: Predict good wine (quality >= 6) vs poor wine" )
53+ print (f"Features: { n_features } " )
54+ print (f"Classes: { n_classes } (wine cultivars 0, 1, 2)" )
55+ print ("Task: Classify wine cultivar" )
3556
36- # Create model - adjusted for large dataset constraints
57+ # Create model
3758 model = emlearn_extratrees .new (
38- 12 , # n_features
39- 2 , # n_classes
40- 5 , # n_trees
41- 10 , # max_depth
42- 3 , # min_samples_leaf
43- 20 , # n_thresholds
44- 0.20 , # subsample_ratio (much smaller: 15% of 5197 = ~780 samples)
45- 0.8 , # feature_subsample_ratio
46- 2000 , # max_nodes
47- 10000 , # max_samples (matches subsample size)
48- 42 # rng_seed
59+ n_features , # n_features
60+ n_classes , # n_classes
61+ 20 , # n_trees
62+ 12 , # max_depth
63+ 2 , # min_samples_leaf
64+ 15 , # n_thresholds
65+ 0.8 , # subsample_ratio
66+ 1.0 , # feature_subsample_ratio
67+ 3000 , # max_nodes
68+ 500 , # max_samples
69+ 42 # rng_seed
4970 )
5071
5172 train_start = time .ticks_ms ()
@@ -57,8 +78,11 @@ def test_wine_quality():
5778
5879 # Test
5980 correct = 0
60- tp , tn , fp , fn = 0 , 0 , 0 , 0
61- probabilities = array .array ('f' , [0.0 , 0.0 ])
81+ probabilities = array .array ('f' , [0.0 ] * n_classes )
82+
83+ # Track class-wise accuracy
84+ class_correct = [0 ] * n_classes
85+ class_total = [0 ] * n_classes
6286
6387 for i in range (n_test ):
6488 start_idx = i * n_features
@@ -68,48 +92,40 @@ def test_wine_quality():
6892 predicted = model .predict_proba (features , probabilities )
6993 actual = y_test [i ]
7094
95+ # Track per-class stats
96+ class_total [actual ] += 1
7197 if predicted == actual :
7298 correct += 1
73-
74- # Confusion matrix
75- if predicted == 1 and actual == 1 :
76- tp += 1
77- elif predicted == 0 and actual == 0 :
78- tn += 1
79- elif predicted == 1 and actual == 0 :
80- fp += 1
81- elif predicted == 0 and actual == 1 :
82- fn += 1
99+ class_correct [actual ] += 1
83100
84101 if i < 5 :
85- conf = max (probabilities [0 ], probabilities [1 ])
86- wine_quality = "good" if actual == 1 else "poor"
87- pred_quality = "good" if predicted == 1 else "poor"
88- print (f"Sample { i } : pred={ pred_quality } , actual={ wine_quality } , conf={ conf :.3f} " )
102+ conf = max (probabilities )
103+ print (f"Sample { i } : pred={ predicted } , actual={ actual } , conf={ conf :.3f} " )
89104
90105 accuracy = correct / n_test
91- precision = tp / (tp + fp ) if (tp + fp ) > 0 else 0
92- recall = tp / (tp + fn ) if (tp + fn ) > 0 else 0
93- f1 = 2 * precision * recall / (precision + recall ) if (precision + recall ) > 0 else 0
94106
95107 print (f"\n Results:" )
96108 print (f"Accuracy: { accuracy :.3f} ({ correct } /{ n_test } )" )
97- print (f"Precision: { precision :.3f} " )
98- print (f"Recall: { recall :.3f} " )
99- print (f"F1-Score: { f1 :.3f} " )
100- print (f"Confusion: TP={ tp } , TN={ tn } , FP={ fp } , FN={ fn } " )
101- print (f"Target (sklearn): ~0.80" )
102109
103- if accuracy >= 0.78 :
104- print ("✅ EXCELLENT: Great performance on wine quality!" )
105- elif accuracy >= 0.75 :
110+ # Per-class accuracy
111+ print (f"\n Per-class accuracy:" )
112+ for c in range (n_classes ):
113+ if class_total [c ] > 0 :
114+ class_acc = class_correct [c ] / class_total [c ]
115+ print (f" Class { c } : { class_acc :.3f} ({ class_correct [c ]} /{ class_total [c ]} )" )
116+
117+ print (f"Target (sklearn ExtraTrees): ~0.95" )
118+
119+ if accuracy >= 0.90 :
120+ print ("✅ EXCELLENT: Great wine classification!" )
121+ elif accuracy >= 0.85 :
106122 print ("✅ VERY GOOD: Strong wine classification!" )
123+ elif accuracy >= 0.80 :
124+ print ("✅ GOOD: Solid wine classification!" )
107125 elif accuracy >= 0.70 :
108- print ("✅ GOOD: Solid wine quality prediction!" )
109- elif accuracy >= 0.65 :
110126 print ("⚠️ FAIR: Working but could improve" )
111127 else :
112128 print ("❌ POOR: Needs significant improvement" )
113129
114130if __name__ == "__main__" :
115- test_wine_quality ()
131+ test_wine ()
0 commit comments