We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4577614 commit 806ce8eCopy full SHA for 806ce8e
1 file changed
tests/test_pipeline.py
@@ -152,14 +152,7 @@ def run_full_pipeline(
152
categorical_embedding_dims=model_params["categorical_embedding_dims"],
153
num_classes=model_params["num_classes"],
154
attention_config=attention_config,
155
- label_attention_config=(
156
- LabelAttentionConfig(
157
- n_head=attention_config.n_head,
158
- num_classes=model_params["num_classes"],
159
- )
160
- if label_attention_enabled
161
- else None
162
- ),
+ n_heads_label_attention=attention_config.n_head,
163
)
164
165
# Create training config
0 commit comments