-
Notifications
You must be signed in to change notification settings - Fork 68
Expand file tree
/
Copy pathsimplex.py
More file actions
108 lines (81 loc) · 2.83 KB
/
simplex.py
File metadata and controls
108 lines (81 loc) · 2.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from .constraint import Constraint
import numpy as np
class Simplex(Constraint):
"""Simplex constraint
A set of the form
:math:`X = \left\{x \in \mathrm{I\!R}^n : \sum_{i=1}^{n} x_i = \\alpha, x \geq 0\\right\}`
where :math:`\\alpha` is a positive constant.
"""
def __init__(self, alpha: float = 1.0): # unless specified, alpha=1
"""Constructor for a Simplex constraint
:param alpha: size parameter of simplex (default: 1)
:return: new instance of Simplex with given alpha
"""
if alpha <= 0:
raise Exception("Alpha must be a positive number")
self.__alpha = float(alpha)
@property
def alpha(self):
"""Returns the simplex value alpha"""
return self.__alpha
def distance_squared(self, u):
raise NotImplementedError()
def project(self, y):
"""Computes the projection of a given point :math:`y\in{\\rm I\!R}^n`
onto the current simplex.
:param y: given point; must be a list of numbers (float, int) or
a numpy n-dim array (`ndarray`)
:return: the projection point in :math:`{\\rm I\!R}^n` as a numpy array of float64s
"""
def __pop_all(z, indices):
for index in sorted(indices, reverse=True):
del z[index]
return z
a = self.__alpha
# 1 ----
v = [y[0]]
v_size_old = -1
v_tilde = []
rho = y[0] - a
# 2 ----
for yn in y[1:]:
if yn > rho:
rho += (yn - rho) / (len(v) + 1)
if rho > yn - a:
v.append(yn)
else:
v_tilde.extend(v)
v = [yn]
rho = yn - a
# 3 ----
if len(v_tilde) > 0:
for v_tilde_i in v_tilde:
if v_tilde_i > rho:
v.append(v_tilde_i)
rho += (v_tilde_i - rho) / len(v)
# 4 ----
keep_running = True
while keep_running:
hit_list = []
current_len_v = len(v)
for j, vj in enumerate(v):
if vj <= rho:
hit_list += [j]
current_len_v -= 1
rho += (rho - vj) / current_len_v
v = __pop_all(v, hit_list)
keep_running = current_len_v != v_size_old
v_size_old = current_len_v
# 6 ----
ufunc = np.vectorize(lambda s: max(s - rho, 0), otypes=[np.float64])
x = ufunc(np.array(y, dtype=np.float64))
# result ----
return x
def is_convex(self):
"""Whether the set is convex (`True`)"""
return True
def is_compact(self):
"""Whether the set is compact (`True`)"""
return True
def dimension(self):
return None