Skip to content

Commit 0a1880b

Browse files
test(label attention): add a test_pipeline with label attention activated
1 parent 30ef8af commit 0a1880b

1 file changed

Lines changed: 51 additions & 2 deletions

File tree

tests/test_pipeline.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AttentionConfig,
1010
CategoricalVariableNet,
1111
ClassificationHead,
12+
LabelAttentionConfig,
1213
TextEmbedder,
1314
TextEmbedderConfig,
1415
)
@@ -51,7 +52,14 @@ def model_params():
5152
}
5253

5354

54-
def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params):
55+
def run_full_pipeline(
56+
tokenizer,
57+
sample_text_data,
58+
categorical_data,
59+
labels,
60+
model_params,
61+
label_attention_enabled: bool = False,
62+
):
5563
"""Helper function to run the complete pipeline for a given tokenizer."""
5664
# Create dataset
5765
dataset = TextClassificationDataset(
@@ -83,6 +91,15 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
8391
embedding_dim=model_params["embedding_dim"],
8492
padding_idx=padding_idx,
8593
attention_config=attention_config,
94+
label_attention_config=(
95+
LabelAttentionConfig(
96+
n_head=attention_config.n_head,
97+
n_kv_head=attention_config.n_kv_head,
98+
num_classes=model_params["num_classes"],
99+
)
100+
if label_attention_enabled
101+
else None
102+
),
86103
)
87104

88105
text_embedder = TextEmbedder(text_embedder_config=text_embedder_config)
@@ -98,7 +115,7 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
98115
expected_input_dim = model_params["embedding_dim"] + categorical_var_net.output_dim
99116
classification_head = ClassificationHead(
100117
input_dim=expected_input_dim,
101-
num_classes=model_params["num_classes"],
118+
num_classes=model_params["num_classes"] if not label_attention_enabled else 1,
102119
)
103120

104121
# Create model
@@ -136,6 +153,15 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
136153
categorical_embedding_dims=model_params["categorical_embedding_dims"],
137154
num_classes=model_params["num_classes"],
138155
attention_config=attention_config,
156+
label_attention_config=(
157+
LabelAttentionConfig(
158+
n_head=attention_config.n_head,
159+
n_kv_head=attention_config.n_kv_head,
160+
num_classes=model_params["num_classes"],
161+
)
162+
if label_attention_enabled
163+
else None
164+
),
139165
)
140166

141167
# Create training config
@@ -239,3 +265,26 @@ def test_ngram_tokenizer(sample_data, model_params):
239265

240266
# Run full pipeline
241267
run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params)
268+
269+
270+
def test_label_attention_enabled(sample_data, model_params):
271+
"""Test the full pipeline with label attention enabled (using WordPieceTokenizer)."""
272+
sample_text_data, categorical_data, labels = sample_data
273+
274+
vocab_size = 100
275+
tokenizer = WordPieceTokenizer(vocab_size, output_dim=50)
276+
tokenizer.train(sample_text_data)
277+
278+
# Check tokenizer works
279+
result = tokenizer.tokenize(sample_text_data)
280+
assert result.input_ids.shape[0] == len(sample_text_data)
281+
282+
# Run full pipeline with label attention enabled
283+
run_full_pipeline(
284+
tokenizer,
285+
sample_text_data,
286+
categorical_data,
287+
labels,
288+
model_params,
289+
label_attention_enabled=True,
290+
)

0 commit comments

Comments
 (0)