diff --git a/plugins/flytekit-ray/flytekitplugins/ray/models.py b/plugins/flytekit-ray/flytekitplugins/ray/models.py index 1f3a830f16..a677c632a6 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/models.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/models.py @@ -220,12 +220,14 @@ def __init__( runtime_env_yaml: typing.Optional[str] = None, ttl_seconds_after_finished: typing.Optional[int] = None, shutdown_after_job_finishes: bool = False, + address: typing.Optional[str] = None, ): self._ray_cluster = ray_cluster self._runtime_env = runtime_env self._runtime_env_yaml = runtime_env_yaml self._ttl_seconds_after_finished = ttl_seconds_after_finished self._shutdown_after_job_finishes = shutdown_after_job_finishes + self._address = address @property def ray_cluster(self) -> RayCluster: @@ -249,6 +251,12 @@ def shutdown_after_job_finishes(self) -> bool: # shutdown_after_job_finishes specifies whether the RayCluster should be deleted after the RayJob finishes. return self._shutdown_after_job_finishes + @property + def address(self) -> typing.Optional[str]: + # address specifies the Ray head address to connect to for an existing cluster. + # When set, the RayJob submits to an existing RayCluster instead of creating a new one. + return self._address + def to_flyte_idl(self) -> _ray_pb2.RayJob: return _ray_pb2.RayJob( ray_cluster=self.ray_cluster.to_flyte_idl(), @@ -256,6 +264,7 @@ def to_flyte_idl(self) -> _ray_pb2.RayJob: runtime_env_yaml=self.runtime_env_yaml, ttl_seconds_after_finished=self.ttl_seconds_after_finished, shutdown_after_job_finishes=self.shutdown_after_job_finishes, + address=self.address if self.address else "", ) @classmethod @@ -266,4 +275,5 @@ def from_flyte_idl(cls, proto: _ray_pb2.RayJob): runtime_env_yaml=proto.runtime_env_yaml, ttl_seconds_after_finished=proto.ttl_seconds_after_finished, shutdown_after_job_finishes=proto.shutdown_after_job_finishes, + address=proto.address if proto.address else None, ) diff --git a/plugins/flytekit-ray/flytekitplugins/ray/task.py b/plugins/flytekit-ray/flytekitplugins/ray/task.py index 8bdea4dd5a..e842a8f247 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/task.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/task.py @@ -166,6 +166,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] runtime_env_yaml=runtime_env_yaml, ttl_seconds_after_finished=cfg.ttl_seconds_after_finished, shutdown_after_job_finishes=cfg.shutdown_after_job_finishes, + address=cfg.address, ) return MessageToDict(ray_job.to_flyte_idl())