@@ -237,14 +237,21 @@ def change_type_map(
237237 remap_index , has_new_type = get_index_between_two_maps (self .type_map , type_map )
238238 super ().change_type_map (type_map = type_map )
239239 if has_new_type :
240+ xp = array_api_compat .array_namespace (self .scale )
240241 extend_shape = [len (type_map ), * list (self .scale .shape [1 :])]
241- extend_scale = np .ones (extend_shape , dtype = self .scale .dtype )
242- self .scale = np .concatenate ([self .scale , extend_scale ], axis = 0 )
242+ extend_scale = xp .ones (
243+ extend_shape ,
244+ dtype = self .scale .dtype ,
245+ device = array_api_compat .device (self .scale ),
246+ )
247+ self .scale = xp .concat ([self .scale , extend_scale ], axis = 0 )
243248 extend_shape = [len (type_map ), * list (self .constant_matrix .shape [1 :])]
244- extend_constant_matrix = np .zeros (
245- extend_shape , dtype = self .constant_matrix .dtype
249+ extend_constant_matrix = xp .zeros (
250+ extend_shape ,
251+ dtype = self .constant_matrix .dtype ,
252+ device = array_api_compat .device (self .constant_matrix ),
246253 )
247- self .constant_matrix = np . concatenate (
254+ self .constant_matrix = xp . concat (
248255 [self .constant_matrix , extend_constant_matrix ], axis = 0
249256 )
250257 self .scale = self .scale [remap_index ]
0 commit comments