diff --git a/configs/model/cell/can.yaml b/configs/model/cell/can.yaml index be81d82fa..3a4e17ce5 100755 --- a/configs/model/cell/can.yaml +++ b/configs/model/cell/can.yaml @@ -38,7 +38,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/cell/cccn.yaml b/configs/model/cell/cccn.yaml index 19bfd9839..7d626e30f 100755 --- a/configs/model/cell/cccn.yaml +++ b/configs/model/cell/cccn.yaml @@ -32,7 +32,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/cell/ccxn.yaml b/configs/model/cell/ccxn.yaml index ce3e44339..7a655df4f 100755 --- a/configs/model/cell/ccxn.yaml +++ b/configs/model/cell/ccxn.yaml @@ -34,7 +34,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/cell/cwn.yaml b/configs/model/cell/cwn.yaml index ccd85647c..f4584cddc 100755 --- a/configs/model/cell/cwn.yaml +++ b/configs/model/cell/cwn.yaml @@ -31,7 +31,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/cell/topotune.yaml b/configs/model/cell/topotune.yaml index abea53926..833254a57 100755 --- a/configs/model/cell/topotune.yaml +++ b/configs/model/cell/topotune.yaml @@ -47,7 +47,7 @@ readout: num_cell_dimensions: ${infer_topotune_num_cell_dimensions:${oc.select:model.backbone.neighborhoods}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/cell/topotune_onehasse.yaml b/configs/model/cell/topotune_onehasse.yaml index 02cf88a0c..b23c04de0 100644 --- a/configs/model/cell/topotune_onehasse.yaml +++ b/configs/model/cell/topotune_onehasse.yaml @@ -46,7 +46,7 @@ readout: num_cell_dimensions: ${infer_topotune_num_cell_dimensions:${oc.select:model.backbone.neighborhoods}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/combinatorial/topotune.yaml b/configs/model/combinatorial/topotune.yaml index 0fb9d9cea..b1e065563 100755 --- a/configs/model/combinatorial/topotune.yaml +++ b/configs/model/combinatorial/topotune.yaml @@ -47,7 +47,7 @@ readout: num_cell_dimensions: ${infer_topotune_num_cell_dimensions:${oc.select:model.backbone.neighborhoods}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/graph/gat.yaml b/configs/model/graph/gat.yaml index 642e2bca2..45bb8d810 100755 --- a/configs/model/graph/gat.yaml +++ b/configs/model/graph/gat.yaml @@ -34,7 +34,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/graph/gcn.yaml b/configs/model/graph/gcn.yaml index d9b71c87e..06a6a4ebe 100755 --- a/configs/model/graph/gcn.yaml +++ b/configs/model/graph/gcn.yaml @@ -31,7 +31,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/graph/gcn_dgm.yaml b/configs/model/graph/gcn_dgm.yaml index 90e352eb0..afdbe5741 100755 --- a/configs/model/graph/gcn_dgm.yaml +++ b/configs/model/graph/gcn_dgm.yaml @@ -34,7 +34,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/graph/gin.yaml b/configs/model/graph/gin.yaml index ad64587cc..3236ec90d 100755 --- a/configs/model/graph/gin.yaml +++ b/configs/model/graph/gin.yaml @@ -31,7 +31,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/graph/gps.yaml b/configs/model/graph/gps.yaml index 81cc1bf1f..7daa8c8de 100644 --- a/configs/model/graph/gps.yaml +++ b/configs/model/graph/gps.yaml @@ -35,7 +35,7 @@ readout: in_channels: ${model.feature_encoder.out_channels} hidden_layers: [16] out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # Extra MLP params dropout: 0.2 diff --git a/configs/model/graph/graph_mlp.yaml b/configs/model/graph/graph_mlp.yaml index f2a643f26..c3038af4d 100755 --- a/configs/model/graph/graph_mlp.yaml +++ b/configs/model/graph/graph_mlp.yaml @@ -35,7 +35,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum diff --git a/configs/model/graph/nsd.yaml b/configs/model/graph/nsd.yaml index a3254bd56..2e713233d 100644 --- a/configs/model/graph/nsd.yaml +++ b/configs/model/graph/nsd.yaml @@ -35,7 +35,7 @@ readout: in_channels: ${model.feature_encoder.out_channels} hidden_layers: [16] out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # Extra MLP params dropout: 0.2 diff --git a/configs/model/hypergraph/alldeepset.yaml b/configs/model/hypergraph/alldeepset.yaml index 202900a50..c3dfa556b 100755 --- a/configs/model/hypergraph/alldeepset.yaml +++ b/configs/model/hypergraph/alldeepset.yaml @@ -39,7 +39,7 @@ readout: num_cell_dimensions: 1 # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/hypergraph/allsettransformer.yaml b/configs/model/hypergraph/allsettransformer.yaml index 52b9072f3..c4414993f 100755 --- a/configs/model/hypergraph/allsettransformer.yaml +++ b/configs/model/hypergraph/allsettransformer.yaml @@ -33,7 +33,7 @@ readout: num_cell_dimensions: 1 # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/hypergraph/edgnn.yaml b/configs/model/hypergraph/edgnn.yaml index 6144f9143..81815e86c 100755 --- a/configs/model/hypergraph/edgnn.yaml +++ b/configs/model/hypergraph/edgnn.yaml @@ -34,7 +34,7 @@ readout: num_cell_dimensions: 1 # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/hypergraph/unignn.yaml b/configs/model/hypergraph/unignn.yaml index adc88c262..15b867325 100755 --- a/configs/model/hypergraph/unignn.yaml +++ b/configs/model/hypergraph/unignn.yaml @@ -29,7 +29,7 @@ readout: num_cell_dimensions: 1 # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/hypergraph/unignn2.yaml b/configs/model/hypergraph/unignn2.yaml index 5beb1d622..eea78f58d 100755 --- a/configs/model/hypergraph/unignn2.yaml +++ b/configs/model/hypergraph/unignn2.yaml @@ -33,7 +33,7 @@ readout: num_cell_dimensions: 1 hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/non_relational/mlp.yaml b/configs/model/non_relational/mlp.yaml index a14e5654e..b990a82ec 100644 --- a/configs/model/non_relational/mlp.yaml +++ b/configs/model/non_relational/mlp.yaml @@ -29,7 +29,7 @@ readout: readout_name: NoReadOut hidden_dim: ${model.backbone.out_channels} out_channels: ${model.backbone.out_channels} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum logits_linear_layer: false diff --git a/configs/model/pointcloud/deepset.yaml b/configs/model/pointcloud/deepset.yaml index 7e59d169f..2ae96dbd2 100644 --- a/configs/model/pointcloud/deepset.yaml +++ b/configs/model/pointcloud/deepset.yaml @@ -29,7 +29,7 @@ readout: in_channels: ${model.feature_encoder.out_channels} hidden_layers: [64,32] out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # Extra MLP params dropout: 0.0 diff --git a/configs/model/simplicial/san.yaml b/configs/model/simplicial/san.yaml index 4ff1c9b7e..7b14027d7 100755 --- a/configs/model/simplicial/san.yaml +++ b/configs/model/simplicial/san.yaml @@ -35,7 +35,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/simplicial/sccn.yaml b/configs/model/simplicial/sccn.yaml index fd60e0cb2..fe79d4ee9 100755 --- a/configs/model/simplicial/sccn.yaml +++ b/configs/model/simplicial/sccn.yaml @@ -30,7 +30,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/simplicial/sccnn.yaml b/configs/model/simplicial/sccnn.yaml index d507e3e45..d7562a77a 100755 --- a/configs/model/simplicial/sccnn.yaml +++ b/configs/model/simplicial/sccnn.yaml @@ -43,7 +43,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/simplicial/sccnn_custom.yaml b/configs/model/simplicial/sccnn_custom.yaml index 1c1dc91d9..21af02bbf 100755 --- a/configs/model/simplicial/sccnn_custom.yaml +++ b/configs/model/simplicial/sccnn_custom.yaml @@ -43,7 +43,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/simplicial/scn.yaml b/configs/model/simplicial/scn.yaml index caeaa24f6..887281287 100755 --- a/configs/model/simplicial/scn.yaml +++ b/configs/model/simplicial/scn.yaml @@ -34,7 +34,7 @@ readout: num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/simplicial/topotune.yaml b/configs/model/simplicial/topotune.yaml index 208328640..1401453b3 100755 --- a/configs/model/simplicial/topotune.yaml +++ b/configs/model/simplicial/topotune.yaml @@ -47,7 +47,7 @@ readout: num_cell_dimensions: ${infer_topotune_num_cell_dimensions:${oc.select:model.backbone.neighborhoods}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/model/simplicial/topotune_onehasse.yaml b/configs/model/simplicial/topotune_onehasse.yaml index d7504109c..b0e5febd6 100644 --- a/configs/model/simplicial/topotune_onehasse.yaml +++ b/configs/model/simplicial/topotune_onehasse.yaml @@ -46,7 +46,7 @@ readout: num_cell_dimensions: ${infer_topotune_num_cell_dimensions:${oc.select:model.backbone.neighborhoods}} # The highest order of cell dimensions to consider hidden_dim: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - task_level: ${dataset.parameters.task_level} + task_level: ${define_task_level:${dataset.parameters.task_level},${dataset.split_params.learning_setting}} # Handles the edge case of node-inductive task pooling_type: sum # compile model for faster training with pytorch 2.0 diff --git a/configs/run.yaml b/configs/run.yaml index 94bba373f..b2b57052a 100755 --- a/configs/run.yaml +++ b/configs/run.yaml @@ -31,9 +31,6 @@ defaults: # debugging config (enable through command line, e.g. `python train.py debug=default) - debug: null -# evaluator: ${dataset.parameters.task} -# callbacks: ${dataset.parameters.task} - # task name, determines output directory path task_name: "train" diff --git a/test/model/__init__.py b/test/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/model/test_model.py b/test/model/test_model.py new file mode 100644 index 000000000..e3c7a7353 --- /dev/null +++ b/test/model/test_model.py @@ -0,0 +1,176 @@ +"""Unit tests for TBModel.process_outputs.""" + +import pytest +import torch +from unittest.mock import MagicMock + +from topobench.model.model import TBModel + + +def make_model(task_level): + """Instantiate TBModel with mocked dependencies for a given task_level. + + Parameters + ---------- + task_level : str + The task level to assign to the readout mock. + + Returns + ------- + TBModel + A TBModel instance with mocked backbone, readout, loss, evaluator and optimizer. + """ + backbone = MagicMock() + backbone.parameters.return_value = [] + + readout = MagicMock() + readout.task_level = task_level + readout.parameters.return_value = [] + + loss = MagicMock() + + evaluator = MagicMock() + + optimizer = MagicMock() + optimizer.configure_optimizer.return_value = {"optimizer": MagicMock()} + + feature_encoder = MagicMock() + feature_encoder.parameters.return_value = [] + + model = TBModel( + backbone=backbone, + readout=readout, + loss=loss, + feature_encoder=feature_encoder, + evaluator=evaluator, + optimizer=optimizer, + ) + return model + + +class TestProcessOutputs: + """Tests for TBModel.process_outputs covering all branches.""" + + def _make_batch(self, n=10): + """Create a simple batch mock with train/val/test masks. + + Parameters + ---------- + n : int + Number of nodes. + + Returns + ------- + MagicMock + Batch mock with boolean masks. + """ + batch = MagicMock() + # First 6 are train, next 2 val, last 2 test + batch.train_mask = torch.tensor([1, 1, 1, 1, 1, 1, 0, 0, 0, 0], dtype=torch.bool) + batch.val_mask = torch.tensor([0, 0, 0, 0, 0, 0, 1, 1, 0, 0], dtype=torch.bool) + batch.test_mask = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=torch.bool) + return batch + + def _make_model_out(self, n=10): + """Create a sample model output dict. + + Parameters + ---------- + n : int + Number of nodes. + + Returns + ------- + dict + Model output with logits, labels, and an extra key. + """ + return { + "logits": torch.randn(n, 3), + "labels": torch.randint(0, 3, (n,)), + "x_0": torch.randn(n, 8), + } + + # ------------------------------------------------------------------ + # node-level masking + # ------------------------------------------------------------------ + + def test_node_training_filters_by_train_mask(self): + """process_outputs with task_level='node' and state 'Training' filters by train_mask.""" + model = make_model("node") + model.state_str = "Training" + batch = self._make_batch() + model_out = self._make_model_out() + + result = model.process_outputs(model_out, batch) + + n_train = batch.train_mask.sum().item() + assert result["logits"].shape[0] == n_train + assert result["labels"].shape[0] == n_train + # non-masked keys are untouched + assert result["x_0"].shape[0] == 10 + + def test_node_validation_filters_by_val_mask(self): + """process_outputs with task_level='node' and state 'Validation' filters by val_mask.""" + model = make_model("node") + model.state_str = "Validation" + batch = self._make_batch() + model_out = self._make_model_out() + + result = model.process_outputs(model_out, batch) + + n_val = batch.val_mask.sum().item() + assert result["logits"].shape[0] == n_val + assert result["labels"].shape[0] == n_val + + def test_node_test_filters_by_test_mask(self): + """process_outputs with task_level='node' and state 'Test' filters by test_mask.""" + model = make_model("node") + model.state_str = "Test" + batch = self._make_batch() + model_out = self._make_model_out() + + result = model.process_outputs(model_out, batch) + + n_test = batch.test_mask.sum().item() + assert result["logits"].shape[0] == n_test + assert result["labels"].shape[0] == n_test + + def test_node_invalid_state_raises_value_error(self): + """process_outputs with task_level='node' and an invalid state_str raises ValueError.""" + model = make_model("node") + model.state_str = "Invalid" + batch = self._make_batch() + model_out = self._make_model_out() + + with pytest.raises(ValueError, match="Invalid state_str"): + model.process_outputs(model_out, batch) + + # ------------------------------------------------------------------ + # no-op task levels + # ------------------------------------------------------------------ + + def test_graph_level_returns_unchanged(self): + """process_outputs with task_level='graph' returns model_out unchanged.""" + model = make_model("graph") + model.state_str = "Training" + batch = self._make_batch() + model_out = self._make_model_out() + original_logits = model_out["logits"].clone() + + result = model.process_outputs(model_out, batch) + + assert result["logits"].shape == original_logits.shape + assert torch.equal(result["logits"], original_logits) + + def test_node_inductive_returns_unchanged(self): + """process_outputs with task_level='node_inductive' returns model_out unchanged (inductive bug-fix path).""" + model = make_model("node_inductive") + model.state_str = "Training" + batch = self._make_batch() + model_out = self._make_model_out() + original_logits = model_out["logits"].clone() + + result = model.process_outputs(model_out, batch) + + assert result["logits"].shape == original_logits.shape + assert torch.equal(result["logits"], original_logits) diff --git a/test/utils/test_config_resolvers.py b/test/utils/test_config_resolvers.py index 4f57601bd..cc94dc850 100644 --- a/test/utils/test_config_resolvers.py +++ b/test/utils/test_config_resolvers.py @@ -4,6 +4,7 @@ from omegaconf import OmegaConf import hydra from topobench.utils.config_resolvers import ( + define_task_level, infer_in_channels, infer_num_cell_dimensions, infer_topotune_num_cell_dimensions, @@ -30,6 +31,16 @@ def setup_method(self): self.feature_lift_transform = OmegaConf.load("configs/transforms/feature_liftings/concatenate.yaml") hydra.initialize(version_base="1.3", config_path="../../configs", job_name="job") + def test_define_task_level(self): + """Test define_task_level.""" + # node + inductive -> node_inductive (the bug-fix branch) + assert define_task_level("node", "inductive") == "node_inductive" + + # else branch: any other combination returns dataset_task_level unchanged + assert define_task_level("node", "transductive") == "node" + assert define_task_level("graph", "inductive") == "graph" + assert define_task_level("graph", "transductive") == "graph" + def test_get_default_trainer(self): """Test get_default_trainer.""" out = get_default_trainer() diff --git a/topobench/model/model.py b/topobench/model/model.py index 2cbbae891..a29cd4266 100755 --- a/topobench/model/model.py +++ b/topobench/model/model.py @@ -232,17 +232,17 @@ def process_outputs(self, model_out: dict, batch: Data) -> dict: dict Dictionary containing the updated model output. """ - # Get the correct mask - if self.state_str == "Training": - mask = batch.train_mask - elif self.state_str == "Validation": - mask = batch.val_mask - elif self.state_str == "Test": - mask = batch.test_mask - else: - raise ValueError("Invalid state_str") - if self.task_level == "node": + # Get the correct mask + if self.state_str == "Training": + mask = batch.train_mask + elif self.state_str == "Validation": + mask = batch.val_mask + elif self.state_str == "Test": + mask = batch.test_mask + else: + raise ValueError("Invalid state_str") + # Keep only train data points for key, val in model_out.items(): if key in ["logits", "labels"]: diff --git a/topobench/nn/readouts/base.py b/topobench/nn/readouts/base.py index 6fdd8412f..b4b658868 100755 --- a/topobench/nn/readouts/base.py +++ b/topobench/nn/readouts/base.py @@ -42,7 +42,7 @@ def __init__( if hidden_dim != out_channels or logits_linear_layer else torch.nn.Identity() ) - assert task_level in ["graph", "node"], "Invalid task_level" + assert task_level in ["graph", "node", "node_inductive"], "Invalid task_level" self.task_level = task_level self.logits_linear_layer = logits_linear_layer diff --git a/topobench/run.py b/topobench/run.py index eec0d61a0..c02dfd9bf 100755 --- a/topobench/run.py +++ b/topobench/run.py @@ -27,6 +27,7 @@ task_wrapper, ) from topobench.utils.config_resolvers import ( + define_task_level, get_default_metrics, get_default_trainer, get_default_transform, @@ -59,6 +60,9 @@ # ------------------------------------------------------------------------------------ # +OmegaConf.register_new_resolver( + "define_task_level", define_task_level, replace=True +) OmegaConf.register_new_resolver( "get_default_metrics", get_default_metrics, replace=True ) diff --git a/topobench/utils/config_resolvers.py b/topobench/utils/config_resolvers.py index e797b0c23..4c773c6a7 100644 --- a/topobench/utils/config_resolvers.py +++ b/topobench/utils/config_resolvers.py @@ -6,6 +6,32 @@ import torch +def define_task_level(dataset_task_level, learning_setting): + r"""Define the task level for a given dataset task level and learning setting. + + Parameters + ---------- + dataset_task_level : str + Task level defined in the dataset configuration file. + learning_setting : str + Learning setting defined in the dataset split parameters. + + Returns + ------- + str + Task level for the model. + + Raises + ------ + ValueError + If the dataset task level or learning setting is invalid. + """ + if dataset_task_level == "node" and learning_setting == "inductive": + return "node_inductive" + else: + return dataset_task_level + + def get_flattened_channels(num_nodes, channels): r"""Get the output dimension of flattening a feature matrix.