Skip to content

Commit 5e4d26d

Browse files
authored
test: Remove redundant dtype kwargs (#499)
1 parent fc498b1 commit 5e4d26d

File tree

3 files changed

+37
-45
lines changed

3 files changed

+37
-45
lines changed

tests/unit/aggregation/_asserts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def assert_linear_under_scaling(
6767
"""Tests empirically that a given `Aggregator` satisfies the linear under scaling property."""
6868

6969
for _ in range(n_runs):
70-
c1 = rand_(matrix.shape[0], dtype=matrix.dtype)
71-
c2 = rand_(matrix.shape[0], dtype=matrix.dtype)
72-
alpha = rand_([], dtype=matrix.dtype)
73-
beta = rand_([], dtype=matrix.dtype)
70+
c1 = rand_(matrix.shape[0])
71+
c2 = rand_(matrix.shape[0])
72+
alpha = rand_([])
73+
beta = rand_([])
7474

7575
x1 = aggregator(torch.diag(c1) @ matrix)
7676
x2 = aggregator(torch.diag(c2) @ matrix)

tests/unit/aggregation/_matrix_samplers.py

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,26 @@
11
from abc import ABC, abstractmethod
22

33
import torch
4-
from settings import DTYPE
54
from torch import Tensor
65
from torch.nn.functional import normalize
76
from utils.tensors import randint_, randn_, randperm_, zeros_
87

98

109
class MatrixSampler(ABC):
11-
"""Abstract base class for sampling matrices of a given shape, rank and dtype."""
10+
"""Abstract base class for sampling matrices of a given shape, rank."""
1211

13-
def __init__(self, m: int, n: int, rank: int, dtype: torch.dtype = DTYPE):
14-
self._check_params(m, n, rank, dtype)
12+
def __init__(self, m: int, n: int, rank: int):
13+
self._check_params(m, n, rank)
1514
self.m = m
1615
self.n = n
1716
self.rank = rank
18-
self.dtype = dtype
1917

20-
def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
18+
def _check_params(self, m: int, n: int, rank: int) -> None:
2119
"""Checks that the provided __init__ parameters are acceptable."""
2220

2321
assert m >= 0
2422
assert n >= 0
2523
assert 0 <= rank <= min(m, n)
26-
assert dtype in {torch.float32, torch.float64}
2724

2825
@abstractmethod
2926
def __call__(self, rng: torch.Generator | None = None) -> Tensor:
@@ -35,24 +32,24 @@ def __repr__(self) -> str:
3532
def __str__(self) -> str:
3633
return (
3734
f"{self.__class__.__name__.replace('MatrixSampler', '')}"
38-
f"({self.m}x{self.n}r{self.rank}:{str(self.dtype)[6:]})"
35+
f"({self.m}x{self.n}r{self.rank})"
3936
)
4037

4138

4239
class NormalSampler(MatrixSampler):
43-
"""Sampler for random normal matrices of shape [m, n] with provided rank and dtype."""
40+
"""Sampler for random normal matrices of shape [m, n] with provided rank."""
4441

4542
def __call__(self, rng: torch.Generator | None = None) -> Tensor:
46-
U = _sample_orthonormal_matrix(self.m, dtype=self.dtype, rng=rng)
47-
Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng)
48-
S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng)))
43+
U = _sample_orthonormal_matrix(self.m, rng=rng)
44+
Vt = _sample_orthonormal_matrix(self.n, rng=rng)
45+
S = torch.diag(torch.abs(randn_([self.rank], generator=rng)))
4946
A = U[:, : self.rank] @ S @ Vt[: self.rank, :]
5047
return A
5148

5249

5350
class StrongSampler(MatrixSampler):
5451
"""
55-
Sampler for random strongly stationary matrices of shape [m, n] with provided rank and dtype.
52+
Sampler for random strongly stationary matrices of shape [m, n] with provided rank.
5653
5754
Definition: A matrix A is said to be strongly stationary if there exists a vector 0 < v such
5855
that v^T A = 0.
@@ -61,25 +58,24 @@ class StrongSampler(MatrixSampler):
6158
orthogonal to v.
6259
"""
6360

64-
def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
65-
super()._check_params(m, n, rank, dtype)
61+
def _check_params(self, m: int, n: int, rank: int) -> None:
62+
super()._check_params(m, n, rank)
6663
assert 1 < m
6764
assert 0 < rank <= min(m - 1, n)
6865

6966
def __call__(self, rng: torch.Generator | None = None) -> Tensor:
70-
v = torch.abs(randn_([self.m], dtype=self.dtype, generator=rng))
67+
v = torch.abs(randn_([self.m], generator=rng))
7168
U1 = normalize(v, dim=0).unsqueeze(1)
7269
U2 = _sample_semi_orthonormal_complement(U1, rng=rng)
73-
Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng)
74-
S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng)))
70+
Vt = _sample_orthonormal_matrix(self.n, rng=rng)
71+
S = torch.diag(torch.abs(randn_([self.rank], generator=rng)))
7572
A = U2[:, : self.rank] @ S @ Vt[: self.rank, :]
7673
return A
7774

7875

7976
class StrictlyWeakSampler(MatrixSampler):
8077
"""
81-
Sampler for random strictly weakly stationary matrices of shape [m, n] with provided rank and
82-
dtype.
78+
Sampler for random strictly weakly stationary matrices of shape [m, n] with provided rank.
8379
8480
Definition: A matrix A is said to be weakly stationary if there exists a vector 0 <= v, v != 0,
8581
such that v^T A = 0.
@@ -97,60 +93,57 @@ class StrictlyWeakSampler(MatrixSampler):
9793
stationary.
9894
"""
9995

100-
def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
101-
super()._check_params(m, n, rank, dtype)
96+
def _check_params(self, m: int, n: int, rank: int) -> None:
97+
super()._check_params(m, n, rank)
10298
assert 1 < m
10399
assert 0 < rank <= min(m - 1, n)
104100

105101
def __call__(self, rng: torch.Generator | None = None) -> Tensor:
106-
u = torch.abs(randn_([self.m], dtype=self.dtype, generator=rng))
102+
u = torch.abs(randn_([self.m], generator=rng))
107103
split_index = randint_(1, self.m, [], generator=rng).item()
108104
shuffled_range = randperm_(self.m, generator=rng)
109-
v = zeros_(self.m, dtype=self.dtype)
105+
v = zeros_(self.m)
110106
v[shuffled_range[:split_index]] = normalize(u[shuffled_range[:split_index]], dim=0)
111-
v_prime = zeros_(self.m, dtype=self.dtype)
107+
v_prime = zeros_(self.m)
112108
v_prime[shuffled_range[split_index:]] = normalize(u[shuffled_range[split_index:]], dim=0)
113109
U1 = torch.stack([v, v_prime]).T
114110
U2 = _sample_semi_orthonormal_complement(U1, rng=rng)
115111
U = torch.hstack([U1, U2])
116-
Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng)
117-
S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng)))
112+
Vt = _sample_orthonormal_matrix(self.n, rng=rng)
113+
S = torch.diag(torch.abs(randn_([self.rank], generator=rng)))
118114
A = U[:, 1 : self.rank + 1] @ S @ Vt[: self.rank, :]
119115
return A
120116

121117

122118
class NonWeakSampler(MatrixSampler):
123119
"""
124-
Sampler for a random non weakly-stationary matrices of shape [m, n] with provided rank and
125-
dtype.
120+
Sampler for a random non weakly-stationary matrices of shape [m, n] with provided rank.
126121
127122
Obtaining such a matrix is done by sampling a positive u, and by then sampling a matrix A that
128123
has u as one of its left-singular vectors, with positive singular value s. Any 0 <= v, v != 0,
129124
satisfies v^T A != 0. Otherwise, we would have 0 = v^T A A^T u = s v^T u > 0, which is a
130125
contradiction. A is thus not weakly stationary.
131126
"""
132127

133-
def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
134-
super()._check_params(m, n, rank, dtype)
128+
def _check_params(self, m: int, n: int, rank: int) -> None:
129+
super()._check_params(m, n, rank)
135130
assert 0 < rank
136131

137132
def __call__(self, rng: torch.Generator | None = None) -> Tensor:
138-
u = torch.abs(randn_([self.m], dtype=self.dtype, generator=rng))
133+
u = torch.abs(randn_([self.m], generator=rng))
139134
U1 = normalize(u, dim=0).unsqueeze(1)
140135
U2 = _sample_semi_orthonormal_complement(U1, rng=rng)
141136
U = torch.hstack([U1, U2])
142-
Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng)
143-
S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng)))
137+
Vt = _sample_orthonormal_matrix(self.n, rng=rng)
138+
S = torch.diag(torch.abs(randn_([self.rank], generator=rng)))
144139
A = U[:, : self.rank] @ S @ Vt[: self.rank, :]
145140
return A
146141

147142

148-
def _sample_orthonormal_matrix(
149-
dim: int, dtype: torch.dtype, rng: torch.Generator | None = None
150-
) -> Tensor:
143+
def _sample_orthonormal_matrix(dim: int, rng: torch.Generator | None = None) -> Tensor:
151144
"""Uniformly samples a random orthonormal matrix of shape [dim, dim]."""
152145

153-
return _sample_semi_orthonormal_complement(zeros_([dim, 0], dtype=dtype), rng=rng)
146+
return _sample_semi_orthonormal_complement(zeros_([dim, 0]), rng=rng)
154147

155148

156149
def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None = None) -> Tensor:
@@ -161,9 +154,8 @@ def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None =
161154
:param Q: A semi-orthonormal matrix (i.e. Q^T Q = I) of shape [m, k], with k <= m.
162155
"""
163156

164-
dtype = Q.dtype
165157
m, k = Q.shape
166-
A = randn_([m, m - k], dtype=dtype, generator=rng)
158+
A = randn_([m, m - k], generator=rng)
167159

168160
# project A onto the orthogonal complement of Q
169161
A_proj = A - Q @ (Q.T @ A)

tests/unit/aggregation/test_constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
def _make_aggregator(matrix: Tensor) -> Constant:
2020
n_rows = matrix.shape[0]
21-
weights = tensor_([1.0 / n_rows] * n_rows, dtype=matrix.dtype)
21+
weights = tensor_([1.0 / n_rows] * n_rows)
2222
return Constant(weights)
2323

2424

0 commit comments

Comments
 (0)