Skip to content

Commit 5df9420

Browse files
committed
Refactor PrototypeDecomposer
1 parent 31a446a commit 5df9420

1 file changed

Lines changed: 133 additions & 88 deletions

File tree

Lines changed: 133 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from collections import defaultdict
2+
from dataclasses import dataclass
3+
from typing import Dict, List, Set, Tuple, Optional
4+
25
import random
36

47
from graphblas.binary import plus
@@ -7,107 +10,149 @@
710
from graphblas.core.vector import Vector
811

912
from cfpq_decomposer.abstract_decomposer import AbstractDecomposer
10-
from cfpq_decomposer.constants import HASH_PRIME_MODULUS, HASH_FUNCTIONS_COUNT, PROTOTYPE_MIN_LSH_BUCKET_SIZE, \
11-
PROTOTYPE_OUTLIER_THRESHOLD, PROTOTYPE_MIN_VALUES_PER_ROW
13+
from cfpq_decomposer.constants import (
14+
HASH_PRIME_MODULUS,
15+
HASH_FUNCTIONS_COUNT,
16+
PROTOTYPE_MIN_LSH_BUCKET_SIZE,
17+
PROTOTYPE_OUTLIER_THRESHOLD,
18+
PROTOTYPE_MIN_VALUES_PER_ROW,
19+
)
1220

1321

14-
class PrototypeDecomposer(AbstractDecomposer):
15-
def row_based_decompose(self, input_matrix: Matrix):
16-
number_of_rows, number_of_columns = input_matrix.shape
17-
row_indices, column_indices, _ = input_matrix.to_coo()
22+
@dataclass
23+
class BucketFactor:
24+
membership_vector: Vector
25+
column_signature: Vector
1826

19-
row_to_column_sets = defaultdict(set)
20-
for row_index, column_index in zip(row_indices, column_indices):
21-
row_to_column_sets[row_index].add(column_index)
2227

23-
hash_coefficients_and_offsets = []
28+
class PrototypeDecomposer(AbstractDecomposer):
29+
def row_based_decompose(self, matrix: Matrix) -> Tuple[Matrix, Matrix]:
30+
row_to_column_sets = self._extract_row_to_column_sets(matrix)
31+
row_minhash_signatures = self._compute_row_minhash_signatures(row_to_column_sets)
32+
master_hash_to_rows = self._group_rows_by_master_hash(row_minhash_signatures)
33+
bucket_factors = self._build_bucket_factors(master_hash_to_rows, row_to_column_sets, matrix)
34+
return self._build_factor_matrices(bucket_factors, matrix)
35+
36+
@staticmethod
37+
def _extract_row_to_column_sets(matrix: Matrix) -> Dict[int, Set[int]]:
38+
row_to_column_sets: Dict[int, Set[int]] = defaultdict(set)
39+
rows, cols, _ = matrix.to_coo()
40+
for r, c in zip(rows, cols):
41+
row_to_column_sets[r].add(c)
42+
return row_to_column_sets
43+
44+
@staticmethod
45+
def _generate_hash_coefficients_and_offsets() -> List[Tuple[int, int]]:
46+
coefficients_and_offsets: List[Tuple[int, int]] = []
2447
for _ in range(HASH_FUNCTIONS_COUNT):
2548
coefficient = random.randint(1, HASH_PRIME_MODULUS - 1)
2649
offset = random.randint(0, HASH_PRIME_MODULUS - 1)
27-
hash_coefficients_and_offsets.append((coefficient, offset))
28-
29-
row_to_minhash_signature = {}
50+
coefficients_and_offsets.append((coefficient, offset))
51+
return coefficients_and_offsets
52+
53+
@classmethod
54+
def _compute_row_minhash_signatures(
55+
cls,
56+
row_to_column_sets: Dict[int, Set[int]],
57+
) -> Dict[int, Tuple[int, ...]]:
58+
hash_params = cls._generate_hash_coefficients_and_offsets()
59+
row_minhash_signatures: Dict[int, Tuple[int, ...]] = {}
3060
for row_index, column_set in row_to_column_sets.items():
3161
if len(column_set) < PROTOTYPE_MIN_VALUES_PER_ROW:
3262
continue
33-
signature = []
34-
for coefficient, offset in hash_coefficients_and_offsets:
35-
min_hash = min((coefficient * col + offset) % HASH_PRIME_MODULUS for col in column_set)
36-
signature.append(min_hash)
37-
row_to_minhash_signature[row_index] = tuple(signature)
38-
39-
row_to_master_hash = {
40-
row_index: hash(signature)
41-
for row_index, signature in row_to_minhash_signature.items()
42-
}
43-
44-
master_hash_to_rows = defaultdict(list)
45-
for row_index, master_hash in row_to_master_hash.items():
46-
master_hash_to_rows[master_hash].append(row_index)
47-
48-
buckets_with_enough_rows = {
63+
signature = tuple(
64+
min((coef * col + off) % HASH_PRIME_MODULUS for col in column_set)
65+
for coef, off in hash_params
66+
)
67+
row_minhash_signatures[row_index] = signature
68+
return row_minhash_signatures
69+
70+
@staticmethod
71+
def _group_rows_by_master_hash(
72+
row_minhash_signatures: Dict[int, Tuple[int, ...]],
73+
) -> Dict[int, List[int]]:
74+
master_hash_to_rows: Dict[int, List[int]] = defaultdict(list)
75+
for row_index, signature in row_minhash_signatures.items():
76+
master_hash_to_rows[hash(signature)].append(row_index)
77+
return {
4978
master_hash: rows
5079
for master_hash, rows in master_hash_to_rows.items()
5180
if len(rows) >= PROTOTYPE_MIN_LSH_BUCKET_SIZE
5281
}
5382

54-
left_factor_column_vectors = []
55-
right_factor_row_signatures = []
56-
57-
for master_hash, bucket_row_indices in buckets_with_enough_rows.items():
58-
bucket_size = len(bucket_row_indices)
59-
bucket_submatrix = input_matrix[bucket_row_indices, :].new()
60-
column_sums = bucket_submatrix.dup(dtype=INT32).reduce_columnwise(plus).new()
61-
62-
first_threshold = int((1 - PROTOTYPE_OUTLIER_THRESHOLD) * bucket_size)
63-
frequent_columns_after_first_filter = column_sums.select('>=', first_threshold).new()
64-
if frequent_columns_after_first_filter.nvals == 0:
65-
continue
66-
67-
frequent_column_indices = set(frequent_columns_after_first_filter.to_coo()[0])
68-
first_filtered_rows = [
69-
row_index
70-
for row_index in bucket_row_indices
71-
if frequent_column_indices <= row_to_column_sets[row_index]
72-
]
73-
if not first_filtered_rows:
74-
continue
75-
76-
filtered_submatrix = input_matrix[first_filtered_rows, :].new()
77-
filtered_column_sums = filtered_submatrix.dup(dtype=INT32).reduce_columnwise(plus)
78-
79-
second_threshold = int((1 - PROTOTYPE_OUTLIER_THRESHOLD) * len(first_filtered_rows))
80-
frequent_columns_after_second_filter = filtered_column_sums.select('>=', second_threshold).new()
81-
if frequent_columns_after_second_filter.nvals == 0:
82-
continue
83-
84-
frequent_filtered_column_indices = set(frequent_columns_after_second_filter.to_coo()[0])
85-
second_filtered_rows = [
86-
row_index
87-
for row_index in first_filtered_rows
88-
if frequent_filtered_column_indices <= row_to_column_sets[row_index]
89-
]
90-
if len(second_filtered_rows) < PROTOTYPE_MIN_LSH_BUCKET_SIZE:
91-
continue
92-
93-
right_factor_row_signatures.append(frequent_columns_after_second_filter)
94-
95-
core_membership_vector = Vector(BOOL, size=number_of_rows)
96-
for core_row in second_filtered_rows:
97-
core_membership_vector[core_row] = True
98-
left_factor_column_vectors.append(core_membership_vector)
99-
100-
bucket_count = len(left_factor_column_vectors)
101-
if bucket_count == 0:
102-
return Matrix(input_matrix.dtype, number_of_rows, 0), \
103-
Matrix(input_matrix.dtype, 0, number_of_columns)
104-
105-
left_factor = Matrix(bool, number_of_rows, bucket_count)
106-
for idx, column_vector in enumerate(left_factor_column_vectors):
107-
left_factor[:, idx] = column_vector
108-
109-
right_factor = Matrix(bool, bucket_count, number_of_columns)
110-
for idx, row_signature in enumerate(right_factor_row_signatures):
111-
right_factor[idx, :] = row_signature
112-
83+
@staticmethod
84+
def _build_bucket_factors(
85+
master_hash_to_rows: Dict[int, List[int]],
86+
row_to_column_sets: Dict[int, Set[int]],
87+
matrix: Matrix,
88+
) -> List[BucketFactor]:
89+
bucket_factors: List[BucketFactor] = []
90+
for bucket_rows in master_hash_to_rows.values():
91+
factor = PrototypeDecomposer._build_bucket_factor(bucket_rows, row_to_column_sets, matrix)
92+
if factor:
93+
bucket_factors.append(factor)
94+
return bucket_factors
95+
96+
@staticmethod
97+
def _filter_rows_by_frequency(
98+
candidate_rows: List[int],
99+
row_to_column_sets: Dict[int, Set[int]],
100+
matrix: Matrix,
101+
) -> Tuple[Optional[Vector], List[int]]:
102+
submatrix = matrix[candidate_rows, :].new()
103+
column_sums = submatrix.dup(dtype=INT32).reduce_columnwise(plus).new()
104+
threshold = int((1 - PROTOTYPE_OUTLIER_THRESHOLD) * len(candidate_rows))
105+
frequency_signature = column_sums.select('>=', threshold).new()
106+
if frequency_signature.nvals == 0:
107+
return None, []
108+
frequent_column_indices = set(frequency_signature.to_coo()[0])
109+
surviving_rows = [
110+
r
111+
for r in candidate_rows
112+
if frequent_column_indices <= row_to_column_sets[r]
113+
]
114+
return frequency_signature, surviving_rows
115+
116+
@staticmethod
117+
def _build_bucket_factor(
118+
bucket_rows: List[int],
119+
row_to_column_sets: Dict[int, Set[int]],
120+
matrix: Matrix,
121+
) -> Optional[BucketFactor]:
122+
num_rows, _ = matrix.shape
123+
124+
first_round_signature, rows_after_first_filter = PrototypeDecomposer._filter_rows_by_frequency(
125+
bucket_rows, row_to_column_sets, matrix
126+
)
127+
if first_round_signature is None or len(rows_after_first_filter) < PROTOTYPE_MIN_LSH_BUCKET_SIZE:
128+
return None
129+
130+
second_round_signature, rows_after_second_filter = PrototypeDecomposer._filter_rows_by_frequency(
131+
rows_after_first_filter, row_to_column_sets, matrix
132+
)
133+
if second_round_signature is None or len(rows_after_second_filter) < PROTOTYPE_MIN_LSH_BUCKET_SIZE:
134+
return None
135+
136+
membership_vector = Vector(BOOL, size=num_rows)
137+
for row_index in rows_after_second_filter:
138+
membership_vector[row_index] = True
139+
return BucketFactor(
140+
membership_vector=membership_vector,
141+
column_signature=second_round_signature
142+
)
143+
144+
@staticmethod
145+
def _build_factor_matrices(
146+
bucket_factors: List[BucketFactor],
147+
matrix: Matrix,
148+
) -> Tuple[Matrix, Matrix]:
149+
num_rows, num_cols = matrix.shape
150+
num_buckets = len(bucket_factors)
151+
if num_buckets == 0:
152+
return Matrix(BOOL, num_rows, 0), Matrix(BOOL, 0, num_cols)
153+
left_factor = Matrix(BOOL, num_rows, num_buckets)
154+
right_factor = Matrix(BOOL, num_buckets, num_cols)
155+
for idx, factor in enumerate(bucket_factors):
156+
left_factor[:, idx] = factor.membership_vector
157+
right_factor[idx, :] = factor.column_signature
113158
return left_factor, right_factor

0 commit comments

Comments
 (0)