Skip to content

Commit f180f85

Browse files
committed
Rename attrToTensor to attrToInputTensor and add inputTensorToAttr
1 parent e8f1721 commit f180f85

1 file changed

Lines changed: 13 additions & 5 deletions

File tree

Deeploy/OperatorDescriptor.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def IntTupleIfNotSingleItemUnpack(value: Any) -> Union[int, Tuple[int, ...]]:
5959
return IntTupleUnpack(value)
6060

6161

62-
def attrToTensor(node: gs.Node, attr: str) -> None:
62+
def attrToInputTensor(node: gs.Node, attr: str) -> None:
6363
values = node.attrs[attr]
6464
if isinstance(values, (int, float)):
6565
values = np.array([values])
@@ -71,6 +71,14 @@ def attrToTensor(node: gs.Node, attr: str) -> None:
7171
node.attrs.pop(attr)
7272

7373

74+
def inputTensorToAttr(node: gs.Node, tensorIdx: int, attr: str) -> None:
75+
tensor = node.inputs[tensorIdx]
76+
assert isinstance(tensor, gs.Constant), \
77+
f"Can convert only constant tensors to attributes. Received tensor of type {tensor}"
78+
node.attrs[attr] = tensor.values
79+
tensor.outputs.clear()
80+
81+
7482
concatDesc = OperatorDescriptor(
7583
inputDescriptor = VariadicIoDesc("data_in", minNumTensors = 2),
7684
outputDescriptor = IoDesc("data_out"),
@@ -91,10 +99,10 @@ class SliceDescriptor(OperatorDescriptor):
9199

92100
def canonicalize(self, node: gs.Node, opset: int) -> bool:
93101
if opset < 10:
94-
attrToTensor(node, "starts")
95-
attrToTensor(node, "ends")
102+
attrToInputTensor(node, "starts")
103+
attrToInputTensor(node, "ends")
96104
if "axes" in node.attrs:
97-
attrToTensor(node, "axes")
105+
attrToInputTensor(node, "axes")
98106

99107
return super().canonicalize(node, opset)
100108

@@ -184,7 +192,7 @@ class ReduceMeanDescriptor(OperatorDescriptor):
184192
def canonicalize(self, node: gs.Node, opset: int) -> bool:
185193
if opset < 18:
186194
if "axes" in node.attrs:
187-
attrToTensor(node, "axes")
195+
attrToInputTensor(node, "axes")
188196
return super().canonicalize(node, opset)
189197

190198

0 commit comments

Comments
 (0)