1515# limitations under the License.
1616#
1717import dataclasses
18- from typing import List , Optional , Dict , Any
18+ from typing import List , Optional
19+ from vertexai .preview ._workflow .serialization_engine import (
20+ serializers_base ,
21+ )
1922
2023
2124@dataclasses .dataclass
@@ -72,16 +75,33 @@ class RemoteConfig(_BaseConfig):
7275 ]
7376
7477 # Specify the extra parameters needed for serializing objects.
75- model.train.vertex.remote_config.serializer_args = {
76- model: {
77- "extra_serializer_param1_for_model": param1_value,
78- "extra_serializer_param2_for_model": param2_value,
78+ from vertexai.preview.developer import SerializerArgs
79+
80+ # You can put all the hashable objects with their arguments in the
81+ # SerializerArgs all at once in a dict. Here we assume "model" is
82+ # hashable.
83+ model.train.vertex.remote_config.serializer_args = SerializerArgs({
84+ model: {
85+ "extra_serializer_param1_for_model": param1_value,
86+ "extra_serializer_param2_for_model": param2_value,
87+ },
88+ hashable_obj2: {
89+ "extra_serializer_param1_for_hashable2": param1_value,
90+ "extra_serializer_param2_for_hashable2": param2_value,
91+ },
92+ })
93+ # Or if the object to be serialized is unhashable, put them into the
94+ # serializer_args one by one. If this is the only use case, there is
95+ # no need to import `SerializerArgs`. Here we assume "X_train" and
96+ # "y_train" is not hashable.
97+ model.train.vertex.remote_config.serializer_args[X_train] = {
98+ "extra_serializer_param1_for_X_train": param1_value,
99+ "extra_serializer_param2_for_X_train": param2_value,
79100 },
80- X_train: {
81- "extra_serializer_param1 ": param1_value,
82- "extra_serializer_param2 ": param2_value,
101+ model.train.vertex.remote_config.serializer_args[y_train] = {
102+ "extra_serializer_param1_for_y_train ": param1_value,
103+ "extra_serializer_param2_for_y_train ": param2_value,
83104 }
84- }
85105
86106 # Train the model as usual
87107 model.train(X_train, y_train)
@@ -132,7 +152,7 @@ class RemoteConfig(_BaseConfig):
132152 custom_commands (List[str]):
133153 List of custom commands to be run in the remote job environment.
134154 These commands will be run before the requirements are installed.
135- serializer_args (Dict[Any, Dict[str, Any]]) :
155+ serializer_args: serializers_base.SerializerArgs :
136156 Map from object to extra arguments when serializing the object. The extra
137157 arguments is a dictionary from the argument names to the argument values.
138158 """
@@ -143,7 +163,9 @@ class RemoteConfig(_BaseConfig):
143163 service_account : Optional [str ] = None
144164 requirements : List [str ] = dataclasses .field (default_factory = list )
145165 custom_commands : List [str ] = dataclasses .field (default_factory = list )
146- serializer_args : Dict [Any , Dict [str , Any ]] = dataclasses .field (default_factory = dict )
166+ serializer_args : serializers_base .SerializerArgs = dataclasses .field (
167+ default_factory = serializers_base .SerializerArgs
168+ )
147169
148170
149171@dataclasses .dataclass
0 commit comments