-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathsparse-matmul.rs
More file actions
78 lines (65 loc) · 2.3 KB
/
sparse-matmul.rs
File metadata and controls
78 lines (65 loc) · 2.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
use std::collections::HashMap;
use std::io;
use std::collections::HashSet;
use itertools::Itertools;
macro_rules! parse_input {
($x:expr, $t:ident) => ($x.trim().parse::<$t>().unwrap())
}
fn read_matrix(entries: usize, row_wise: bool) -> HashMap<usize, HashMap<usize, f64>> {
let mut matrix: HashMap<usize, HashMap<usize, f64>> = HashMap::new();
for _ in 0..entries {
let mut entry = String::new();
io::stdin().read_line(&mut entry).unwrap();
let inputs = entry.split_whitespace().collect::<Vec<_>>();
let r = parse_input!(inputs[0], usize);
let c = parse_input!(inputs[1], usize);
let v = parse_input!(inputs[2], f64);
if row_wise {
matrix.entry(r).or_insert_with(HashMap::new).insert(c, v);
} else {
matrix.entry(c).or_insert_with(HashMap::new).insert(r, v);
}
}
matrix
}
fn dot_product(a: &HashMap<usize, HashMap<usize, f64>>,
b: &HashMap<usize, HashMap<usize, f64>>,
i: usize, k: usize) -> f64 {
match (a.get(&i), b.get(&k)) {
(Some(row_a), Some(col_b)) => {
let keys_a: HashSet<_> = row_a.keys().collect();
let keys_b: HashSet<_> = col_b.keys().collect();
keys_a.intersection(&keys_b)
.map(|&j| row_a[j] * col_b[j])
.sum()
}
_ => 0.0
}
}
// Damn stupid formatting
fn format_number(num: &f64) -> String {
if (num - num.round()).abs() < f64::EPSILON {
format!("{:.1}", num)
} else {
format!("{}", num)
}
}
fn main() {
let mut input_line = String::new();
io::stdin().read_line(&mut input_line).unwrap();
let mut input_line = String::new();
io::stdin().read_line(&mut input_line).unwrap();
let inputs = input_line.split_whitespace().collect::<Vec<_>>();
let a_entries = parse_input!(inputs[0], usize);
let b_entries = parse_input!(inputs[1], usize);
let a = read_matrix(a_entries, true);
let b = read_matrix(b_entries, false);
for i in a.keys().sorted() {
for k in b.keys().sorted() {
let dot_prod = dot_product(&a, &b, *i, *k);
if dot_prod != 0.0 {
println!("{} {} {}", i, k, format_number(&dot_prod));
}
}
}
}