forked from MorinWang/SONATA
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils_martingale_coreset.py
More file actions
1286 lines (1020 loc) · 54.4 KB
/
utils_martingale_coreset.py
File metadata and controls
1286 lines (1020 loc) · 54.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Martingale Coreset Selection Utilities
Enhanced coreset selection mechanism with:
- Martingale increment scoring
- Optimal stopping strategies
- Bellman equation optimization
- Multi-scale time weighting with Ito formula
- Improved replacement strategy with fill ratio threshold
"""
import numpy as np
import torch
import math
import logging
from collections import deque
import heapq
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class DataPointCoreSetManager:
"""
Core data point selection manager for tensor factorization
Selects important data points based on uncertainty, influence, and novelty
"""
def __init__(self, max_size=100, initial_threshold=0.5, adaptive_threshold=True,
importance_weights=(0.4, 0.3, 0.3), device=torch.device("cpu"),
exploration_rate=0.9, decay_rate=0.1, batch_replace_size=5,
fill_ratio_threshold=0.8):
"""
Initialize the data point coreset manager
Args:
max_size: Maximum coreset size
initial_threshold: Initial importance threshold
adaptive_threshold: Whether to use adaptive thresholding
importance_weights: Weights for (uncertainty, influence, novelty)
device: Computation device
exploration_rate: Initial exploration rate
decay_rate: Exploration rate decay coefficient
batch_replace_size: Number of data points to replace in a batch
fill_ratio_threshold: Ratio threshold to start replacement (e.g., 0.8 means start at 80% capacity)
"""
self.coreset = [] # Stores coreset data points [(indices, y, time_ind, score), ...]
self.coreset_indices = set() # Hash set of coreset data point indices
self.max_size = max_size
self.threshold = initial_threshold
self.adaptive = adaptive_threshold
self.weights = importance_weights # α, β, γ weights
self.device = device
self.batch_replace_size = batch_replace_size
self.fill_ratio_threshold = fill_ratio_threshold
# Calculate the size threshold at which to start replacement
self.size_threshold = int(self.max_size * self.fill_ratio_threshold)
# Exploration-exploitation balance parameters
self.epsilon_0 = exploration_rate # Initial exploration rate ϵ0
self.lambda_decay = decay_rate # Exploration rate decay coefficient λ
self.confidence = 0.0 # Current model confidence
# History for novelty and influence calculations
self.historical_points = []
self.max_history_size = 1000 # Limit history size
logger.info(f"Data point coreset manager initialized: max_size={max_size}, "
f"threshold={initial_threshold}, fill_ratio_threshold={fill_ratio_threshold}")
def compute_importance_score(self, indices, y, time_ind, model):
"""
Calculate importance score for a data point
Args:
indices: Data point indices (ℓn)
y: Observed value (yn)
time_ind: Timestamp index (tn)
model: DCTF model for uncertainty metrics
Returns:
Importance score
"""
# Get weights
w_u, w_i, w_n = self.weights # α, β, γ
# Calculate uncertainty
uncertainty = self._compute_uncertainty(indices, time_ind, model)
# Calculate influence
influence = self._compute_influence(indices, time_ind, model)
# Calculate novelty
novelty = self._compute_novelty(indices, y, time_ind)
# Combined score
score = w_u * uncertainty + w_i * influence + w_n * novelty
return score.item() if isinstance(score, torch.Tensor) else score
def _compute_uncertainty(self, indices, time_ind, model):
"""
Calculate data point uncertainty
Uses prediction variance as uncertainty measure
Args:
indices: Data point indices
time_ind: Timestamp index
model: DCTF model
Returns:
Uncertainty score
"""
try:
# Get factor variance for each mode
variances = []
for mode, idx in enumerate(indices):
if hasattr(model, 'post_U_v') and model.post_U_v:
# Get factor variance at this time point
var = torch.diagonal(model.post_U_v[mode][idx, :, :, time_ind]).mean()
variances.append(var.item())
# If variances can't be obtained, use estimated prediction error
if not variances:
return 1.0
# Use average variance across all modes as data point uncertainty
return np.mean(variances)
except Exception as e:
logger.error(f"Error computing uncertainty: {e}")
return 1.0 # Return default value on error
def _compute_influence(self, indices, time_ind, model):
"""
Calculate data point influence
Measures how this point affects other data points
Args:
indices: Data point indices
time_ind: Timestamp index
model: DCTF model
Returns:
Influence score
"""
try:
# Calculate relevance of this point to others in coreset
influence_score = 0.0
count = 0
# Get embedding for this data point
point_embedding = []
for mode, idx in enumerate(indices):
if hasattr(model, 'post_U_m') and model.post_U_m:
# Get factor mean at this time point
emb = model.post_U_m[mode][idx, :, :, time_ind]
point_embedding.append(emb.flatten().detach().cpu().numpy())
if not point_embedding:
return 0.5 # Return medium influence if embeddings unavailable
# Calculate similarity with other points in coreset
for core_indices, _, core_time, _ in self.coreset:
if tuple(indices) != tuple(core_indices): # Don't compare with self
# Get embedding for coreset point
core_embedding = []
for mode, idx in enumerate(core_indices):
if hasattr(model, 'post_U_m') and model.post_U_m:
# Get factor mean at that time point
emb = model.post_U_m[mode][idx, :, :, core_time]
core_embedding.append(emb.flatten().detach().cpu().numpy())
if core_embedding:
# Calculate similarity between points (cosine similarity)
similarity = self._compute_point_similarity(point_embedding, core_embedding)
influence_score += similarity
count += 1
return influence_score / max(1, count)
except Exception as e:
logger.error(f"Error computing influence: {e}")
return 0.5 # Return medium influence on error
def _compute_point_similarity(self, emb1, emb2):
"""Calculate similarity between two data point embeddings"""
try:
# Concatenate embeddings across modes
vec1 = np.concatenate([e.flatten() for e in emb1])
vec2 = np.concatenate([e.flatten() for e in emb2])
# Calculate cosine similarity
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2) + 1e-8)
return max(0, similarity) # Ensure non-negative
except Exception as e:
logger.error(f"Error computing point similarity: {e}")
return 0.0
def _compute_novelty(self, indices, y, time_ind):
"""
Calculate data point novelty
Measures difference from existing coreset
Args:
indices: Data point indices
y: Observed value
time_ind: Timestamp index
Returns:
Novelty score
"""
try:
# Initially assume maximum novelty
if not self.coreset:
return 1.0
# Convert indices to list for safe comparison
indices_list = [int(i) for i in indices]
# Check if data point with same indices already exists
already_exists = False
for core_indices, _, _, _ in self.coreset:
# Safely compare two lists
if len(indices_list) == len(core_indices):
if all(int(a) == int(b) for a, b in zip(indices_list, core_indices)):
already_exists = True
break
if already_exists:
return 0.0 # Point with identical indices exists, zero novelty
# Calculate distance to nearest time point
time_diffs = []
for _, _, core_time, _ in self.coreset:
time_diff = abs(int(time_ind) - int(core_time))
time_diffs.append(time_diff)
# Nearest time point distance
nearest_time_diff = min(time_diffs) if time_diffs else 1.0
# Time novelty: farther from nearest time point means more novel
time_novelty = 1.0 - math.exp(-0.1 * nearest_time_diff)
# Index novelty: check how many mode dimensions are new
index_novelty = 0.0
unique_indices_per_mode = [set() for _ in range(len(indices_list))]
# Collect existing indices in coreset for each mode
for core_indices, _, _, _ in self.coreset:
for mode, idx in enumerate(core_indices):
if mode < len(unique_indices_per_mode):
unique_indices_per_mode[mode].add(int(idx))
# Calculate proportion of new indices
new_mode_count = 0
for mode, idx in enumerate(indices_list):
if mode < len(unique_indices_per_mode):
# Safe membership check
if int(idx) not in unique_indices_per_mode[mode]:
new_mode_count += 1
index_novelty = new_mode_count / len(indices_list) if indices_list else 0.0
# Combined novelty (time and index)
novelty = 0.6 * time_novelty + 0.4 * index_novelty
return max(0.0, min(1.0, novelty)) # Ensure in [0,1] range
except Exception as e:
logger.error(f"Error computing novelty: {e}")
return 0.5 # Return medium novelty on error
def compute_exploration_rate(self):
"""
Calculate current exploration rate
Rate decreases as model confidence increases
"""
epsilon = self.epsilon_0 * math.exp(-self.lambda_decay * self.confidence)
return min(1.0, max(0.01, epsilon)) # Ensure within reasonable range
def update_confidence(self, new_confidence):
"""Update model confidence"""
# Confidence in 0-1 range, controls exploration rate
self.confidence = max(0.0, min(1.0, new_confidence))
def update_threshold(self, scores):
"""Dynamically update threshold θt based on score distribution"""
if not self.adaptive or not scores:
return self.threshold
# Adapt threshold based on score distribution
scores_array = np.array(scores)
mean_score = np.mean(scores_array)
std_score = np.std(scores_array)
# Set threshold to mean minus half standard deviation (ensure selecting enough samples)
new_threshold = mean_score - 0.5 * std_score
# Limit threshold change range to avoid excessive fluctuation
self.threshold = max(0.1, min(0.9, new_threshold))
return self.threshold
def select_points_to_replace(self, candidates, batch_size=None):
"""
Select which candidate points should replace existing coreset points
Args:
candidates: Candidate data points [(indices, y, time_ind, score), ...]
batch_size: Number of points to replace at once (default: self.batch_replace_size)
Returns:
List of points to add, list of indices to remove
"""
if batch_size is None:
batch_size = self.batch_replace_size
# Sort candidates by score in descending order
sorted_candidates = sorted(candidates, key=lambda x: x[3], reverse=True)
# If coreset is not yet full, no need for replacement
if len(self.coreset) < self.size_threshold:
# Add up to size_threshold
return sorted_candidates[:self.size_threshold - len(self.coreset)], []
# Sort existing coreset by score in ascending order
self.coreset.sort(key=lambda x: x[3])
# Calculate how many points we can replace
replace_count = min(batch_size, len(sorted_candidates))
# Check if candidate scores are better than lowest coreset scores
points_to_add = []
indices_to_remove = []
for i in range(replace_count):
if i >= len(sorted_candidates):
break
# Only replace if candidate score is higher than lowest coreset score
if i >= len(self.coreset):
break
if sorted_candidates[i][3] > self.coreset[i][3]:
points_to_add.append(sorted_candidates[i])
indices_to_remove.append(i) # Index of point to remove
else:
# If this candidate isn't better, subsequent ones won't be either
break
return points_to_add, indices_to_remove
def should_start_replacement(self):
"""
Check if the coreset size has reached the threshold to start replacement
Returns:
True if replacement should be considered, False otherwise
"""
return len(self.coreset) >= self.size_threshold
def update_coreset(self, data_batch, model):
"""
Update coreset with new data points using improved replacement strategy
Args:
data_batch: Data batch, each element contains (indices, y, time_ind)
model: DCTF model for importance metrics
Returns:
Added and removed data points
"""
# If batch is empty, return empty results
if not data_batch:
return [], []
candidates = []
scores = []
# Calculate importance score for each data point
for batch_idx, (indices, y, time_ind) in enumerate(data_batch):
try:
# Check if already in coreset
indices_tuple = tuple(int(idx) for idx in indices)
if indices_tuple in self.coreset_indices:
continue
# Calculate importance score
score = self.compute_importance_score(indices, y, time_ind, model)
# Add to candidate list
candidates.append((indices, y, time_ind, score))
scores.append(score)
except Exception as e:
logger.error(f"Error calculating data point importance score: {e}")
# Update threshold
current_threshold = self.update_threshold(scores)
# Apply exploration-exploitation balance
epsilon = self.compute_exploration_rate()
# Filter candidates based on threshold and exploration rate
filtered_candidates = []
for indices, y, time_ind, score in candidates:
# Exploration-exploitation balance
if np.random.random() < epsilon:
# Explore: select randomly
if np.random.random() < 0.3: # 30% chance to select
filtered_candidates.append((indices, y, time_ind, score))
else:
# Exploit: select based on score
if score > current_threshold:
filtered_candidates.append((indices, y, time_ind, score))
added = []
removed = []
# Check if we should start replacement
if self.should_start_replacement():
logger.info(f"Coreset size ({len(self.coreset)}) has reached threshold ({self.size_threshold}). "
f"Starting replacement strategy.")
# Apply replacement strategy
points_to_add, indices_to_remove = self.select_points_to_replace(filtered_candidates)
# Remove points with lowest scores
for idx in sorted(indices_to_remove, reverse=True): # Remove from highest index to lowest
if idx < len(self.coreset):
indices, y, time_ind, _ = self.coreset[idx]
indices_tuple = tuple(int(i) for i in indices)
# Remove from coreset and hash set
self.coreset.pop(idx)
if indices_tuple in self.coreset_indices:
self.coreset_indices.remove(indices_tuple)
removed.append((indices, time_ind))
# Add new points
for indices, y, time_ind, score in points_to_add:
try:
indices_tuple = tuple(int(idx) for idx in indices)
if indices_tuple not in self.coreset_indices:
self.coreset.append((indices, y, time_ind, score))
self.coreset_indices.add(indices_tuple)
added.append((indices, time_ind))
except Exception as e:
logger.error(f"Error adding replacement data point to coreset: {e}")
else:
# If we haven't reached threshold, directly add points
# Sort by score
filtered_candidates.sort(key=lambda x: x[3], reverse=True)
# Add up to size_threshold
for indices, y, time_ind, score in filtered_candidates[:self.size_threshold - len(self.coreset)]:
try:
# Generate index hash for quick lookup
indices_tuple = tuple(int(idx) for idx in indices)
if indices_tuple not in self.coreset_indices:
self.coreset.append((indices, y, time_ind, score))
self.coreset_indices.add(indices_tuple)
added.append((indices, time_ind))
except Exception as e:
logger.error(f"Error adding data point to coreset: {e}")
# Add to history
for indices, y, time_ind, _ in candidates:
try:
self.historical_points.append((indices, y, time_ind))
# Limit history size
if len(self.historical_points) > self.max_history_size:
self.historical_points.pop(0)
except Exception as e:
logger.error(f"Error updating history: {e}")
# Log coreset statistics
if added or removed:
logger.info(f"Coreset update: {len(added)} points added, {len(removed)} points removed. "
f"Current size: {len(self.coreset)}/{self.max_size} (threshold: {self.size_threshold})")
return added, removed
def is_in_coreset(self, indices):
"""Check if specified data point is in the coreset"""
try:
# Safely convert indices to integer tuple
indices_tuple = tuple(int(idx) for idx in indices)
return indices_tuple in self.coreset_indices
except Exception:
return False
def get_coreset_data(self):
"""Get all data points in the coreset"""
return [(indices, y, time_ind) for indices, y, time_ind, _ in self.coreset]
def get_coreset_size(self):
"""Get coreset size"""
return len(self.coreset)
class MultiScaleWeighting:
"""
Multi-scale weighting mechanism for time-dependent data
Dynamically adjusts weights for different time scales
"""
def __init__(self, num_scales=3, hidden_dim=32, device=torch.device("cpu"), temperature=1.0):
"""
Initialize multi-scale weighting mechanism
Args:
num_scales: Number of time scales
hidden_dim: Hidden layer dimension
device: Computation device
temperature: Attention softmax temperature parameter
"""
self.num_scales = num_scales
self.hidden_dim = hidden_dim
self.device = device
self.temperature = temperature
# Initialize parameters
self.W = torch.nn.Parameter(torch.randn(hidden_dim, hidden_dim, device=device))
self.v = torch.nn.Parameter(torch.randn(hidden_dim, device=device))
self.b = torch.nn.Parameter(torch.zeros(hidden_dim, device=device))
self.gamma_k = torch.nn.Parameter(torch.ones(num_scales, device=device))
logger.info(f"Multi-scale weighting mechanism initialized: num_scales={num_scales}, hidden_dim={hidden_dim}")
def compute_weights(self, h_k):
"""
Calculate weights for different time scales
Args:
h_k: Time scale hidden states list
Returns:
Normalized weight vector, shape [num_scales, 1]
"""
# Prevent empty list
if not h_k or len(h_k) == 0:
return torch.ones(1, 1, device=self.device)
try:
# Prepare weight storage
weights = torch.zeros(min(self.num_scales, len(h_k)), device=self.device)
# Calculate attention score for each scale
for k in range(min(self.num_scales, len(h_k))):
# Get hidden state for current scale
h = h_k[k]
# Ensure h is a 1D tensor
if isinstance(h, torch.Tensor) and h.dim() > 1:
h = h.reshape(-1)
# Convert to tensor if not already
if not isinstance(h, torch.Tensor):
h = torch.tensor(h, device=self.device)
# Map to hidden space
h_mapped = torch.matmul(self.W, h.float()) + self.b
# Apply tanh activation
h_tanh = torch.tanh(h_mapped)
# Calculate attention score
score = torch.matmul(self.v, h_tanh) * self.gamma_k[k]
weights[k] = score.item() if isinstance(score, torch.Tensor) else score
# Ensure all weights are finite
weights = torch.where(torch.isfinite(weights), weights, torch.zeros_like(weights))
# If all weights are zero, return uniform weights
if torch.sum(weights) == 0:
return torch.ones_like(weights) / weights.size(0)
# Use softmax with temperature parameter
weights = torch.nn.functional.softmax(weights / self.temperature, dim=0)
return weights.unsqueeze(1) # Return [num_scales, 1] shape
except Exception as e:
logger.error(f"Error calculating multi-scale weights: {e}")
# Return uniform weights on error
uniform_weights = torch.ones(min(self.num_scales, len(h_k)), device=self.device)
return (uniform_weights / uniform_weights.sum()).unsqueeze(1)
class MartingaleDataPointCoreSetManager(DataPointCoreSetManager):
"""Enhanced coreset manager using martingale theory"""
def __init__(self, max_size=100, initial_threshold=0.5, adaptive_threshold=True,
importance_weights=(0.3, 0.2, 0.2, 0.3), device=torch.device("cpu"),
exploration_rate=0.9, decay_rate=0.1, prediction_history_size=50,
discount_factor=0.9, simulation_samples=5, batch_replace_size=20,
fill_ratio_threshold=0.8):
"""
Initialize the manager
Args:
importance_weights: Importance weights (uncertainty, influence, novelty, martingale_increment)
prediction_history_size: Prediction history size
discount_factor: Bellman equation discount factor
simulation_samples: Number of samples for simulation evaluation
batch_replace_size: Number of data points to replace in a batch
fill_ratio_threshold: Ratio threshold to start replacement (e.g., 0.8 means start at 80% capacity)
Other parameters same as parent class
"""
# Call parent initialization with first three weights
super().__init__(max_size, initial_threshold, adaptive_threshold,
importance_weights[:3], device, exploration_rate, decay_rate,
batch_replace_size, fill_ratio_threshold)
# Add martingale increment weight
self.martingale_weight = importance_weights[3] if len(importance_weights) > 3 else 0.3
# Store past prediction distributions for martingale increment
self.prediction_history = deque(maxlen=prediction_history_size)
self.value_function_cache = {}
# Bellman equation parameters
self.discount_factor = discount_factor
self.simulation_samples = simulation_samples
# Record coreset score distribution for stopping decisions
self.score_history = []
logger.info(f"Martingale theory coreset manager initialized: max_size={max_size}, "
f"importance_weights={importance_weights}, fill_ratio_threshold={fill_ratio_threshold}")
def _compute_martingale_increment(self, indices, time_ind, model):
"""Calculate data point's martingale increment (information gain)
Quantifies information gain based on model prediction ability change
"""
try:
# 1. Get current model's prediction ability
before_error = self._compute_prediction_variance(model)
# 2. Simulate adding data point to model
# Note: we only evaluate impact, don't actually modify model
sample_points = [(indices, None, time_ind)]
after_error = self._evaluate_with_additional_points(model, sample_points)
# 3. Information gain = error reduction
increment = max(0, before_error - after_error)
# Normalize
return np.tanh(increment * 5) # Map increment to [0,1] interval
except Exception as e:
logger.error(f"Error calculating martingale increment: {e}")
return 0.0
def _compute_prediction_variance(self, model):
"""Calculate current model's prediction variance
Higher variance indicates lower prediction ability
"""
try:
# If test data available, use small test sample to evaluate
if hasattr(model, 'te_ind') and hasattr(model, 'te_y'):
sample_size = min(20, len(model.te_ind))
if sample_size > 0:
indices = np.random.choice(len(model.te_ind), sample_size, replace=False)
sample_ind = model.te_ind[indices]
sample_y = model.te_y[indices]
# Get current time step
current_time = max(model.unique_train_time) if model.unique_train_time else 0
sample_time = np.ones_like(indices) * current_time
# Get prediction and variance
pred, _ = model.model_test(sample_ind, sample_y, sample_time)
# Calculate variance between prediction and actual value
variance = torch.var((pred.squeeze() - sample_y.squeeze())).item()
return variance
# If no test data, use model internal state estimate
if hasattr(model, 'post_U_v'):
# Use average variance of posterior distribution as proxy for prediction uncertainty
total_var = 0
count = 0
for mode in range(len(model.post_U_v)):
if model.post_U_v[mode].numel() > 0:
total_var += torch.mean(torch.diagonal(model.post_U_v[mode][:, :, :, -1], dim1=1, dim2=2)).item()
count += 1
if count > 0:
return total_var / count
# Default return a fixed value
return 1.0
except Exception as e:
logger.error(f"Error calculating prediction variance: {e}")
return 1.0
def _evaluate_with_additional_points(self, model, additional_points):
"""Evaluate model prediction ability after adding extra points
Args:
model: Current model
additional_points: Extra points to add, each element (indices, y, time_ind)
Returns:
Prediction variance after adding points
"""
try:
# Current coreset
current_coreset = self.get_coreset_data()
# Merge current coreset and extra points
combined_coreset = current_coreset + additional_points
# Calculate prediction ability based on merged coreset
# Note: We use heuristic estimation instead of actually training model
# Actually training model would be too computationally expensive
# 1. Check if extra points add new modes or time coverage
# Calculate modes and time periods covered by coreset
current_modes = set()
current_times = set()
for indices, _, time_ind in current_coreset:
for mode, idx in enumerate(indices):
current_modes.add((mode, int(idx)))
current_times.add(time_ind)
# Calculate new modes and time periods added by extra points
new_modes = 0
new_times = 0
for indices, _, time_ind in additional_points:
for mode, idx in enumerate(indices):
if (mode, int(idx)) not in current_modes:
new_modes += 1
if time_ind not in current_times:
new_times += 1
# 2. Estimate prediction variance reduction based on new coverage
# Assume each new mode reduces variance by 1%, each new time period by 0.5%
current_variance = self._compute_prediction_variance(model)
estimated_variance = current_variance * (1 - 0.01 * new_modes - 0.005 * new_times)
# 3. Consider variance reduction from increased sample count
# Assume variance inversely proportional to square root of sample count
n1 = len(current_coreset) + 1 # Avoid division by zero
n2 = n1 + len(additional_points)
ratio = math.sqrt(n1 / n2)
estimated_variance = estimated_variance * ratio
return max(0.1, estimated_variance) # Ensure positive variance
except Exception as e:
logger.error(f"Error evaluating extra points: {e}")
return self._compute_prediction_variance(model) # Return current variance on error
def compute_importance_score(self, indices, y, time_ind, model):
"""Calculate data point importance score, adding martingale increment"""
try:
# Original scores
uncertainty = self._compute_uncertainty(indices, time_ind, model)
influence = self._compute_influence(indices, time_ind, model)
novelty = self._compute_novelty(indices, y, time_ind)
# Get first three weights from original weights
w_u, w_i, w_n = self.weights
# Calculate additional martingale increment score
martingale_increment = self._compute_martingale_increment(indices, time_ind, model)
# Combined score
score = (w_u * uncertainty +
w_i * influence +
w_n * novelty +
self.martingale_weight * martingale_increment)
if isinstance(score, torch.Tensor):
score = score.item()
# Store score for statistics
self.score_history.append(score)
if len(self.score_history) > 100: # Keep last 100 scores
self.score_history.pop(0)
return score
except Exception as e:
logger.error(f"Error calculating importance score: {e}")
return 0.5 # Return medium importance on error
def optimal_stopping_strategy(self, candidates, model, budget):
"""Use optimal stopping theory to decide which candidates to add
Args:
candidates: Candidate data points, each (indices, y, time_ind, score)
model: Current model
budget: Maximum data points to add
Returns:
Selected data points list
"""
try:
if not candidates:
return []
# Sort candidate data points
sorted_candidates = sorted(candidates, key=lambda x: x[3], reverse=True)
# Calculate dynamic threshold
if self.score_history:
# Calculate threshold based on score distribution
mean_score = np.mean(self.score_history)
std_score = np.std(self.score_history) if len(self.score_history) > 1 else 0.1
# Dynamic threshold: mean minus part of standard deviation
# Larger coreset means higher threshold
coreset_fill_ratio = len(self.coreset) / self.max_size
threshold = mean_score - (1.0 - coreset_fill_ratio) * std_score
# Adjust threshold based on exploration rate
exploration_rate = self.compute_exploration_rate()
threshold = threshold * (1 - exploration_rate)
else:
# Use initial threshold if no history data
threshold = self.threshold
# Apply threshold filtering
selected = []
for indices, y, time_ind, score in sorted_candidates:
# Check if already in coreset
indices_tuple = tuple(int(idx) for idx in indices)
if indices_tuple in self.coreset_indices:
continue
# If score above threshold and budget not exceeded, select point
if score > threshold and len(selected) < budget:
selected.append((indices, y, time_ind, score))
else:
# Once a point below threshold is encountered, stop
# This is the manifestation of optimal stopping strategy
break
return selected
except Exception as e:
logger.error(f"Error in optimal stopping strategy: {e}")
# On error, simply return first budget points
return sorted_candidates[:budget]
def _apply_bellman_optimization(self, candidates, model):
"""Apply Bellman equation optimization to coreset update
Consider impact of current decision on future, based on dynamic programming
"""
try:
# If too many candidates, first filter with heuristic
if len(candidates) > 10:
# Keep top 10 candidates by score
sorted_candidates = sorted(candidates, key=lambda x: x[3], reverse=True)
candidates = sorted_candidates[:10]
# Current state value
current_value = self._estimate_state_value(model)
best_action = []
max_future_value = current_value # Initial value is taking no action
# Consider all possible candidate point combinations (max 3 to avoid explosion)
max_to_add = min(3, len(candidates))
# Iterate all possible selection quantities
for num_to_add in range(1, max_to_add + 1):
# For each quantity, consider all possible combinations
from itertools import combinations
for combo in combinations(candidates, num_to_add):
# Check if combination valid (within budget)
if len(self.coreset) + len(combo) <= self.max_size:
# Calculate immediate reward for this combination
immediate_reward = sum(c[3] for c in combo)
# Estimate future value after selecting this combination
points_to_add = [(c[0], c[1], c[2]) for c in combo]
future_state_value = self._estimate_future_value(model, points_to_add)
# Total value = immediate reward + discount factor * future value
total_value = immediate_reward + self.discount_factor * future_state_value
if total_value > max_future_value:
max_future_value = total_value
best_action = list(combo)
return best_action
except Exception as e:
logger.error(f"Error in Bellman optimization: {e}")
# On error, simply return highest scoring points
return sorted(candidates, key=lambda x: x[3], reverse=True)[:3]
def _estimate_state_value(self, model):
"""Estimate current state's value function
Args:
model: Current model
Returns:
State value, higher means stronger prediction ability
"""
# Simply use inverse of prediction variance as state value
prediction_variance = self._compute_prediction_variance(model)
if prediction_variance > 0:
return 1.0 / prediction_variance
return 0.0
def _estimate_future_value(self, model, additional_points):
"""Estimate future state value after adding points
Args:
model: Current model
additional_points: Extra points to add
Returns:
Estimated future state value
"""
# Estimate prediction variance after adding points
future_variance = self._evaluate_with_additional_points(model, additional_points)
if future_variance > 0:
return 1.0 / future_variance
return 0.0
def compute_exploration_rate(self):
"""Calculate current exploration rate, using martingale measure transform
Dynamically adjusts exploration rate based on coreset saturation and model confidence
"""
try:
# Base exploration rate
base_epsilon = self.epsilon_0 * math.exp(-self.lambda_decay * self.confidence)
# Coreset fill ratio effect on exploration
coreset_fill_ratio = len(self.coreset) / self.max_size
# Higher fill ratio means lower exploration rate
fill_factor = math.exp(-2.0 * coreset_fill_ratio)
# Calculate current information uncertainty
if self.prediction_history:
uncertainties = [entry['uncertainty'] for entry in self.prediction_history]
uncertainty = np.mean(uncertainties)
else:
uncertainty = 0.5
# Higher information uncertainty means higher exploration rate
uncertainty_factor = 1.0 + uncertainty
# Final exploration rate
adjusted_epsilon = base_epsilon * fill_factor * uncertainty_factor
# Ensure within reasonable range
return min(1.0, max(0.01, adjusted_epsilon))
except Exception as e:
logger.error(f"Error calculating exploration rate: {e}")
return self.epsilon_0 # Return initial exploration rate on error
def update_coreset(self, data_batch, model):
"""Update coreset using martingale theory, optimal stopping strategy, and improved replacement"""
try:
if not data_batch:
return [], []
# Store importance scores
candidates = []
scores = []
# Calculate importance score for each data point
for batch_idx, (indices, y, time_ind) in enumerate(data_batch):
# Check if already in coreset
indices_tuple = tuple(int(idx) for idx in indices)
if indices_tuple in self.coreset_indices:
continue
# Calculate importance score
score = self.compute_importance_score(indices, y, time_ind, model)
# Add to candidate list
candidates.append((indices, y, time_ind, score))
scores.append(score)
# Update threshold
if self.adaptive:
self.update_threshold(scores)
# Calculate maximum data points that can be added
remaining_budget = self.max_size - len(self.coreset)