hn/loop-order-cost#521
Conversation
| match node: | ||
| case Table(Alias() as tns, idxs): | ||
| base = stats_bindings.get(tns) | ||
| st = stats_factory.relabel(base, tuple(idxs)) |
There was a problem hiding this comment.
You probably don't need this relabel call.
| return penalty | ||
|
|
||
|
|
||
| def loop_order_cost( |
There was a problem hiding this comment.
This function is unfortunately not quite right. I had a version that worked like this a while back in previous iterations of the julia code. However, it suffers from a specific problem:
Suppose you have the query:
E_ijklm = A_ij * B_jk * C_kl * D_lm
Further, suppose that D is entirely zeros (i.e. it's empty). Then, the output will always be empty. This means that your full_stats object will represent an empty tensor, so when you project it down, it will still be empty. Because of this, the i->j->k->l->m and m->l->k->j->i order will look identical. However, in execution, the i->j->k->l->m order will be asymptotically worse because it will basically do the matmul between A and B in the first 3 loops, O(n^3), before realizing that the l,m loops are empty. The m->l->k->j->i order would simply exit immediately in O(1) time.
To fix this, you need to consider the sub-query that a prefix induces. When considering i-j-k, you should estimate the nnz of A_ijB_jkC_k. In the julia version, it does this in the function get_loop_lookups (https://github.com/finch-tensor/Finch.jl/blob/8b6554b2c73668670037bfb8413425035589d915/src/Galley/TensorStats/cost-estimates.jl#L15). However, it also relies on the function get_conjuncts_and_disjuncts (https://github.com/finch-tensor/Finch.jl/blob/8b6554b2c73668670037bfb8413425035589d915/src/Galley/PlanAST/plan.jl#L514) in order to figure out which input tensors in the expression are acting as 'conjuncts' (i.e. join-like arguments) and which ones are acting as 'disjuncts' (i.e. union-like arguments).
So, to carry over the key logic, you'll have to port both of those functions over.
loop order cost port