-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path_inputs.py
More file actions
57 lines (46 loc) · 1.63 KB
/
_inputs.py
File metadata and controls
57 lines (46 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from settings import DEVICE
from utils.tensors import zeros_
from ._matrix_samplers import NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler
_normal_dims = [
(1, 1, 1),
(4, 3, 1),
(4, 3, 2),
(4, 3, 3),
(9, 11, 5),
(9, 11, 9),
]
_zero_dims = [
(1, 1, 0),
(4, 3, 0),
(9, 11, 0),
]
_stationarity_dims = [
(20, 10, 10),
(20, 10, 5),
(20, 10, 1),
(20, 100, 1),
(20, 100, 19),
]
_scales = [0.0, 1e-10, 1e3, 1e5, 1e10, 1e15]
_rng = torch.Generator(device=DEVICE).manual_seed(0)
matrices = [NormalSampler(m, n, r)(_rng) for m, n, r in _normal_dims]
zero_matrices = [zeros_([m, n]) for m, n, _ in _zero_dims]
strong_matrices = [StrongSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims]
strictly_weak_matrices = [StrictlyWeakSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims]
non_weak_matrices = [NonWeakSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims]
scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices]
non_strong_matrices = strictly_weak_matrices + non_weak_matrices
typical_matrices = zero_matrices + matrices + strong_matrices + non_strong_matrices
scaled_matrices_2_plus_rows = [matrix for matrix in scaled_matrices if matrix.shape[0] >= 2]
typical_matrices_2_plus_rows = [matrix for matrix in typical_matrices if matrix.shape[0] >= 2]
# It seems that NashMTL does not work for matrices with 1 row, so we make different matrices for it.
_nashmtl_dims = [
(3, 1, 1),
(4, 3, 1),
(4, 3, 2),
(4, 3, 3),
(9, 11, 5),
(9, 11, 9),
]
nash_mtl_matrices = [NormalSampler(m, n, r)(_rng) for m, n, r in _nashmtl_dims]