Skip to content

Commit 806ce8e

Browse files
fix: adapt tests
1 parent 4577614 commit 806ce8e

1 file changed

Lines changed: 1 addition & 8 deletions

File tree

tests/test_pipeline.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,7 @@ def run_full_pipeline(
152152
categorical_embedding_dims=model_params["categorical_embedding_dims"],
153153
num_classes=model_params["num_classes"],
154154
attention_config=attention_config,
155-
label_attention_config=(
156-
LabelAttentionConfig(
157-
n_head=attention_config.n_head,
158-
num_classes=model_params["num_classes"],
159-
)
160-
if label_attention_enabled
161-
else None
162-
),
155+
n_heads_label_attention=attention_config.n_head,
163156
)
164157

165158
# Create training config

0 commit comments

Comments
 (0)