22from collections .abc import Mapping
33from numbers import Number
44
5+ from bindings .looptree import TemporalTag , SequentialTag , PipelineTemporalTag
6+
57import islpy as isl
68
79from pytimeloop .isl .singular import get_sum_of_pw_qpolynomial
10+ from pytimeloop .isl .sum import sum_with_mask
811from pytimeloop .looptree .mapping_utilities import *
912
1013
@@ -29,7 +32,8 @@ def reads_and_writes_from_fill_by_parent(fills: Mapping,
2932 reads_to_parent ,
3033 mapping ,
3134 workload ,
32- is_path = False ):
35+ is_path = False ,
36+ per_unit = False ):
3337 mapping = mapping ['nodes' ]
3438 dspace_id_to_name = workload .data_space_id_to_name ()
3539 einsum_id_to_name = workload .einsum_id_to_name ()
@@ -49,8 +53,33 @@ def reads_and_writes_from_fill_by_parent(fills: Mapping,
4953 for (buffer_id , dspace_id , einsum_id ), (tags , fill ) in fills .items ():
5054 read_to_parent = reads_to_parent [(buffer_id , dspace_id , einsum_id )][1 ]
5155
52- read_to_parent = get_sum_of_pw_qpolynomial (read_to_parent )
53- fill = get_sum_of_pw_qpolynomial (fill )
56+ if not per_unit :
57+ read_to_parent = get_sum_of_pw_qpolynomial (read_to_parent )
58+ fill = get_sum_of_pw_qpolynomial (fill )
59+ else :
60+ fill = sum_with_mask (
61+ [
62+ (
63+ isinstance (t , TemporalTag ) or
64+ isinstance (t , PipelineTemporalTag ) or
65+ isinstance (t , SequentialTag )
66+ )
67+ for t in tags
68+ ],
69+ fill
70+ ).max ().to_python ()
71+ n_read_to_parent_dim = read_to_parent .dim (isl .dim_type .in_ )
72+ read_to_parent = sum_with_mask (
73+ [
74+ (
75+ isinstance (t , TemporalTag ) or
76+ isinstance (t , PipelineTemporalTag ) or
77+ isinstance (t , SequentialTag )
78+ )
79+ for t in tags [:n_read_to_parent_dim ]
80+ ],
81+ read_to_parent
82+ ).max ().to_python ()
5483
5584 dspace_name = dspace_id_to_name [dspace_id ]
5685 einsum_name = einsum_id_to_name [einsum_id ]
@@ -61,24 +90,32 @@ def reads_and_writes_from_fill_by_parent(fills: Mapping,
6190 key = (parent_buffer , dspace_name , einsum_name )
6291 if dspace_id in workload .tensors_written_by_einsum (einsum_id ):
6392 writes [key ] += read_to_parent
93+ reads [key ] += read_to_parent
6494 # Subtracted term: elided first read of a read-write tensor
65- reads [key ] += \
66- read_to_parent - workload .get_tensor_volume (dspace_id )
95+ # TODO: figure out how to do this per unit
96+ if not per_unit :
97+ reads [key ] -= workload .get_tensor_volume (dspace_id )
6798 elif dspace_id in workload .tensors_read_by_einsum (einsum_id ):
6899 reads [key ] += read_to_parent
69100 # Fills will write into current buffer except for compute (which does
70101 # not have write action) and top-level buffer
71102 if buffer_id not in compute_targets and parent_buffer is not None :
72103 if dspace_id in workload .tensors_written_by_einsum (einsum_id ):
73- writes [(buffer_id , dspace_name , einsum_name )] += \
74- fill - workload .get_tensor_volume (dspace_id )
104+ writes [(buffer_id , dspace_name , einsum_name )] += fill
105+ if not per_unit :
106+ writes [(buffer_id , dspace_name , einsum_name )] -= \
107+ workload .get_tensor_volume (dspace_id )
75108 else :
76109 writes [(buffer_id , dspace_name , einsum_name )] += fill
77110
78111 return reads , writes
79112
80113
81- def reads_and_writes_from_fill_by_peer (fills : Mapping , mapping , workload , is_path = False ):
114+ def reads_and_writes_from_fill_by_peer (fills : Mapping ,
115+ mapping ,
116+ workload ,
117+ is_path = False ,
118+ per_unit = False ):
82119 mapping = mapping ['nodes' ]
83120 dspace_id_to_name = workload .data_space_id_to_name ()
84121 einsum_id_to_name = workload .einsum_id_to_name ()
@@ -89,14 +126,27 @@ def reads_and_writes_from_fill_by_peer(fills: Mapping, mapping, workload, is_pat
89126 einsums_with_complete_mappings = get_einsums_with_complete_mappings (mapping , workload , is_path )
90127
91128 for (buffer_id , dspace_id , einsum_id ), (tags , fill ) in fills .items ():
92- fill = get_sum_of_pw_qpolynomial (fill )
129+ if not per_unit :
130+ fill = get_sum_of_pw_qpolynomial (fill )
131+ else :
132+ fill = sum_with_mask (
133+ [
134+ (
135+ isinstance (t , TemporalTag ) or
136+ isinstance (t , PipelineTemporalTag ) or
137+ isinstance (t , SequentialTag )
138+ )
139+ for t in tags
140+ ],
141+ fill
142+ ).max ().to_python ()
93143 einsum_name = einsum_id_to_name [einsum_id ]
94144 dspace_name = dspace_id_to_name [dspace_id ]
95145 if einsum_id not in einsums_with_complete_mappings :
96146 continue
97147
98148 reads [(buffer_id , dspace_name , einsum_name )] = fill
99- writes [(buffer_id , dspace_name , einsum_name )] = 0 # already accounted for in above
149+ writes [(buffer_id , dspace_name , einsum_name )] = 0 # already accounted for in fill_by_parent
100150
101151 return reads , writes
102152
0 commit comments