|
| 1 | +from typing import Callable, Optional, Protocol, TypeVar |
| 2 | + |
| 3 | +from pydantic import BaseModel |
| 4 | +from typing_extensions import Self |
| 5 | + |
| 6 | +from dstack._internal.core.models.profiles import Profile, SpotPolicy |
| 7 | +from dstack._internal.core.models.resources import ( |
| 8 | + CPUSpec, |
| 9 | + DiskSpec, |
| 10 | + GPUSpec, |
| 11 | + Memory, |
| 12 | + Range, |
| 13 | + ResourcesSpec, |
| 14 | +) |
| 15 | +from dstack._internal.core.models.runs import Requirements |
| 16 | +from dstack._internal.utils.typing import SupportsRichComparison |
| 17 | + |
| 18 | + |
| 19 | +class CombineError(ValueError): |
| 20 | + pass |
| 21 | + |
| 22 | + |
| 23 | +def combine_fleet_and_run_profiles( |
| 24 | + fleet_profile: Profile, run_profile: Profile |
| 25 | +) -> Optional[Profile]: |
| 26 | + """ |
| 27 | + Combines fleet and run profile parameters that affect offer selection or provisioning. |
| 28 | + """ |
| 29 | + try: |
| 30 | + return Profile( |
| 31 | + backends=_intersect_lists_optional(fleet_profile.backends, run_profile.backends), |
| 32 | + regions=_intersect_lists_optional(fleet_profile.regions, run_profile.regions), |
| 33 | + availability_zones=_intersect_lists_optional( |
| 34 | + fleet_profile.availability_zones, run_profile.availability_zones |
| 35 | + ), |
| 36 | + instance_types=_intersect_lists_optional( |
| 37 | + fleet_profile.instance_types, run_profile.instance_types |
| 38 | + ), |
| 39 | + reservation=_get_single_value_optional( |
| 40 | + fleet_profile.reservation, run_profile.reservation |
| 41 | + ), |
| 42 | + spot_policy=_combine_spot_policy_optional( |
| 43 | + fleet_profile.spot_policy, run_profile.spot_policy |
| 44 | + ), |
| 45 | + max_price=_get_min_optional(fleet_profile.max_price, run_profile.max_price), |
| 46 | + idle_duration=_combine_idle_duration_optional( |
| 47 | + fleet_profile.idle_duration, run_profile.idle_duration |
| 48 | + ), |
| 49 | + tags=_combine_tags_optional(fleet_profile.tags, run_profile.tags), |
| 50 | + ) |
| 51 | + except CombineError: |
| 52 | + return None |
| 53 | + |
| 54 | + |
| 55 | +def combine_fleet_and_run_requirements( |
| 56 | + fleet_requirements: Requirements, run_requirements: Requirements |
| 57 | +) -> Optional[Requirements]: |
| 58 | + try: |
| 59 | + return Requirements( |
| 60 | + resources=_combine_resources(fleet_requirements.resources, run_requirements.resources), |
| 61 | + max_price=_get_min_optional(fleet_requirements.max_price, run_requirements.max_price), |
| 62 | + spot=_combine_spot_optional(fleet_requirements.spot, run_requirements.spot), |
| 63 | + reservation=_get_single_value_optional( |
| 64 | + fleet_requirements.reservation, run_requirements.reservation |
| 65 | + ), |
| 66 | + ) |
| 67 | + except CombineError: |
| 68 | + return None |
| 69 | + |
| 70 | + |
| 71 | +_T = TypeVar("_T") |
| 72 | +_ModelT = TypeVar("_ModelT", bound=BaseModel) |
| 73 | +_CompT = TypeVar("_CompT", bound=SupportsRichComparison) |
| 74 | + |
| 75 | + |
| 76 | +class _SupportsCopy(Protocol): |
| 77 | + def copy(self) -> Self: ... |
| 78 | + |
| 79 | + |
| 80 | +_CopyT = TypeVar("_CopyT", bound=_SupportsCopy) |
| 81 | + |
| 82 | + |
| 83 | +def _intersect_lists_optional( |
| 84 | + list1: Optional[list[_T]], list2: Optional[list[_T]] |
| 85 | +) -> Optional[list[_T]]: |
| 86 | + if list1 is None: |
| 87 | + if list2 is None: |
| 88 | + return None |
| 89 | + return list2.copy() |
| 90 | + if list2 is None: |
| 91 | + return list1.copy() |
| 92 | + return [x for x in list1 if x in list2] |
| 93 | + |
| 94 | + |
| 95 | +def _get_min(value1: _CompT, value2: _CompT) -> _CompT: |
| 96 | + return min(value1, value2) |
| 97 | + |
| 98 | + |
| 99 | +def _get_min_optional(value1: Optional[_CompT], value2: Optional[_CompT]) -> Optional[_CompT]: |
| 100 | + return _combine_optional(value1, value2, _get_min) |
| 101 | + |
| 102 | + |
| 103 | +def _get_single_value(value1: _T, value2: _T) -> _T: |
| 104 | + if value1 == value2: |
| 105 | + return value1 |
| 106 | + raise CombineError(f"Values {value1} and {value2} cannot be combined") |
| 107 | + |
| 108 | + |
| 109 | +def _get_single_value_optional(value1: Optional[_T], value2: Optional[_T]) -> Optional[_T]: |
| 110 | + return _combine_optional(value1, value2, _get_single_value) |
| 111 | + |
| 112 | + |
| 113 | +def _combine_spot_policy(value1: SpotPolicy, value2: SpotPolicy) -> SpotPolicy: |
| 114 | + if value1 == SpotPolicy.AUTO: |
| 115 | + return value2 |
| 116 | + if value2 == SpotPolicy.AUTO: |
| 117 | + return value1 |
| 118 | + if value1 == value2: |
| 119 | + return value1 |
| 120 | + raise CombineError(f"spot_policy values {value1} and {value2} cannot be combined") |
| 121 | + |
| 122 | + |
| 123 | +def _combine_spot_policy_optional( |
| 124 | + value1: Optional[SpotPolicy], value2: Optional[SpotPolicy] |
| 125 | +) -> Optional[SpotPolicy]: |
| 126 | + return _combine_optional(value1, value2, _combine_spot_policy) |
| 127 | + |
| 128 | + |
| 129 | +def _combine_idle_duration(value1: int, value2: int) -> int: |
| 130 | + if value1 < 0 and value2 >= 0 or value2 < 0 and value1 >= 0: |
| 131 | + raise CombineError(f"idle_duration values {value1} and {value2} cannot be combined") |
| 132 | + return min(value1, value2) |
| 133 | + |
| 134 | + |
| 135 | +def _combine_idle_duration_optional(value1: Optional[int], value2: Optional[int]) -> Optional[int]: |
| 136 | + return _combine_optional(value1, value2, _combine_idle_duration) |
| 137 | + |
| 138 | + |
| 139 | +def _combine_tags_optional( |
| 140 | + value1: Optional[dict[str, str]], value2: Optional[dict[str, str]] |
| 141 | +) -> Optional[dict[str, str]]: |
| 142 | + return _combine_copy_optional(value1, value2, _combine_tags) |
| 143 | + |
| 144 | + |
| 145 | +def _combine_tags(value1: dict[str, str], value2: dict[str, str]) -> dict[str, str]: |
| 146 | + return value1 | value2 |
| 147 | + |
| 148 | + |
| 149 | +def _combine_resources(value1: ResourcesSpec, value2: ResourcesSpec) -> ResourcesSpec: |
| 150 | + return ResourcesSpec( |
| 151 | + cpu=_combine_cpu(value1.cpu, value2.cpu), # type: ignore[attr-defined] |
| 152 | + memory=_combine_memory(value1.memory, value2.memory), |
| 153 | + shm_size=_combine_shm_size_optional(value1.shm_size, value2.shm_size), |
| 154 | + gpu=_combine_gpu_optional(value1.gpu, value2.gpu), |
| 155 | + disk=_combine_disk_optional(value1.disk, value2.disk), |
| 156 | + ) |
| 157 | + |
| 158 | + |
| 159 | +def _combine_cpu(value1: CPUSpec, value2: CPUSpec) -> CPUSpec: |
| 160 | + return CPUSpec( |
| 161 | + arch=_get_single_value_optional(value1.arch, value2.arch), |
| 162 | + count=_combine_range(value1.count, value2.count), |
| 163 | + ) |
| 164 | + |
| 165 | + |
| 166 | +def _combine_memory(value1: Range[Memory], value2: Range[Memory]) -> Range[Memory]: |
| 167 | + return _combine_range(value1, value2) |
| 168 | + |
| 169 | + |
| 170 | +def _combine_shm_size_optional( |
| 171 | + value1: Optional[Memory], value2: Optional[Memory] |
| 172 | +) -> Optional[Memory]: |
| 173 | + return _get_min_optional(value1, value2) |
| 174 | + |
| 175 | + |
| 176 | +def _combine_gpu(value1: GPUSpec, value2: GPUSpec) -> GPUSpec: |
| 177 | + return GPUSpec( |
| 178 | + vendor=_get_single_value_optional(value1.vendor, value2.vendor), |
| 179 | + name=_intersect_lists_optional(value1.name, value2.name), |
| 180 | + count=_combine_range(value1.count, value2.count), |
| 181 | + memory=_combine_range_optional(value1.memory, value2.memory), |
| 182 | + total_memory=_combine_range_optional(value1.total_memory, value2.total_memory), |
| 183 | + compute_capability=_get_min_optional(value1.compute_capability, value2.compute_capability), |
| 184 | + ) |
| 185 | + |
| 186 | + |
| 187 | +def _combine_gpu_optional( |
| 188 | + value1: Optional[GPUSpec], value2: Optional[GPUSpec] |
| 189 | +) -> Optional[GPUSpec]: |
| 190 | + return _combine_models_optional(value1, value2, _combine_gpu) |
| 191 | + |
| 192 | + |
| 193 | +def _combine_disk(value1: DiskSpec, value2: DiskSpec) -> DiskSpec: |
| 194 | + return DiskSpec(size=_combine_range(value1.size, value2.size)) |
| 195 | + |
| 196 | + |
| 197 | +def _combine_disk_optional( |
| 198 | + value1: Optional[DiskSpec], value2: Optional[DiskSpec] |
| 199 | +) -> Optional[DiskSpec]: |
| 200 | + return _combine_models_optional(value1, value2, _combine_disk) |
| 201 | + |
| 202 | + |
| 203 | +def _combine_spot(value1: bool, value2: bool) -> bool: |
| 204 | + if value1 != value2: |
| 205 | + raise CombineError(f"spot values {value1} and {value2} cannot be combined") |
| 206 | + return value1 |
| 207 | + |
| 208 | + |
| 209 | +def _combine_spot_optional(value1: Optional[bool], value2: Optional[bool]) -> Optional[bool]: |
| 210 | + return _combine_optional(value1, value2, _combine_spot) |
| 211 | + |
| 212 | + |
| 213 | +def _combine_range(value1: Range, value2: Range) -> Range: |
| 214 | + res = value1.intersect(value2) |
| 215 | + if res is None: |
| 216 | + raise CombineError(f"Ranges {value1} and {value2} cannot be combined") |
| 217 | + return res |
| 218 | + |
| 219 | + |
| 220 | +def _combine_range_optional(value1: Optional[Range], value2: Optional[Range]) -> Optional[Range]: |
| 221 | + return _combine_models_optional(value1, value2, _combine_range) |
| 222 | + |
| 223 | + |
| 224 | +def _combine_optional( |
| 225 | + value1: Optional[_T], value2: Optional[_T], combiner: Callable[[_T, _T], _T] |
| 226 | +) -> Optional[_T]: |
| 227 | + if value1 is None: |
| 228 | + return value2 |
| 229 | + if value2 is None: |
| 230 | + return value1 |
| 231 | + return combiner(value1, value2) |
| 232 | + |
| 233 | + |
| 234 | +def _combine_models_optional( |
| 235 | + value1: Optional[_ModelT], |
| 236 | + value2: Optional[_ModelT], |
| 237 | + combiner: Callable[[_ModelT, _ModelT], _ModelT], |
| 238 | +) -> Optional[_ModelT]: |
| 239 | + if value1 is None: |
| 240 | + if value2 is not None: |
| 241 | + return value2.copy(deep=True) |
| 242 | + return None |
| 243 | + if value2 is None: |
| 244 | + return value1.copy(deep=True) |
| 245 | + return combiner(value1, value2) |
| 246 | + |
| 247 | + |
| 248 | +def _combine_copy_optional( |
| 249 | + value1: Optional[_CopyT], |
| 250 | + value2: Optional[_CopyT], |
| 251 | + combiner: Callable[[_CopyT, _CopyT], _CopyT], |
| 252 | +) -> Optional[_CopyT]: |
| 253 | + if value1 is None: |
| 254 | + if value2 is not None: |
| 255 | + return value2.copy() |
| 256 | + return None |
| 257 | + if value2 is None: |
| 258 | + return value1.copy() |
| 259 | + return combiner(value1, value2) |
0 commit comments