Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,14 @@ def __init__(
runtime_env_yaml: typing.Optional[str] = None,
ttl_seconds_after_finished: typing.Optional[int] = None,
shutdown_after_job_finishes: bool = False,
address: typing.Optional[str] = None,
):
self._ray_cluster = ray_cluster
self._runtime_env = runtime_env
self._runtime_env_yaml = runtime_env_yaml
self._ttl_seconds_after_finished = ttl_seconds_after_finished
self._shutdown_after_job_finishes = shutdown_after_job_finishes
self._address = address

@property
def ray_cluster(self) -> RayCluster:
Expand All @@ -249,13 +251,20 @@ def shutdown_after_job_finishes(self) -> bool:
# shutdown_after_job_finishes specifies whether the RayCluster should be deleted after the RayJob finishes.
return self._shutdown_after_job_finishes

@property
def address(self) -> typing.Optional[str]:
# address specifies the Ray head address to connect to for an existing cluster.
# When set, the RayJob submits to an existing RayCluster instead of creating a new one.
return self._address

def to_flyte_idl(self) -> _ray_pb2.RayJob:
return _ray_pb2.RayJob(
ray_cluster=self.ray_cluster.to_flyte_idl(),
runtime_env=self.runtime_env,
runtime_env_yaml=self.runtime_env_yaml,
ttl_seconds_after_finished=self.ttl_seconds_after_finished,
shutdown_after_job_finishes=self.shutdown_after_job_finishes,
address=self.address if self.address else "",
)

@classmethod
Expand All @@ -266,4 +275,5 @@ def from_flyte_idl(cls, proto: _ray_pb2.RayJob):
runtime_env_yaml=proto.runtime_env_yaml,
ttl_seconds_after_finished=proto.ttl_seconds_after_finished,
shutdown_after_job_finishes=proto.shutdown_after_job_finishes,
address=proto.address if proto.address else None,
)
1 change: 1 addition & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
runtime_env_yaml=runtime_env_yaml,
ttl_seconds_after_finished=cfg.ttl_seconds_after_finished,
shutdown_after_job_finishes=cfg.shutdown_after_job_finishes,
address=cfg.address,
)
return MessageToDict(ray_job.to_flyte_idl())

Expand Down