55import threading
66from abc import ABC , abstractmethod
77from functools import lru_cache
8+ from pathlib import Path
89from typing import Dict , List , Optional
910
1011import git
3637)
3738from dstack ._internal .core .services import is_valid_dstack_resource_name
3839from dstack ._internal .utils .logging import get_logger
40+ from dstack ._internal .utils .path import PathLike
3941
4042logger = get_logger (__name__ )
4143
42- DSTACK_WORKING_DIR = "/root/.dstack"
4344DSTACK_SHIM_BINARY_NAME = "dstack-shim"
44- DSTACK_SHIM_BINARY_PATH = f"/usr/local/bin/{ DSTACK_SHIM_BINARY_NAME } "
4545DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
46- DSTACK_RUNNER_BINARY_PATH = f"/usr/local/bin/{ DSTACK_RUNNER_BINARY_NAME } "
4746
4847
4948class Compute (ABC ):
@@ -336,6 +335,24 @@ def is_volume_detached(self, volume: Volume, instance_id: str) -> bool:
336335 return True
337336
338337
338+ def get_dstack_working_dir (base_path : Optional [PathLike ] = None ) -> str :
339+ if base_path is None :
340+ base_path = "/root"
341+ return str (Path (base_path , ".dstack" ))
342+
343+
344+ def get_dstack_shim_binary_path (bin_path : Optional [PathLike ] = None ) -> str :
345+ if bin_path is None :
346+ bin_path = "/usr/local/bin"
347+ return str (Path (bin_path , DSTACK_SHIM_BINARY_NAME ))
348+
349+
350+ def get_dstack_runner_binary_path (bin_path : Optional [PathLike ] = None ) -> str :
351+ if bin_path is None :
352+ bin_path = "/usr/local/bin"
353+ return str (Path (bin_path , DSTACK_RUNNER_BINARY_NAME ))
354+
355+
339356def get_job_instance_name (run : Run , job : Job ) -> str :
340357 return job .job_spec .job_name
341358
@@ -442,39 +459,74 @@ def get_cloud_config(**config) -> str:
442459
443460
444461def get_user_data (
445- authorized_keys : List [str ], backend_specific_commands : Optional [List [str ]] = None
462+ authorized_keys : List [str ],
463+ backend_specific_commands : Optional [List [str ]] = None ,
464+ base_path : Optional [PathLike ] = None ,
465+ bin_path : Optional [PathLike ] = None ,
466+ backend_shim_env : Optional [Dict [str , str ]] = None ,
446467) -> str :
447- shim_commands = get_shim_commands (authorized_keys )
468+ shim_commands = get_shim_commands (
469+ authorized_keys = authorized_keys ,
470+ base_path = base_path ,
471+ bin_path = bin_path ,
472+ backend_shim_env = backend_shim_env ,
473+ )
448474 commands = (backend_specific_commands or []) + shim_commands
449475 return get_cloud_config (
450476 runcmd = [["sh" , "-c" , " && " .join (commands )]],
451477 ssh_authorized_keys = authorized_keys ,
452478 )
453479
454480
455- def get_shim_env (authorized_keys : List [str ]) -> Dict [str , str ]:
481+ def get_shim_env (
482+ authorized_keys : List [str ],
483+ base_path : Optional [PathLike ] = None ,
484+ bin_path : Optional [PathLike ] = None ,
485+ backend_shim_env : Optional [Dict [str , str ]] = None ,
486+ ) -> Dict [str , str ]:
456487 log_level = "6" # Trace
457488 envs = {
458- "DSTACK_SHIM_HOME" : DSTACK_WORKING_DIR ,
489+ "DSTACK_SHIM_HOME" : get_dstack_working_dir ( base_path ) ,
459490 "DSTACK_SHIM_HTTP_PORT" : str (DSTACK_SHIM_HTTP_PORT ),
460491 "DSTACK_SHIM_LOG_LEVEL" : log_level ,
461492 "DSTACK_RUNNER_DOWNLOAD_URL" : get_dstack_runner_download_url (),
462- "DSTACK_RUNNER_BINARY_PATH" : DSTACK_RUNNER_BINARY_PATH ,
493+ "DSTACK_RUNNER_BINARY_PATH" : get_dstack_runner_binary_path ( bin_path ) ,
463494 "DSTACK_RUNNER_HTTP_PORT" : str (DSTACK_RUNNER_HTTP_PORT ),
464495 "DSTACK_RUNNER_SSH_PORT" : str (DSTACK_RUNNER_SSH_PORT ),
465496 "DSTACK_RUNNER_LOG_LEVEL" : log_level ,
466497 "DSTACK_PUBLIC_SSH_KEY" : "\n " .join (authorized_keys ),
467498 }
499+ if backend_shim_env is not None :
500+ envs |= backend_shim_env
468501 return envs
469502
470503
471504def get_shim_commands (
472- authorized_keys : List [str ], * , is_privileged : bool = False , pjrt_device : Optional [str ] = None
505+ authorized_keys : List [str ],
506+ * ,
507+ is_privileged : bool = False ,
508+ pjrt_device : Optional [str ] = None ,
509+ base_path : Optional [PathLike ] = None ,
510+ bin_path : Optional [PathLike ] = None ,
511+ backend_shim_env : Optional [Dict [str , str ]] = None ,
473512) -> List [str ]:
474- commands = get_shim_pre_start_commands ()
475- for k , v in get_shim_env (authorized_keys ).items ():
513+ commands = get_shim_pre_start_commands (
514+ base_path = base_path ,
515+ bin_path = bin_path ,
516+ )
517+ shim_env = get_shim_env (
518+ authorized_keys = authorized_keys ,
519+ base_path = base_path ,
520+ bin_path = bin_path ,
521+ backend_shim_env = backend_shim_env ,
522+ )
523+ for k , v in shim_env .items ():
476524 commands += [f'export "{ k } ={ v } "' ]
477- commands += get_run_shim_script (is_privileged , pjrt_device )
525+ commands += get_run_shim_script (
526+ is_privileged = is_privileged ,
527+ pjrt_device = pjrt_device ,
528+ bin_path = bin_path ,
529+ )
478530 return commands
479531
480532
@@ -511,25 +563,33 @@ def get_dstack_shim_download_url() -> str:
511563 return f"https://{ bucket } .s3.eu-west-1.amazonaws.com/{ build } /binaries/dstack-shim-linux-amd64"
512564
513565
514- def get_shim_pre_start_commands () -> List [str ]:
566+ def get_shim_pre_start_commands (
567+ base_path : Optional [PathLike ] = None ,
568+ bin_path : Optional [PathLike ] = None ,
569+ ) -> List [str ]:
515570 url = get_dstack_shim_download_url ()
516-
571+ dstack_shim_binary_path = get_dstack_shim_binary_path (bin_path )
572+ dstack_working_dir = get_dstack_working_dir (base_path )
517573 return [
518574 f"dlpath=$(sudo mktemp -t { DSTACK_SHIM_BINARY_NAME } .XXXXXXXXXX)" ,
519575 # -sS -- disable progress meter and warnings, but still show errors (unlike bare -s)
520576 f'sudo curl -sS --compressed --connect-timeout 60 --max-time 240 --retry 1 --output "$dlpath" "{ url } "' ,
521- f'sudo mv "$dlpath" { DSTACK_SHIM_BINARY_PATH } ' ,
522- f"sudo chmod +x { DSTACK_SHIM_BINARY_PATH } " ,
523- f"sudo mkdir { DSTACK_WORKING_DIR } -p" ,
577+ f'sudo mv "$dlpath" { dstack_shim_binary_path } ' ,
578+ f"sudo chmod +x { dstack_shim_binary_path } " ,
579+ f"sudo mkdir { dstack_working_dir } -p" ,
524580 ]
525581
526582
527- def get_run_shim_script (is_privileged : bool , pjrt_device : Optional [str ]) -> List [str ]:
583+ def get_run_shim_script (
584+ is_privileged : bool ,
585+ pjrt_device : Optional [str ],
586+ bin_path : Optional [PathLike ] = None ,
587+ ) -> List [str ]:
588+ dstack_shim_binary_path = get_dstack_shim_binary_path (bin_path )
528589 privileged_flag = "--privileged" if is_privileged else ""
529590 pjrt_device_env = f"--pjrt-device={ pjrt_device } " if pjrt_device else ""
530-
531591 return [
532- f"nohup { DSTACK_SHIM_BINARY_PATH } { privileged_flag } { pjrt_device_env } &" ,
592+ f"nohup { dstack_shim_binary_path } { privileged_flag } { pjrt_device_env } &" ,
533593 ]
534594
535595
@@ -555,7 +615,11 @@ def get_gateway_user_data(authorized_key: str) -> str:
555615 )
556616
557617
558- def get_docker_commands (authorized_keys : list [str ]) -> list [str ]:
618+ def get_docker_commands (
619+ authorized_keys : list [str ],
620+ bin_path : Optional [PathLike ] = None ,
621+ ) -> list [str ]:
622+ dstack_runner_binary_path = get_dstack_runner_binary_path (bin_path )
559623 authorized_keys_content = "\n " .join (authorized_keys ).strip ()
560624 commands = [
561625 # save and unset ld.so variables
@@ -606,10 +670,10 @@ def get_docker_commands(authorized_keys: list[str]) -> list[str]:
606670
607671 url = get_dstack_runner_download_url ()
608672 commands += [
609- f"curl --connect-timeout 60 --max-time 240 --retry 1 --output { DSTACK_RUNNER_BINARY_PATH } { url } " ,
610- f"chmod +x { DSTACK_RUNNER_BINARY_PATH } " ,
673+ f"curl --connect-timeout 60 --max-time 240 --retry 1 --output { dstack_runner_binary_path } { url } " ,
674+ f"chmod +x { dstack_runner_binary_path } " ,
611675 (
612- f"{ DSTACK_RUNNER_BINARY_PATH } "
676+ f"{ dstack_runner_binary_path } "
613677 " --log-level 6"
614678 " start"
615679 f" --http-port { DSTACK_RUNNER_HTTP_PORT } "
0 commit comments