@@ -59,7 +59,7 @@ def IntTupleIfNotSingleItemUnpack(value: Any) -> Union[int, Tuple[int, ...]]:
5959 return IntTupleUnpack (value )
6060
6161
62- def attrToTensor (node : gs .Node , attr : str ) -> None :
62+ def attrToInputTensor (node : gs .Node , attr : str ) -> None :
6363 values = node .attrs [attr ]
6464 if isinstance (values , (int , float )):
6565 values = np .array ([values ])
@@ -71,6 +71,14 @@ def attrToTensor(node: gs.Node, attr: str) -> None:
7171 node .attrs .pop (attr )
7272
7373
74+ def inputTensorToAttr (node : gs .Node , tensorIdx : int , attr : str ) -> None :
75+ tensor = node .inputs [tensorIdx ]
76+ assert isinstance (tensor , gs .Constant ), \
77+ f"Can convert only constant tensors to attributes. Received tensor of type { tensor } "
78+ node .attrs [attr ] = tensor .values
79+ tensor .outputs .clear ()
80+
81+
7482concatDesc = OperatorDescriptor (
7583 inputDescriptor = VariadicIoDesc ("data_in" , minNumTensors = 2 ),
7684 outputDescriptor = IoDesc ("data_out" ),
@@ -91,10 +99,10 @@ class SliceDescriptor(OperatorDescriptor):
9199
92100 def canonicalize (self , node : gs .Node , opset : int ) -> bool :
93101 if opset < 10 :
94- attrToTensor (node , "starts" )
95- attrToTensor (node , "ends" )
102+ attrToInputTensor (node , "starts" )
103+ attrToInputTensor (node , "ends" )
96104 if "axes" in node .attrs :
97- attrToTensor (node , "axes" )
105+ attrToInputTensor (node , "axes" )
98106
99107 return super ().canonicalize (node , opset )
100108
@@ -184,7 +192,7 @@ class ReduceMeanDescriptor(OperatorDescriptor):
184192 def canonicalize (self , node : gs .Node , opset : int ) -> bool :
185193 if opset < 18 :
186194 if "axes" in node .attrs :
187- attrToTensor (node , "axes" )
195+ attrToInputTensor (node , "axes" )
188196 return super ().canonicalize (node , opset )
189197
190198
0 commit comments