Skip to content

Commit ad38e3a

Browse files
speedstorm1copybara-github
authored andcommitted
feat: support hyperparameters in distillation tuning
PiperOrigin-RevId: 890622141
1 parent 4dbb277 commit ad38e3a

File tree

3 files changed

+66
-10
lines changed

3 files changed

+66
-10
lines changed

google/genai/tests/tunings/test_tune.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,25 @@
245245
),
246246
exception_if_mldev="not supported in Gemini API",
247247
),
248+
pytest_helper.TestTableItem(
249+
name="test_tune_oss_distillation_hyperparams",
250+
parameters=genai_types.CreateTuningJobParameters(
251+
base_model="qwen/qwen3@qwen3-4b",
252+
training_dataset=genai_types.TuningDataset(
253+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
254+
),
255+
config=genai_types.CreateTuningJobConfig(
256+
method="DISTILLATION",
257+
base_teacher_model="deepseek-ai/deepseek-r1-0528-maas",
258+
learning_rate=1e-4,
259+
batch_size=4,
260+
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
261+
tuning_mode="TUNING_MODE_FULL",
262+
http_options=VERTEX_HTTP_OPTIONS,
263+
),
264+
),
265+
exception_if_mldev="not supported in Gemini API",
266+
),
248267
pytest_helper.TestTableItem(
249268
name="test_tune_encryption_spec",
250269
parameters=genai_types.CreateTuningJobParameters(

google/genai/tunings.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,13 @@ def _CreateTuningJobConfig_to_vertex(
409409
['supervisedTuningSpec', 'tuningMode'],
410410
getv(from_object, ['tuning_mode']),
411411
)
412+
elif discriminator == 'DISTILLATION':
413+
if getv(from_object, ['tuning_mode']) is not None:
414+
setv(
415+
parent_object,
416+
['distillationSpec', 'tuningMode'],
417+
getv(from_object, ['tuning_mode']),
418+
)
412419

413420
if getv(from_object, ['custom_base_model']) is not None:
414421
setv(
@@ -427,6 +434,13 @@ def _CreateTuningJobConfig_to_vertex(
427434
['supervisedTuningSpec', 'hyperParameters', 'batchSize'],
428435
getv(from_object, ['batch_size']),
429436
)
437+
elif discriminator == 'DISTILLATION':
438+
if getv(from_object, ['batch_size']) is not None:
439+
setv(
440+
parent_object,
441+
['distillationSpec', 'hyperParameters', 'batchSize'],
442+
getv(from_object, ['batch_size']),
443+
)
430444

431445
discriminator = getv(root_object, ['config', 'method'])
432446
if discriminator is None:
@@ -438,6 +452,13 @@ def _CreateTuningJobConfig_to_vertex(
438452
['supervisedTuningSpec', 'hyperParameters', 'learningRate'],
439453
getv(from_object, ['learning_rate']),
440454
)
455+
elif discriminator == 'DISTILLATION':
456+
if getv(from_object, ['learning_rate']) is not None:
457+
setv(
458+
parent_object,
459+
['distillationSpec', 'hyperParameters', 'learningRate'],
460+
getv(from_object, ['learning_rate']),
461+
)
441462

442463
discriminator = getv(root_object, ['config', 'method'])
443464
if discriminator is None:

google/genai/types.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11649,10 +11649,7 @@ class PreferenceOptimizationSpecDict(TypedDict, total=False):
1164911649

1165011650

1165111651
class DistillationHyperParameters(_common.BaseModel):
11652-
"""Hyperparameters for Distillation.
11653-
11654-
This data type is not supported in Gemini API.
11655-
"""
11652+
"""Hyperparameters for distillation."""
1165611653

1165711654
adapter_size: Optional[AdapterSize] = Field(
1165811655
default=None, description="""Optional. Adapter size for distillation."""
@@ -11665,13 +11662,19 @@ class DistillationHyperParameters(_common.BaseModel):
1166511662
default=None,
1166611663
description="""Optional. Multiplier for adjusting the default learning rate.""",
1166711664
)
11665+
batch_size: Optional[int] = Field(
11666+
default=None,
11667+
description="""The batch size hyperparameter for tuning.
11668+
This is only supported for OSS models in Vertex.""",
11669+
)
11670+
learning_rate: Optional[float] = Field(
11671+
default=None,
11672+
description="""The learning rate for tuning. OSS models only.""",
11673+
)
1166811674

1166911675

1167011676
class DistillationHyperParametersDict(TypedDict, total=False):
11671-
"""Hyperparameters for Distillation.
11672-
11673-
This data type is not supported in Gemini API.
11674-
"""
11677+
"""Hyperparameters for distillation."""
1167511678

1167611679
adapter_size: Optional[AdapterSize]
1167711680
"""Optional. Adapter size for distillation."""
@@ -11682,6 +11685,13 @@ class DistillationHyperParametersDict(TypedDict, total=False):
1168211685
learning_rate_multiplier: Optional[float]
1168311686
"""Optional. Multiplier for adjusting the default learning rate."""
1168411687

11688+
batch_size: Optional[int]
11689+
"""The batch size hyperparameter for tuning.
11690+
This is only supported for OSS models in Vertex."""
11691+
11692+
learning_rate: Optional[float]
11693+
"""The learning rate for tuning. OSS models only."""
11694+
1168511695

1168611696
DistillationHyperParametersOrDict = Union[
1168711697
DistillationHyperParameters, DistillationHyperParametersDict
@@ -11723,6 +11733,9 @@ class DistillationSpec(_common.BaseModel):
1172311733
default=None,
1172411734
description="""Optional. Cloud Storage path to file containing validation dataset for tuning. The dataset must be formatted as a JSONL file.""",
1172511735
)
11736+
tuning_mode: Optional[TuningMode] = Field(
11737+
default=None, description="""Tuning mode for tuning."""
11738+
)
1172611739

1172711740

1172811741
class DistillationSpecDict(TypedDict, total=False):
@@ -11752,6 +11765,9 @@ class DistillationSpecDict(TypedDict, total=False):
1175211765
validation_dataset_uri: Optional[str]
1175311766
"""Optional. Cloud Storage path to file containing validation dataset for tuning. The dataset must be formatted as a JSONL file."""
1175411767

11768+
tuning_mode: Optional[TuningMode]
11769+
"""Tuning mode for tuning."""
11770+
1175511771

1175611772
DistillationSpecOrDict = Union[DistillationSpec, DistillationSpecDict]
1175711773

@@ -13933,7 +13949,7 @@ class CreateTuningJobConfig(_common.BaseModel):
1393313949
default=None, description="""Adapter size for tuning."""
1393413950
)
1393513951
tuning_mode: Optional[TuningMode] = Field(
13936-
default=None, description="""Tuning mode for SFT tuning."""
13952+
default=None, description="""Tuning mode for tuning."""
1393713953
)
1393813954
custom_base_model: Optional[str] = Field(
1393913955
default=None,
@@ -14014,7 +14030,7 @@ class CreateTuningJobConfigDict(TypedDict, total=False):
1401414030
"""Adapter size for tuning."""
1401514031

1401614032
tuning_mode: Optional[TuningMode]
14017-
"""Tuning mode for SFT tuning."""
14033+
"""Tuning mode for tuning."""
1401814034

1401914035
custom_base_model: Optional[str]
1402014036
"""Custom base model for tuning. This is only supported for OSS models in Vertex."""

0 commit comments

Comments
 (0)