2929
3030from __future__ import annotations
3131
32- from pydantic import BaseModel , ConfigDict , field_validator
32+ from typing import Literal
33+
34+ from pydantic import BaseModel , ConfigDict , field_validator , model_validator
3335
3436
3537class ModelArguments (BaseModel ):
@@ -47,6 +49,7 @@ class DataArguments(BaseModel):
4749
4850 model_config = ConfigDict (extra = "forbid" )
4951
52+ mode : Literal ["online" , "offline" , "streaming" ] = "online"
5053 data_path : str | None = None
5154 offline_data_path : str | None = None
5255 lazy_preprocess : bool = True
@@ -67,6 +70,25 @@ def _check_sample_size(cls, v: int) -> int:
6770 raise ValueError ("sample_size must be -1 (use all samples) or a positive integer" )
6871 return v
6972
73+ @model_validator (mode = "after" )
74+ def _check_mode_requirements (self ) -> DataArguments :
75+ # Backward-compat: if ``mode`` is left at its default ("online") but a mode-specific
76+ # field is set, promote to the corresponding mode. Explicit ``mode`` always wins.
77+ if self .mode == "online" :
78+ if self .offline_data_path is not None :
79+ self .mode = "offline"
80+ elif self .streaming_server_url is not None :
81+ self .mode = "streaming"
82+ if self .mode == "offline" and not self .offline_data_path :
83+ raise ValueError ("data.mode='offline' requires data.offline_data_path" )
84+ if self .mode == "streaming" and not (
85+ self .streaming_server_url and self .streaming_model_name
86+ ):
87+ raise ValueError (
88+ "data.mode='streaming' requires data.streaming_server_url and data.streaming_model_name"
89+ )
90+ return self
91+
7092
7193class TrainingArguments (BaseModel ):
7294 """Speculative-decoding extensions on top of ``transformers.TrainingArguments``.
0 commit comments