-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
138 lines (127 loc) · 4.85 KB
/
utils.py
File metadata and controls
138 lines (127 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# rlte/utils.py
# Hilfsfunktionen: Seeding, Schedules, Logistic-Normal-Bijection, Allokation, Logging
# Verweise: Eq. (6)-(11); App. D für Normalisierung. :contentReference[oaicite:4]{index=4}
from __future__ import annotations
import math
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
import torch
def seed_everything(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def linear_sigma_schedule(iter_idx: int, total_iters: int, sigma_init: float, sigma_final: float) -> float:
# Eq. (11)
if total_iters <= 1:
return sigma_final
return (sigma_final - sigma_init) * ((iter_idx - 1) / (total_iters - 1)) + sigma_init
def shifted_half_normal_int(sigma: float, max_sigma_mult: float = 5.0) -> int:
"""
Ordergröße ~ round( 1 + sigma*|Z| ), Deckel bei 1 + 5*sigma (App. A.1).
"""
z = abs(np.random.randn())
z = min(z, max_sigma_mult)
return max(1, int(round(1.0 + sigma * z)))
def softplus(x: torch.Tensor) -> torch.Tensor:
return torch.log1p(torch.exp(x))
# === Logistic-Normal Transformation h / h^{-1} (Eq. (6)-(8)) ===
def h_logistic_normal(x: torch.Tensor) -> torch.Tensor:
"""
h: R^K -> S_K (Simplex mit K+1 Komponenten)
a_k = exp(x_k)/(1 + sum_{l=0}^{K-1} exp(x_l)), k=0..K-1
a_K = 1/(1 + sum_{l=0}^{K-1} exp(x_l))
"""
ex = torch.exp(x)
denom = 1.0 + torch.sum(ex, dim=-1, keepdim=True)
a_rest = ex / denom
a_last = 1.0 / denom
return torch.cat([a_rest, a_last], dim=-1)
def h_inv_logistic_normal(a: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
"""
h^{-1}: S_K -> R^K, x_k = log(a_k/a_K), k=0..K-1
"""
a = torch.clamp(a, eps, 1.0) # Numerikschutz
a_k = a[..., :-1]
a_K = a[..., -1:]
return torch.log(a_k / a_K)
def bijection_test(K: int = 6, trials: int = 100, tol: float = 1e-6) -> bool:
"""
Acceptance/Unittest: h/h^{-1} ist Bijektion (App. Tests).
"""
ok = True
for _ in range(trials):
x = torch.randn(4, K) # batch
a = h_logistic_normal(x)
x_rec = h_inv_logistic_normal(a)
if not torch.allclose(x, x_rec, atol=tol, rtol=0):
ok = False
break
return ok
# === Integer-Allokation & Re-Allokation (Sec. 3.2, p.4) ===
def integer_allocation(a: np.ndarray, M: int, K: int) -> Tuple[int, List[int], int]:
"""
a ∈ S_K (K+1 Komponenten): a0 (Market Sell), a1..a_{K-1} (Limit k-Ticks über Best-Bid),
aK (Hold). Integer-Allokation so, dass Summe ≤ M.
Rückgabe:
mkt_qty, limit_qtys[k=1..K-1], hold_qty
"""
assert len(a) == K + 1
# zuerst die "aktiven" Komponenten (ohne Hold) grob runden, dann ggf. kürzen
target = a * M
active = np.copy(target)
active[-1] = 0.0 # Hold nicht zuweisen, wird Rest
rounded = np.rint(active).astype(int)
total = int(rounded.sum())
if total > M:
# Kürzen – kleine Heuristik: reduziere die größten Abweichungen zuerst
diffs = (rounded - active).astype(float)
order = np.argsort(-np.abs(diffs)) # größte Rundungsfehler zuerst
i = 0
while total > M and i < len(order):
idx = order[i]
if idx == K: # skip Hold
i += 1
continue
if rounded[idx] > 0:
rounded[idx] -= 1
total -= 1
else:
i += 1
# Rest = Hold
hold_qty = max(0, M - int(rounded.sum()))
mkt_qty = int(rounded[0])
limit_qtys = [int(rounded[k]) for k in range(1, K)]
return mkt_qty, limit_qtys, hold_qty
@dataclass
class AllocationDiff:
cancel_levels: Dict[int, int] # level -> wie viele Lots müssen wir stornieren (nur Nötiges)
add_levels: Dict[int, int] # level -> wie viele Lots neu platzieren
def compute_reallocation(current_level_counts: Dict[int, int],
desired_level_counts: Dict[int, int]) -> AllocationDiff:
"""
Re-Allokation: storniere nur Nötiges; (höchste Queue-Position zuerst passiert später in env).
"""
cancel_levels, add_levels = {}, {}
levels = set(list(current_level_counts.keys()) + list(desired_level_counts.keys()))
for L in levels:
cur = current_level_counts.get(L, 0)
des = desired_level_counts.get(L, 0)
if cur > des:
cancel_levels[L] = cur - des
elif des > cur:
add_levels[L] = des - cur
return AllocationDiff(cancel_levels=cancel_levels, add_levels=add_levels)
# Einfache Logger-Hilfe
class SmoothedValue:
def __init__(self, momentum: float = 0.95):
self.value = None
self.m = momentum
def update(self, x: float) -> float:
if self.value is None:
self.value = x
else:
self.value = self.m * self.value + (1.0 - self.m) * x
return self.value