Skip to content

Commit 1b07849

Browse files
committed
feat: 3-class REM-optimized classifier with temporal weighting
- Add remOptimizedClassifier.ts: 3-class model (Awake/NREM/REM) - Heavily weight 90-minute ultradian cycle temporal features - Track local HR variability as discriminative feature for REM - Update SleepStageGraph to display 3-class stages - Remove legacy 4-class classifier from settings UI - Simplify hybridClassifier to only use REM-optimized model Key improvements over previous classifier: - Temporal features weighted at 50%+ (vs HR which has nearly identical means) - Local HR std used to detect stable REM periods - Explicit REM boost during cycle windows (last 30% of 90-min cycle) - No REM detection in first 70 minutes (biological constraint)
1 parent b65aa9e commit 1b07849

4 files changed

Lines changed: 1392 additions & 216 deletions

File tree

app/(tabs)/settings.tsx

Lines changed: 62 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ import {
2929
formatHealthConnectDebugReport,
3030
} from '@/services/healthConnect';
3131
import {
32-
trainAndValidate,
33-
loadEnhancedModel,
34-
clearEnhancedModel,
35-
type EnhancedModel,
36-
type ValidationResult,
37-
} from '@/services/enhancedSleepClassifier';
32+
trainRemOptimizedModel,
33+
loadRemOptimizedModel,
34+
clearRemOptimizedModel,
35+
formatTrainingReport,
36+
type TrainingReport,
37+
} from '@/services/remOptimizedClassifier';
3838

3939
export default function SettingsScreen() {
4040
const [showSleepDebug, setShowSleepDebug] = useState(false);
@@ -47,9 +47,9 @@ export default function SettingsScreen() {
4747
const [showHCDebug, setShowHCDebug] = useState(false);
4848
const [hcDebugReport, setHcDebugReport] = useState<string | null>(null);
4949
const [isLoadingHCDebug, setIsLoadingHCDebug] = useState(false);
50-
const [showEnhancedTraining, setShowEnhancedTraining] = useState(false);
51-
const [enhancedModelStatus, setEnhancedModelStatus] = useState<string | null>(null);
52-
const [isTrainingEnhanced, setIsTrainingEnhanced] = useState(false);
50+
const [showRemOptimized, setShowRemOptimized] = useState(false);
51+
const [remOptimizedReport, setRemOptimizedReport] = useState<string | null>(null);
52+
const [isTrainingRemOptimized, setIsTrainingRemOptimized] = useState(false);
5353

5454
const health = useHealth();
5555

@@ -95,80 +95,47 @@ export default function SettingsScreen() {
9595
}
9696
};
9797

98-
const formatEnhancedModelStatus = (
99-
model: EnhancedModel | null,
100-
validation?: ValidationResult
101-
) => {
102-
if (!model) return 'No enhanced model trained yet.';
103-
104-
let status = `ENHANCED MODEL STATUS\n`;
105-
status += `━━━━━━━━━━━━━━━━━━━━━\n`;
106-
status += `Nights analyzed: ${model.nightsAnalyzed}\n`;
107-
status += `Last updated: ${new Date(model.lastUpdated).toLocaleString()}\n`;
108-
status += `Temporal smoothing: ${model.temporalSmoothingStrength.toFixed(2)}\n\n`;
109-
110-
status += `FEATURE WEIGHTS\n`;
111-
status += `HR: ${model.featureWeights.hr.toFixed(2)}, HRV: ${model.featureWeights.hrv.toFixed(2)}\n`;
112-
status += `HRV Est: ${model.featureWeights.hrvEst.toFixed(2)}, RR: ${model.featureWeights.rr.toFixed(2)}\n\n`;
113-
114-
if (model.validationAccuracy !== null) {
115-
status += `VALIDATION ACCURACY\n`;
116-
status += `━━━━━━━━━━━━━━━━━━━━━\n`;
117-
status += `Overall: ${(model.validationAccuracy * 100).toFixed(1)}%\n\n`;
118-
status += `Per-Stage:\n`;
119-
status += ` Awake: ${(model.perStageAccuracy.awake * 100).toFixed(1)}%\n`;
120-
status += ` Light: ${(model.perStageAccuracy.light * 100).toFixed(1)}%\n`;
121-
status += ` Deep: ${(model.perStageAccuracy.deep * 100).toFixed(1)}%\n`;
122-
status += ` REM: ${(model.perStageAccuracy.rem * 100).toFixed(1)}%\n`;
123-
}
124-
125-
if (validation) {
126-
status += `\nCONFUSION MATRIX (rows=actual, cols=predicted)\n`;
127-
status += `━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n`;
128-
const stages: Array<'awake' | 'light' | 'deep' | 'rem'> = ['awake', 'light', 'deep', 'rem'];
129-
status += ` awake light deep rem\n`;
130-
for (const actual of stages) {
131-
status += `${actual.padEnd(6)}`;
132-
for (const predicted of stages) {
133-
status += `${validation.confusionMatrix[actual][predicted].toString().padStart(6)}`;
134-
}
135-
status += '\n';
136-
}
137-
status += `\nTotal samples: ${validation.totalSamples}\n`;
98+
const handleTrainRemOptimized = async () => {
99+
setIsTrainingRemOptimized(true);
100+
setRemOptimizedReport('Starting REM-optimized (3-class) training...');
101+
try {
102+
const { model, report } = await trainRemOptimizedModel(48);
103+
setRemOptimizedReport(formatTrainingReport(report));
104+
} catch (error) {
105+
setRemOptimizedReport(`Training failed: ${error}`);
106+
} finally {
107+
setIsTrainingRemOptimized(false);
138108
}
139-
140-
return status;
141109
};
142110

143-
const handleLoadEnhancedModel = async () => {
144-
const model = await loadEnhancedModel();
145-
setEnhancedModelStatus(formatEnhancedModelStatus(model));
146-
};
111+
const handleLoadRemOptimized = async () => {
112+
const model = await loadRemOptimizedModel();
113+
if (model) {
114+
let status = '╔══════════════════════════════════════════╗\n';
115+
status += '║ REM-OPTIMIZED MODEL STATUS (3-CLASS) ║\n';
116+
status += '╚══════════════════════════════════════════╝\n\n';
117+
status += `Nights analyzed: ${model.nightsAnalyzed}\n`;
118+
status += `Last updated: ${new Date(model.lastUpdated).toLocaleString()}\n\n`;
119+
120+
if (model.validationAccuracy !== null) {
121+
status += `Overall Accuracy: ${(model.validationAccuracy * 100).toFixed(1)}%\n`;
122+
status += `REM Sensitivity: ${((model.remSensitivity ?? 0) * 100).toFixed(1)}%\n`;
123+
status += `REM Specificity: ${((model.remSpecificity ?? 0) * 100).toFixed(1)}%\n\n`;
124+
status += `Per-Stage Accuracy:\n`;
125+
status += ` Awake: ${(model.perStageAccuracy.awake * 100).toFixed(1)}%\n`;
126+
status += ` NREM: ${(model.perStageAccuracy.nrem * 100).toFixed(1)}%\n`;
127+
status += ` REM: ${(model.perStageAccuracy.rem * 100).toFixed(1)}%\n`;
128+
}
147129

148-
const handleTrainEnhanced = async () => {
149-
setIsTrainingEnhanced(true);
150-
setEnhancedModelStatus('Starting training...');
151-
try {
152-
const { model, validation, bestParams } = await trainAndValidate((message) => {
153-
setEnhancedModelStatus(message);
154-
});
155-
let status = formatEnhancedModelStatus(model, validation);
156-
status += `\nBEST PARAMETERS FOUND\n`;
157-
status += `━━━━━━━━━━━━━━━━━━━━━\n`;
158-
status += `Temporal smoothing: ${bestParams.temporalSmoothing}\n`;
159-
status += `HR weight: ${bestParams.hrWeight}\n`;
160-
status += `HRV weight: ${bestParams.hrvWeight}\n`;
161-
setEnhancedModelStatus(status);
162-
} catch (error) {
163-
setEnhancedModelStatus(`Training failed: ${error}`);
164-
} finally {
165-
setIsTrainingEnhanced(false);
130+
setRemOptimizedReport(status);
131+
} else {
132+
setRemOptimizedReport('No REM-optimized model trained yet.\nTap "Train 3-Class" to begin.');
166133
}
167134
};
168135

169-
const handleClearEnhanced = async () => {
170-
await clearEnhancedModel();
171-
setEnhancedModelStatus('Enhanced model cleared.');
136+
const handleClearRemOptimized = async () => {
137+
await clearRemOptimizedModel();
138+
setRemOptimizedReport('REM-optimized model cleared.');
172139
};
173140

174141
useEffect(() => {
@@ -440,58 +407,58 @@ export default function SettingsScreen() {
440407
{(Platform.OS === 'android' || Platform.OS === 'ios') && (
441408
<>
442409
<MenuRow
443-
icon="fitness-outline"
444-
label="Sleep Stage Classifier Training"
410+
icon="moon-outline"
411+
label="Sleep Stage Classifier"
445412
onPress={() => {
446-
setShowEnhancedTraining(!showEnhancedTraining);
447-
if (!showEnhancedTraining && !enhancedModelStatus) {
448-
handleLoadEnhancedModel();
413+
setShowRemOptimized(!showRemOptimized);
414+
if (!showRemOptimized && !remOptimizedReport) {
415+
handleLoadRemOptimized();
449416
}
450417
}}
451418
/>
452-
{showEnhancedTraining && (
419+
{showRemOptimized && (
453420
<View style={styles.expandedSection}>
454421
<Text variant="caption" color="muted" style={{ marginBottom: spacing.sm }}>
455-
Learn your personal sleep patterns from your wearable data to improve REM
456-
detection accuracy.
422+
Train a personalized REM detector using your wearable sleep data. Uses 90-minute
423+
ultradian cycles and HR variability patterns to predict REM windows.
457424
</Text>
458425
<View style={styles.debugButtonRow}>
459426
<Pressable
460427
style={styles.debugButton}
461-
onPress={handleLoadEnhancedModel}
462-
disabled={isTrainingEnhanced}
428+
onPress={handleLoadRemOptimized}
429+
disabled={isTrainingRemOptimized}
463430
>
464431
<Ionicons name="information-circle" size={16} color={colors.primary[500]} />
465432
<Text variant="caption" color="primary">
466-
Check Status
433+
Status
467434
</Text>
468435
</Pressable>
469436
<Pressable
470437
style={[styles.debugButton, styles.trainButton]}
471-
onPress={handleTrainEnhanced}
472-
disabled={isTrainingEnhanced}
438+
onPress={handleTrainRemOptimized}
439+
disabled={isTrainingRemOptimized}
473440
>
474-
{isTrainingEnhanced ? (
441+
{isTrainingRemOptimized ? (
475442
<ActivityIndicator size="small" color={colors.gray[950]} />
476443
) : (
477444
<Ionicons name="flash" size={16} color={colors.gray[950]} />
478445
)}
479446
<Text variant="caption" style={{ color: colors.gray[950] }}>
480-
Train Model
447+
Train
481448
</Text>
482449
</Pressable>
483450
<Pressable
484451
style={styles.debugButton}
485-
onPress={handleClearEnhanced}
486-
disabled={isTrainingEnhanced}
452+
onPress={handleClearRemOptimized}
453+
disabled={isTrainingRemOptimized}
487454
>
488455
<Ionicons name="trash-outline" size={16} color={colors.error} />
489456
<Text variant="caption" color="primary">
490457
Clear
491458
</Text>
492459
</Pressable>
493460
</View>
494-
{enhancedModelStatus && (
461+
{remOptimizedReport && (
495462
<ScrollView
496463
style={styles.debugOutput}
497464
horizontal={false}
@@ -503,7 +470,7 @@ export default function SettingsScreen() {
503470
style={styles.debugOutputText}
504471
selectable={true}
505472
>
506-
{enhancedModelStatus}
473+
{remOptimizedReport}
507474
</Text>
508475
</ScrollView>
509476
)}

components/SleepStageGraph.tsx

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import { View, StyleSheet, Platform } from 'react-native';
44
import { Text } from '@/components/ui/Text';
55
import { colors, spacing, borderRadius } from '@/theme/tokens';
66
import type { SleepStage } from '@/types/database';
7-
import { getSleepStageColor } from '@/services/sleep';
7+
8+
type DisplayStage = 'awake' | 'nrem' | 'rem';
89

910
interface StageHistoryEntry {
1011
stage: SleepStage;
@@ -20,15 +21,37 @@ interface SleepStageGraphProps {
2021
onStagePress?: (stage: SleepStage, timestamp: number) => void;
2122
}
2223

23-
const STAGE_ORDER: SleepStage[] = ['awake', 'light', 'rem', 'deep'];
24-
const STAGE_Y_POSITIONS: Record<SleepStage, number> = {
24+
const STAGE_ORDER: DisplayStage[] = ['awake', 'nrem', 'rem'];
25+
const STAGE_Y_POSITIONS: Record<DisplayStage, number> = {
2526
awake: 0,
26-
light: 1,
27+
nrem: 1,
2728
rem: 2,
28-
deep: 3,
29-
any: 2,
3029
};
3130

31+
function toDisplayStage(stage: SleepStage): DisplayStage {
32+
switch (stage) {
33+
case 'awake':
34+
return 'awake';
35+
case 'light':
36+
case 'deep':
37+
case 'any':
38+
return 'nrem';
39+
case 'rem':
40+
return 'rem';
41+
}
42+
}
43+
44+
function getDisplayStageColor(stage: DisplayStage): string {
45+
switch (stage) {
46+
case 'awake':
47+
return colors.warning;
48+
case 'nrem':
49+
return colors.accent.cyan;
50+
case 'rem':
51+
return colors.primary[500];
52+
}
53+
}
54+
3255
const GRAPH_HEIGHT = 160;
3356
const GRAPH_PADDING_TOP = 20;
3457
const GRAPH_PADDING_BOTTOM = 20;
@@ -43,8 +66,14 @@ const LINE_THICKNESS = 4;
4366
const TRANSITION_COLOR = colors.gray[500];
4467

4568
function getYForStage(stage: SleepStage): number {
69+
const displayStage = toDisplayStage(stage);
70+
const position = STAGE_Y_POSITIONS[displayStage];
71+
return GRAPH_PADDING_TOP + (position / 2) * USABLE_HEIGHT;
72+
}
73+
74+
function getYForDisplayStage(stage: DisplayStage): number {
4675
const position = STAGE_Y_POSITIONS[stage];
47-
return GRAPH_PADDING_TOP + (position / 3) * USABLE_HEIGHT;
76+
return GRAPH_PADDING_TOP + (position / 2) * USABLE_HEIGHT;
4877
}
4978

5079
interface DataPoint {
@@ -212,7 +241,7 @@ export function SleepStageGraph({
212241
{
213242
left: p.x - 6,
214243
top: p.y - 6,
215-
backgroundColor: getSleepStageColor(p.stage),
244+
backgroundColor: getDisplayStageColor(toDisplayStage(p.stage)),
216245
},
217246
]}
218247
/>
@@ -227,8 +256,10 @@ export function SleepStageGraph({
227256
const prev = computedPoints[i - 1];
228257
const curr = computedPoints[i];
229258

230-
const isTransition = prev.stage !== curr.stage;
231-
const segmentColor = isTransition ? TRANSITION_COLOR : getSleepStageColor(prev.stage);
259+
const prevDisplay = toDisplayStage(prev.stage);
260+
const currDisplay = toDisplayStage(curr.stage);
261+
const isTransition = prevDisplay !== currDisplay;
262+
const segmentColor = isTransition ? TRANSITION_COLOR : getDisplayStageColor(prevDisplay);
232263

233264
if (Platform.OS === 'web') {
234265
const midX = (prev.x + curr.x) / 2;
@@ -280,7 +311,7 @@ export function SleepStageGraph({
280311
{
281312
left: last.x - 6,
282313
top: last.y - 6,
283-
backgroundColor: getSleepStageColor(last.stage),
314+
backgroundColor: getDisplayStageColor(toDisplayStage(last.stage)),
284315
},
285316
]}
286317
/>
@@ -304,18 +335,18 @@ export function SleepStageGraph({
304335

305336
const renderGridLines = useCallback(() => {
306337
return STAGE_ORDER.map((stage) => {
307-
const y = getYForStage(stage);
338+
const y = getYForDisplayStage(stage);
308339
return <View key={stage} style={[styles.gridLine, { top: y }]} />;
309340
});
310341
}, []);
311342

312343
const renderStageLabels = useCallback(() => {
313344
return STAGE_ORDER.map((stage) => {
314-
const y = getYForStage(stage);
315-
const label = stage === 'rem' ? 'REM' : stage.charAt(0).toUpperCase() + stage.slice(1);
345+
const y = getYForDisplayStage(stage);
346+
const label = stage === 'rem' ? 'REM' : stage === 'nrem' ? 'NREM' : 'Awake';
316347
return (
317348
<View key={stage} style={[styles.stageLabel, { top: y - 8 }]}>
318-
<View style={[styles.stageDot, { backgroundColor: getSleepStageColor(stage) }]} />
349+
<View style={[styles.stageDot, { backgroundColor: getDisplayStageColor(stage) }]} />
319350
<Text variant="caption" color="muted" style={styles.stageLabelText}>
320351
{label}
321352
</Text>

0 commit comments

Comments
 (0)