Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def _make_dataset(training: bool) -> kd.data.Pipeline:
_LABEL_FIELD = "label" # pylint: disable=invalid-name

tokenizer = gm.text.Gemma3Tokenizer()
yes_tokens = tokenizer.encode("Yes", add_special_tokens=False)
no_tokens = tokenizer.encode("No", add_special_tokens=False)

if len(yes_tokens) != 1 or len(no_tokens) != 1:
raise ValueError(
"'Yes' and 'No' must map to a single token for classification."
)

return kd.data.py.Tfds(
name="glue/cola",
Expand All @@ -96,7 +103,7 @@ def _make_dataset(training: bool) -> kd.data.Pipeline:
gm.data.FormatText(
key=_INPUT_FIELD,
template="""<start_of_turn>user
Please classify whether the following sentence is grammaticaly correct, please answer only with Yes or No.
Please classify whether the following sentence is grammatically correct, please answer only with Yes or No.
Sentence: {text}<end_of_turn>
<start_of_turn>model""",
),
Expand All @@ -110,18 +117,17 @@ def _make_dataset(training: bool) -> kd.data.Pipeline:
max_length=128,
),
# Process the label
gm.data.MapInts(
key=_LABEL_FIELD,
# Rather than predicting the token 0 and 1, we are using the
# token 1294 and 3553 which respectivelly correspond to "No" and
# "Yes". We do this because those token already contain semantic
# information, so even zero-shot prediction without any
# finetuning has better than random performances.
old_to_new={
0: 1294, # Token -> "No"
1: 3553, # Token -> "Yes"
},
),
gm.data.MapInts(
key=_LABEL_FIELD,
# Rather than predicting tokens 0 and 1, we map labels to the
# tokenizer-derived token IDs for "No" and "Yes". These tokens
# contain semantic information, which improves zero-shot
# performance even without finetuning.
old_to_new={
0: no_tokens[0], # "No"
1: yes_tokens[0], # "Yes"
},
),
kd.data.Rearrange(
key=_LABEL_FIELD,
pattern="... -> ... 1", # For shape compatibility with the loss.
Expand Down