Skip to content

Commit 161441e

Browse files
Copilotjustinchuby
andcommitted
Update nn.Sequential signature to accept *modules (varargs) matching PyTorch
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent c4a25ac commit 161441e

2 files changed

Lines changed: 15 additions & 12 deletions

File tree

onnxscript/nn/_module_test.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def forward(self, op, x):
854854
return op.Add(x, op.Constant(value_float=1.0))
855855

856856
graph, op, x = self._make_input()
857-
seq = Sequential([AddOne(), AddOne(), AddOne()])
857+
seq = Sequential(AddOne(), AddOne(), AddOne())
858858
result = seq(op, x)
859859

860860
self.assertIsInstance(result, ir.Value)
@@ -869,7 +869,7 @@ def forward(self, op, x):
869869
return op.Identity(x)
870870

871871
_, op, x = self._make_input()
872-
seq = Sequential([PassThrough()])
872+
seq = Sequential(PassThrough())
873873
result = seq(op, x)
874874
self.assertIsInstance(result, ir.Value)
875875

@@ -899,7 +899,7 @@ def forward(self, op, pair):
899899
return op.Add(a, b)
900900

901901
graph, op, x = self._make_input()
902-
seq = Sequential([SplitTwo(), UnpackAndAdd()])
902+
seq = Sequential(SplitTwo(), UnpackAndAdd())
903903
result = seq(op, x)
904904

905905
self.assertIsInstance(result, ir.Value)
@@ -920,7 +920,7 @@ def forward(self, op, pair):
920920
return op.Add(a, b)
921921

922922
_, op, x = self._make_input()
923-
seq = Sequential([SplitTwoList(), UnpackAndAdd()])
923+
seq = Sequential(SplitTwoList(), UnpackAndAdd())
924924
result = seq(op, x)
925925
self.assertIsInstance(result, ir.Value)
926926

@@ -938,7 +938,7 @@ def forward(self, op, pair):
938938
return pair
939939

940940
_, op, x = self._make_input()
941-
seq = Sequential([ReturnPair(), TupleIdentity()])
941+
seq = Sequential(ReturnPair(), TupleIdentity())
942942
result = seq(op, x)
943943
self.assertIsInstance(result, tuple)
944944
self.assertEqual(len(result), 2)
@@ -951,7 +951,7 @@ def forward(self, op, x):
951951
return (op.Identity(x), op.Identity(x))
952952

953953
_, op, x = self._make_input()
954-
seq = Sequential([ReturnPair()])
954+
seq = Sequential(ReturnPair())
955955
result = seq(op, x)
956956
self.assertIsInstance(result, tuple)
957957
self.assertEqual(len(result), 2)
@@ -968,7 +968,7 @@ def forward(self, op, x):
968968
return (op.Identity(x), op.Identity(x), op.Identity(x))
969969

970970
_, op, x = self._make_input()
971-
seq = Sequential([Identity(), SplitThree()])
971+
seq = Sequential(Identity(), SplitThree())
972972
result = seq(op, x)
973973
self.assertIsInstance(result, tuple)
974974
self.assertEqual(len(result), 3)
@@ -987,7 +987,7 @@ def forward(self, op, x):
987987

988988
_, op = _create_graph_and_op()
989989
accept = AcceptNone()
990-
seq = Sequential([ReturnNone(), accept])
990+
seq = Sequential(ReturnNone(), accept)
991991
result = seq(op, "anything")
992992
self.assertIsNone(result)
993993
self.assertIsNone(accept.received)
@@ -1008,7 +1008,7 @@ def forward(self, op, x):
10081008
class Model(Module):
10091009
def __init__(self):
10101010
super().__init__("model")
1011-
self.layers = Sequential([Linear(4, 4), Linear(4, 4)])
1011+
self.layers = Sequential(Linear(4, 4), Linear(4, 4))
10121012

10131013
def forward(self, op, x):
10141014
return self.layers(op, x)
@@ -1035,7 +1035,7 @@ def __init__(self, size):
10351035
def forward(self, op, x):
10361036
return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0]))
10371037

1038-
seq = Sequential([SiLU(), Linear(4)])
1038+
seq = Sequential(SiLU(), Linear(4))
10391039
named = dict(seq.named_parameters())
10401040
# SiLU at index 0 has no params; Linear at index 1 has weight
10411041
self.assertIn("1.weight", named)
@@ -1061,7 +1061,7 @@ def forward(self, op, x):
10611061
class Model(Module):
10621062
def __init__(self):
10631063
super().__init__("model")
1064-
self.blocks = Sequential([])
1064+
self.blocks = Sequential()
10651065
# Append AFTER __setattr__ has set Sequential._name = "blocks"
10661066
self.blocks.append(Linear(4))
10671067
self.blocks.append(Linear(4))

onnxscript/nn/_sequential.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,16 @@ def forward(self, op, x):
2424
2525
# Produces parameter names: "mod.0.weight", "mod.0.bias"
2626
# SiLU at index 0 has no parameters.
27-
mod = Sequential([SiLU(), Linear(4, 4)])
27+
mod = Sequential(SiLU(), Linear(4, 4))
2828
2929
# Calling mod(op, x) is equivalent to:
3030
# x = silu(op, x)
3131
# x = linear(op, x)
3232
"""
3333

34+
def __init__(self, *modules: _module_list.Module) -> None:
35+
super().__init__(list(modules))
36+
3437
def _set_name(self, name: str) -> None:
3538
"""Set this container's name. Children keep simple ``"0"``, ``"1"`` names.
3639

0 commit comments

Comments
 (0)