1+ from typing import overload
2+
13from pytimeloop .isl .singular import get_value_from_singular_qpolynomial
24from pytimeloop .looptree .latency .processors import LATENCY_PROCESSORS
3- from pytimeloop .looptree .reuse .isl .des import IslReuseAnalysisOutput
5+ from pytimeloop .looptree .reuse .isl import IslReuseAnalysisOutput
6+ from pytimeloop .looptree .reuse .summarized import SummarizedAnalysisOutput
47from pytimeloop .looptree .latency .memory import memory_latency
58
69from bindings .looptree import SpatialTag
@@ -11,9 +14,9 @@ def get_latency(looptree_results: IslReuseAnalysisOutput,
1114 workload ,
1215 arch ,
1316 bindings ):
14- comp_latency = compute_latency ( mapping ,
15- looptree_results . temporal_steps ,
16- workload )
17+ comp_latency = calculate_compute_latency ( looptree_results ,
18+ mapping ,
19+ workload )
1720 mem_latency = memory_latency (looptree_results ,
1821 arch ,
1922 mapping ,
@@ -23,12 +26,40 @@ def get_latency(looptree_results: IslReuseAnalysisOutput,
2326 return overall_latency , comp_latency , mem_latency
2427
2528
26- def compute_latency (mapping , temporal_steps , workload ):
29+ @overload
30+ def calculate_compute_latency (reuse_analysis_results : IslReuseAnalysisOutput ,
31+ mapping ,
32+ workload ):
33+ pass
34+ @overload
35+ def calculate_compute_latency (reuse_analysis_results : SummarizedAnalysisOutput ,
36+ mapping ,
37+ workload ):
38+ pass
39+ def calculate_compute_latency (reuse_analysis_results , mapping , workload ):
40+ if isinstance (reuse_analysis_results , IslReuseAnalysisOutput ):
41+ return compute_isl_latency (reuse_analysis_results .temporal_steps ,
42+ mapping ,
43+ workload )
44+ elif isinstance (reuse_analysis_results , SummarizedAnalysisOutput ):
45+ return compute_summarized_latency (
46+ reuse_analysis_results .temporal_steps ,
47+ mapping ,
48+ workload
49+ )
50+
51+
52+ def compute_isl_latency (temporal_steps , mapping , workload ):
2753 return get_value_from_singular_qpolynomial (
2854 _compute_latency (mapping .nodes , 0 , temporal_steps , workload )[1 ]
2955 ).to_python ()
3056
3157
58+ def compute_summarized_latency (temporal_steps , mapping , workload ):
59+ # TODO: this is only for single-Einsum!!!
60+ return sum (value for key , value in temporal_steps )
61+
62+
3263def _compute_latency (mapping , top_idx : int , temporal_steps , workload ):
3364 einsum_name_to_id = workload .einsum_name_to_id ()
3465
0 commit comments