Skip to content

Commit 791fbe2

Browse files
committed
Adapt prototype decomposer to common decomposer interface
1 parent 9beaec7 commit 791fbe2

4 files changed

Lines changed: 118 additions & 320 deletions

File tree

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import random
2+
from collections import defaultdict
3+
4+
from graphblas.binary import plus
5+
from graphblas.core.dtypes import BOOL, INT32
6+
from graphblas.core.matrix import Matrix
7+
from graphblas.core.vector import Vector
8+
9+
from cfpq_decomposer.abstract_decomposer import AbstractDecomposer
10+
11+
12+
class PrototypeDecomposer(AbstractDecomposer):
13+
def row_based_decompose(self, M: Matrix):
14+
n_rows, n_cols = M.shape
15+
16+
I, J, V = M.to_coo()
17+
18+
rows = defaultdict(set)
19+
for i, j in zip(I, J):
20+
rows[i].add(j)
21+
22+
p = 2147483647
23+
num_hashes = 3
24+
hash_funcs = []
25+
for _ in range(num_hashes):
26+
a = random.randint(1, p - 1)
27+
b = random.randint(0, p - 1)
28+
hash_funcs.append((a, b))
29+
30+
minhashes = dict()
31+
32+
for i, S_i in rows.items():
33+
minhash_values = []
34+
if len(S_i) < 5:
35+
continue
36+
for a, b in hash_funcs:
37+
min_hash = min(((a * x + b) % p) for x in S_i)
38+
minhash_values.append(min_hash)
39+
minhashes[i] = tuple(minhash_values)
40+
41+
master_hashes = dict()
42+
for i, minhash_values in minhashes.items():
43+
master_hash = hash(minhash_values)
44+
master_hashes[i] = master_hash
45+
46+
buckets = defaultdict(list)
47+
for i, master_hash in master_hashes.items():
48+
buckets[master_hash].append(i)
49+
50+
buckets = {h: idxs for h, idxs in buckets.items() if len(idxs) >= 5}
51+
52+
LEFT_columns = []
53+
RIGHT_rows = []
54+
55+
for h, B in buckets.items():
56+
N = len(B)
57+
M_B: Matrix = M[B, :].new()
58+
A1 = M_B.dup(dtype=INT32).reduce_columnwise(plus).new()
59+
60+
threshold = int(0.95 * N)
61+
A2: Vector = A1.select('>=', threshold).new()
62+
63+
if A2.nvals == 0:
64+
continue
65+
66+
S_A2 = set(A2.to_coo()[0])
67+
68+
B_prime = [i for i in B if S_A2 <= rows[i]]
69+
70+
K = len(B_prime)
71+
if K == 0:
72+
continue
73+
74+
M_B_prime = M[B_prime, :].new()
75+
A3 = M_B_prime.dup(dtype=INT32).reduce_columnwise(plus)
76+
77+
threshold = int(0.95 * K)
78+
A4 = A3.select('>=', threshold).new()
79+
80+
if A4.nvals == 0:
81+
continue
82+
83+
S_A4 = set(A4.to_coo()[0])
84+
85+
B_double_prime = [i for i in B_prime if S_A4 <= rows[i]]
86+
87+
if len(B_double_prime) < 5:
88+
continue
89+
90+
RIGHT_rows.append(A4)
91+
92+
CORE = Vector(BOOL, size=n_rows)
93+
for i in B_double_prime:
94+
CORE[i] = True
95+
LEFT_columns.append(CORE)
96+
97+
num_buckets_remaining = len(LEFT_columns)
98+
if num_buckets_remaining == 0:
99+
return Matrix(M.dtype, M.nrows, 0), Matrix(M.dtype, 0, M.ncols)
100+
101+
LEFT = Matrix(bool, n_rows, num_buckets_remaining)
102+
for idx, CORE in enumerate(LEFT_columns):
103+
LEFT[:, idx] = CORE
104+
105+
RIGHT = Matrix(bool, num_buckets_remaining, n_cols)
106+
for idx, A4 in enumerate(RIGHT_rows):
107+
RIGHT[idx, :] = A4
108+
109+
return LEFT, RIGHT

cfpq_matrix/matrix_utils.py

Lines changed: 1 addition & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import random
2-
from collections import defaultdict
31
from typing import Any, Tuple
42

53
import graphblas
64
import numpy as np
75
from graphblas.binary import plus
8-
from graphblas.core.dtypes import DataType, BOOL, INT32
6+
from graphblas.core.dtypes import DataType
97
from graphblas.core.matrix import Matrix
108
from graphblas.core.vector import Vector
119

@@ -31,167 +29,6 @@ def expand_matrix(matrix: Matrix, new_shape: Tuple[int, int]) -> Matrix:
3129
(rows, columns, values) = matrix.to_coo()
3230
return Matrix.from_coo(rows, columns, values, dtype=matrix.dtype, nrows=new_shape[0], ncols=new_shape[1])
3331

34-
def row_based_decompose(M: Matrix):
35-
"""
36-
Decomposes a sparse boolean matrix M into LEFT, RIGHT, and M' such that M = LEFT * RIGHT + M'.
37-
38-
Parameters:
39-
M (gb.Matrix): Input sparse boolean matrix.
40-
41-
Returns:
42-
LEFT (gb.Matrix): Left factor matrix.
43-
RIGHT (gb.Matrix): Right factor matrix.
44-
M_prime (gb.Matrix): Remainder matrix after decomposition.
45-
"""
46-
n_rows, n_cols = M.shape
47-
48-
I, J, V = M.to_coo()
49-
50-
rows = defaultdict(set)
51-
for i, j in zip(I, J):
52-
rows[i].add(j)
53-
54-
p = 2147483647
55-
num_hashes = 5 # TODO 2 or 3 is probably better for real world data
56-
hash_funcs = []
57-
for _ in range(num_hashes):
58-
a = random.randint(1, p - 1)
59-
b = random.randint(0, p - 1)
60-
hash_funcs.append((a, b))
61-
62-
minhashes = dict()
63-
64-
for i, S_i in rows.items():
65-
minhash_values = []
66-
if len(S_i) < 5:
67-
continue
68-
for a, b in hash_funcs:
69-
min_hash = min(((a * x + b) % p) for x in S_i)
70-
minhash_values.append(min_hash)
71-
minhashes[i] = tuple(minhash_values)
72-
73-
master_hashes = dict()
74-
for i, minhash_values in minhashes.items():
75-
master_hash = hash(minhash_values)
76-
master_hashes[i] = master_hash
77-
78-
buckets = defaultdict(list)
79-
for i, master_hash in master_hashes.items():
80-
buckets[master_hash].append(i)
81-
82-
buckets = {h: idxs for h, idxs in buckets.items() if len(idxs) >= 5}
83-
84-
LEFT_columns = []
85-
RIGHT_rows = []
86-
87-
for h, B in buckets.items():
88-
N = len(B)
89-
M_B: Matrix = M[B, :].new()
90-
A1 = M_B.dup(dtype=INT32).reduce_columnwise(plus).new()
91-
92-
threshold = int(0.95 * N)
93-
A2: Vector = A1.select('>=', threshold).new()
94-
95-
if A2.nvals == 0:
96-
continue
97-
98-
S_A2 = set(A2.to_coo()[0])
99-
100-
B_prime = [i for i in B if S_A2 <= rows[i]]
101-
102-
K = len(B_prime)
103-
if K == 0:
104-
continue
105-
106-
M_B_prime = M[B_prime, :].new()
107-
A3 = M_B_prime.dup(dtype=INT32).reduce_columnwise(plus)
108-
109-
threshold = int(0.95 * K)
110-
A4 = A3.select('>=', threshold).new()
111-
112-
if A4.nvals == 0:
113-
continue
114-
115-
S_A4 = set(A4.to_coo()[0])
116-
117-
B_double_prime = [i for i in B_prime if S_A4 <= rows[i]]
118-
119-
if len(B_double_prime) < 5:
120-
continue
121-
122-
RIGHT_rows.append(A4)
123-
124-
CORE = Vector(BOOL, size=n_rows)
125-
for i in B_double_prime:
126-
CORE[i] = True
127-
LEFT_columns.append(CORE)
128-
129-
num_buckets_remaining = len(LEFT_columns)
130-
if num_buckets_remaining == 0:
131-
return Matrix(M.dtype, M.nrows, 0), Matrix(M.dtype, 0, M.ncols)
132-
133-
LEFT = Matrix(bool, n_rows, num_buckets_remaining)
134-
for idx, CORE in enumerate(LEFT_columns):
135-
LEFT[:, idx] = CORE
136-
137-
RIGHT = Matrix(bool, num_buckets_remaining, n_cols)
138-
for idx, A4 in enumerate(RIGHT_rows):
139-
RIGHT[idx, :] = A4
140-
141-
return LEFT, RIGHT
142-
143-
def column_based_decompose(M: Matrix):
144-
LEFT_T, RIGHT_T = row_based_decompose(M.T.new())
145-
return RIGHT_T.T.new(), LEFT_T.T.new()
146-
147-
def decompose(M: Matrix):
148-
accumulated_LEFT = []
149-
accumulated_RIGHT = []
150-
iteration = 0
151-
152-
init_nvals = M.nvals
153-
if init_nvals == 0:
154-
return Matrix(M.dtype, M.nrows, 0), Matrix(M.dtype, 0, M.ncols)
155-
156-
while True:
157-
iteration += 1
158-
nvals_before = M.nvals
159-
160-
LEFT1, RIGHT1 = row_based_decompose(M)
161-
162-
if LEFT1.nvals != 0:
163-
M = M.dup(mask=~LEFT1.mxm(RIGHT1, op=graphblas.semiring.any_pair).new(dtype=BOOL).S)
164-
165-
LEFT2, RIGHT2 = column_based_decompose(M)
166-
167-
if LEFT2.nvals != 0:
168-
M = M.dup(mask=~LEFT2.mxm(RIGHT2, op=graphblas.semiring.any_pair).new(dtype=BOOL).S)
169-
170-
nvals_LEFT_RIGHT = LEFT1.nvals + RIGHT1.nvals + LEFT2.nvals + RIGHT2.nvals
171-
172-
nvals_after = M.nvals
173-
delta_M = nvals_before - nvals_after
174-
175-
reduction_ratio = delta_M / nvals_before if nvals_before > 0 else 0
176-
size_ratio = nvals_LEFT_RIGHT / delta_M if delta_M > 0 else float('inf')
177-
178-
accumulated_LEFT.extend([LEFT1, LEFT2])
179-
accumulated_RIGHT.extend([RIGHT1, RIGHT2])
180-
181-
if reduction_ratio < 0.05 or size_ratio > 0.3:
182-
break
183-
184-
if M.nvals == 0:
185-
break
186-
187-
if not accumulated_LEFT or not accumulated_RIGHT:
188-
return Matrix(BOOL, nrows=M.nrows, ncols=0), Matrix(BOOL, nrows=0, ncols=M.ncols)
189-
190-
LEFT = stack([accumulated_LEFT])
191-
RIGHT = stack([[RIGHT] for RIGHT in accumulated_RIGHT])
192-
193-
return LEFT, RIGHT
194-
19532
def stack(matrix_grid: list[list[Matrix]]) -> Matrix:
19633
"""
19734
Stack a 2D list of matrices into a single larger matrix.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from cfpq_decomposer.decomposer import Decomposer
2+
from cfpq_decomposer.prototype_decomposer import PrototypeDecomposer
3+
from test.cfpq_decomposer.test_abstract_decomposer import TestAbstractDecomposer
4+
5+
6+
class TestPrototypeDecomposer(TestAbstractDecomposer):
7+
def create_decomposer(self) -> Decomposer:
8+
return PrototypeDecomposer()

0 commit comments

Comments
 (0)