55import threading
66from abc import ABC , abstractmethod
77from collections .abc import Iterable
8+ from enum import Enum
89from functools import lru_cache
910from pathlib import Path
10- from typing import Callable , Dict , List , Literal , Optional
11+ from typing import Callable , Dict , List , Optional
1112
1213import git
1314import requests
1415import yaml
1516from cachetools import TTLCache , cachedmethod
17+ from gpuhunt import CPUArchitecture
1618
1719from dstack ._internal import settings
1820from dstack ._internal .core .backends .base .offers import filter_offers_by_requirements
4951DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
5052DEFAULT_PRIVATE_SUBNETS = ("10.0.0.0/8" , "172.16.0.0/12" , "192.168.0.0/16" )
5153
52- GoArchType = Literal ["amd64" , "arm64" ]
54+
55+ class GoArchType (str , Enum ):
56+ """
57+ A subset of GOARCH values
58+ """
59+
60+ AMD64 = "amd64"
61+ ARM64 = "arm64"
62+
63+ def to_cpu_architecture (self ) -> CPUArchitecture :
64+ if self == self .AMD64 :
65+ return CPUArchitecture .X86
66+ if self == self .ARM64 :
67+ return CPUArchitecture .ARM
68+ assert False , self
5369
5470
5571class Compute (ABC ):
@@ -688,14 +704,14 @@ def normalize_arch(arch: Optional[str] = None) -> GoArchType:
688704 If the arch is not specified, falls back to `amd64`.
689705 """
690706 if not arch :
691- return "amd64"
707+ return GoArchType . AMD64
692708 arch_lower = arch .lower ()
693709 if "32" in arch_lower or arch_lower in ["i386" , "i686" ]:
694710 raise ValueError (f"32-bit architectures are not supported: { arch } " )
695711 if arch_lower .startswith ("x86" ) or arch_lower .startswith ("amd" ):
696- return "amd64"
712+ return GoArchType . AMD64
697713 if arch_lower .startswith ("arm" ) or arch_lower .startswith ("aarch" ):
698- return "arm64"
714+ return GoArchType . ARM64
699715 raise ValueError (f"Unsupported architecture: { arch } " )
700716
701717
@@ -711,8 +727,7 @@ def get_dstack_runner_download_url(arch: Optional[str] = None) -> str:
711727 "/{version}/binaries/dstack-runner-linux-{arch}"
712728 )
713729 version = get_dstack_runner_version ()
714- arch = normalize_arch (arch )
715- return url_template .format (version = version , arch = arch )
730+ return url_template .format (version = version , arch = normalize_arch (arch ).value )
716731
717732
718733def get_dstack_shim_download_url (arch : Optional [str ] = None ) -> str :
@@ -727,8 +742,7 @@ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
727742 "/{version}/binaries/dstack-shim-linux-{arch}"
728743 )
729744 version = get_dstack_runner_version ()
730- arch = normalize_arch (arch )
731- return url_template .format (version = version , arch = arch )
745+ return url_template .format (version = version , arch = normalize_arch (arch ).value )
732746
733747
734748def get_setup_cloud_instance_commands (
0 commit comments