Skip to content

Commit 0daad89

Browse files
committed
fix pdp unstructured shape bug
1 parent 58266da commit 0daad89

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

  • src/pquant/pruning_methods

src/pquant/pruning_methods/pdp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def build(self, input_shape):
3131
shape = (input_shape[0], 1, 1)
3232
else:
3333
shape = (input_shape[0], 1, 1, 1)
34+
else:
35+
shape = input_shape
3436
self.mask = self.add_weight(shape=shape, initializer="ones", name="mask", trainable=False)
3537
self.flat_weight_size = ops.cast(ops.size(self.mask), self.mask.dtype)
3638
super().build(input_shape)

0 commit comments

Comments
 (0)