@@ -47,7 +47,7 @@ def multiclass_dataset():
4747
4848def test_gcn_scorer_multilabel (multilabel_dataset ):
4949 torch .manual_seed (42 )
50- scorer = GCNScorer (embedder_config = _embedder_config , num_train_epochs = 1 , batch_size = 2 , seed = 42 )
50+ scorer = GCNScorer (embedder_config = _embedder_config , label_embedder_config = _embedder_config , num_train_epochs = 1 , batch_size = 2 , seed = 42 )
5151 train_utterances = multilabel_dataset ["train" ]["utterance" ]
5252 train_labels = multilabel_dataset ["train" ]["label" ]
5353 descriptions = [intent .name for intent in multilabel_dataset .intents ]
@@ -62,7 +62,7 @@ def test_gcn_scorer_multilabel(multilabel_dataset):
6262
6363def test_gcn_scorer_multiclass (multiclass_dataset ):
6464 torch .manual_seed (42 )
65- scorer = GCNScorer (embedder_config = _embedder_config , num_train_epochs = 1 , batch_size = 2 , seed = 42 )
65+ scorer = GCNScorer (embedder_config = _embedder_config , label_embedder_config = _embedder_config , num_train_epochs = 1 , batch_size = 2 , seed = 42 )
6666 train_utterances = multiclass_dataset ["train" ]["utterance" ]
6767 train_labels = multiclass_dataset ["train" ]["label" ]
6868 descriptions = [intent .name for intent in multiclass_dataset .intents ]
@@ -78,7 +78,7 @@ def test_gcn_scorer_multiclass(multiclass_dataset):
7878
7979def test_gcn_scorer_dump_load (tmp_path , multilabel_dataset ):
8080 torch .manual_seed (42 )
81- scorer = GCNScorer (embedder_config = _embedder_config , num_train_epochs = 1 , batch_size = 2 , seed = 42 )
81+ scorer = GCNScorer (embedder_config = _embedder_config , label_embedder_config = _embedder_config , num_train_epochs = 1 , batch_size = 2 , seed = 42 )
8282 train_utterances = multilabel_dataset ["train" ]["utterance" ]
8383 train_labels = multilabel_dataset ["train" ]["label" ]
8484 descriptions = [intent .name for intent in multilabel_dataset .intents ]
0 commit comments