Skip to content

Commit d109faf

Browse files
authored
Merge pull request #302 from asogaard/restructure-models
Restructure models
2 parents 9808766 + 39c9148 commit d109faf

5 files changed

Lines changed: 426 additions & 410 deletions

File tree

examples/train_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def main():
9393
)
9494
gnn = DynEdge(
9595
nb_inputs=detector.nb_outputs,
96+
global_pooling_schemes=["min", "max", "mean", "sum"],
9697
)
9798
task = EnergyReconstruction(
9899
hidden_size=gnn.nb_outputs,

src/graphnet/models/components/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional, Sequence
1+
from typing import Callable, List, Optional, Sequence, Union
22

33
from torch.functional import Tensor
44

@@ -13,7 +13,7 @@ def __init__(
1313
nn: Callable,
1414
aggr: str = "max",
1515
nb_neighbors: int = 8,
16-
features_subset: Optional[Sequence] = None,
16+
features_subset: Optional[Union[Sequence[int], List[int]]] = None,
1717
**kwargs,
1818
):
1919
# Check(s)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from .dynedge import DynEdge, DynEdge_V2, DynEdge_V3
21
from .convnet import ConvNet
2+
from .dynedge import DynEdge
3+
from .dynedge_jinst import DynEdgeJINST

0 commit comments

Comments
 (0)