@@ -124,8 +124,7 @@ def _create_kernel_parameters(parameter_constraints: Sequence[ParameterConstrain
124124 constant_parameter_mask : Sequence [bool ],
125125 parameter_names : Sequence [str ],
126126 parameter_locations : Sequence [Loc ],
127- ir_ctx : ir .IRContext ,
128- array_memory_space = None ) -> _KernelParameters :
127+ ir_ctx : ir .IRContext ) -> _KernelParameters :
129128 aggregate_vars = []
130129 nonconstant_flat_vars = []
131130 for pos , (constraint , is_const , name , loc ) in enumerate (
@@ -140,10 +139,10 @@ def _create_kernel_parameters(parameter_constraints: Sequence[ParameterConstrain
140139 if isinstance (constraint , ScalarConstraint ):
141140 ty = TileTy (constraint .dtype , ())
142141 elif isinstance (constraint , ArrayConstraint ):
143- ty = _get_array_ty (constraint , array_memory_space )
142+ ty = _get_array_ty (constraint )
144143 elif isinstance (constraint , ListConstraint ):
145144 assert isinstance (constraint .element , ArrayConstraint )
146- array_ty = _get_array_ty (constraint .element , array_memory_space )
145+ array_ty = _get_array_ty (constraint .element )
147146 ty = ListTy (array_ty )
148147 else :
149148 raise TypeError (f"Unexpected parameter descriptor type"
@@ -157,7 +156,7 @@ def _create_kernel_parameters(parameter_constraints: Sequence[ParameterConstrain
157156 return _KernelParameters (aggregate_vars , nonconstant_flat_vars )
158157
159158
160- def _get_array_ty (param : ArrayConstraint , memory_space ):
159+ def _get_array_ty (param : ArrayConstraint ):
161160 for static_stride , bound in zip (param .stride_constant , param .stride_lower_bound_incl ,
162161 strict = True ):
163162 if static_stride is not None :
@@ -169,8 +168,7 @@ def _get_array_ty(param: ArrayConstraint, memory_space):
169168 return ArrayTy (make_tile_ty (param .dtype , ()),
170169 shape = (None ,) * param .ndim ,
171170 strides = param .stride_constant ,
172- index_dtype = param .index_dtype ,
173- memory_space = memory_space )
171+ index_dtype = param .index_dtype )
174172
175173
176174def _log_mlir (bytecode_buf ):
0 commit comments