From 7791d7cf97364091aa3b3daf4fc810a0a91084e6 Mon Sep 17 00:00:00 2001 From: Saravanabalagi Ramachandran Date: Thu, 7 Mar 2024 11:38:05 +0000 Subject: [PATCH] Update mask_updater.py fixing _output_randomize error --- nni/compression/speedup/mask_updater.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nni/compression/speedup/mask_updater.py b/nni/compression/speedup/mask_updater.py index 22c20c87d5..eb05826189 100644 --- a/nni/compression/speedup/mask_updater.py +++ b/nni/compression/speedup/mask_updater.py @@ -258,7 +258,7 @@ def direct_update_process(self, model_speedup: 'ModelSpeedup', node: Node): if model_speedup.garbage_collect_values: # do memory collect to reduce memory usage for to_delete in model_speedup.user_to_last_uses.get(node, []): - del model_speedup.node_infos[to_delete]._output_randomize + del model_speedup.node_infos[to_delete].output_randomize # in all the following function, the first arg name is `input`, and don't have other tensor as input args. @@ -397,7 +397,7 @@ def direct_getitem(self, model_speedup: 'ModelSpeedup', node: Node): def indirect_getitem(self, model_speedup: 'ModelSpeedup', node: Node): assert len(node.args) == 2 - input_grad = tree_map_zip(lambda t, m: (t * m).type_as(t) if isinstance(m, torch.Tensor) else t, \ + input_grad = tree_map_zip(lambda t, m: (t * m).type_as(t) if isinstance(m, torch.Tensor) and t is not None else t, \ model_speedup.node_infos[node].output_grad, model_speedup.node_infos[node].output_masks) arg_1_val = model_speedup.node_infos[node.args[1]].output_randomize if isinstance(node.args[1], Node) else node.args[1]