diff --git a/src/dstack/_internal/core/backends/verda/compute.py b/src/dstack/_internal/core/backends/verda/compute.py index 4ad995d9ea..2fd2359777 100644 --- a/src/dstack/_internal/core/backends/verda/compute.py +++ b/src/dstack/_internal/core/backends/verda/compute.py @@ -19,8 +19,9 @@ get_offers_disk_modifier, ) from dstack._internal.core.backends.verda.models import VerdaConfig -from dstack._internal.core.errors import NoCapacityError +from dstack._internal.core.errors import BackendError, NoCapacityError from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import CoreModel from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceConfiguration, @@ -31,7 +32,6 @@ from dstack._internal.core.models.resources import Memory, Range from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.utils.logging import get_logger -from dstack._internal.utils.ssh import get_public_key_fingerprint logger = get_logger("verda.compute") @@ -101,54 +101,72 @@ def create_instance( instance_config, max_length=MAX_INSTANCE_NAME_LEN ) public_keys = instance_config.get_public_keys() - ssh_ids = [] - for ssh_public_key in public_keys: - ssh_ids.append( - # verda allows you to use the same name - _get_or_create_ssh_key( - client=self.client, - name=f"dstack-{instance_config.instance_name}.key", - public_key=ssh_public_key, + ssh_ids: List[str] = [] + startup_script_id: Optional[str] = None + try: + for idx, ssh_public_key in enumerate(public_keys): + ssh_ids.append( + _create_ssh_key( + client=self.client, + name=f"{instance_name}-{idx}.key", + public_key=ssh_public_key, + ) ) - ) - commands = get_shim_commands() - startup_script = " ".join([" && ".join(commands)]) - script_name = f"dstack-{instance_config.instance_name}.sh" - startup_script_ids = _get_or_create_startup_scrpit( - client=self.client, - name=script_name, - script=startup_script, - ) + commands = get_shim_commands() + startup_script = " ".join([" && ".join(commands)]) + script_name = f"{instance_name}.sh" + startup_script_id = _create_startup_script( + client=self.client, + name=script_name, + script=startup_script, + ) - disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) - image_id = _get_vm_image_id(instance_offer) - - logger.debug( - "Deploying Verda instance", - { - "instance_type": instance_offer.instance.name, - "ssh_key_ids": ssh_ids, - "startup_script_id": startup_script_ids, - "hostname": instance_name, - "description": instance_name, - "image": image_id, - "disk_size": disk_size, - "location": instance_offer.region, - }, - ) - instance = _deploy_instance( - client=self.client, - instance_type=instance_offer.instance.name, - ssh_key_ids=ssh_ids, - startup_script_id=startup_script_ids, - hostname=instance_name, - description=instance_name, - image=image_id, - disk_size=disk_size, - is_spot=instance_offer.instance.resources.spot, - location=instance_offer.region, - ) + disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) + image_id = _get_vm_image_id(instance_offer) + + logger.debug( + "Deploying Verda instance", + { + "instance_type": instance_offer.instance.name, + "ssh_key_ids": ssh_ids, + "startup_script_id": startup_script_id, + "hostname": instance_name, + "description": instance_name, + "image": image_id, + "disk_size": disk_size, + "location": instance_offer.region, + }, + ) + instance = _deploy_instance( + client=self.client, + instance_type=instance_offer.instance.name, + ssh_key_ids=ssh_ids, + startup_script_id=startup_script_id, + hostname=instance_name, + description=instance_name, + image=image_id, + disk_size=disk_size, + is_spot=instance_offer.instance.resources.spot, + location=instance_offer.region, + ) + except Exception: + # startup_script_id and ssh_key_ids are per-instance. Ensure no leaks on failures. + try: + _delete_startup_script(self.client, startup_script_id) + except Exception: + logger.exception( + "Failed to cleanup startup script %s after provisioning failure.", + startup_script_id, + ) + try: + _delete_ssh_keys(self.client, ssh_ids) + except Exception: + logger.exception( + "Failed to cleanup ssh keys %s after provisioning failure.", + ssh_ids, + ) + raise return JobProvisioningData( backend=instance_offer.backend, instance_type=instance_offer.instance, @@ -161,12 +179,16 @@ def create_instance( ssh_port=22, dockerized=True, ssh_proxy=None, - backend_data=None, + backend_data=VerdaInstanceBackendData( + startup_script_id=startup_script_id, + ssh_key_ids=ssh_ids, + ).json(), ) def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ): + backend_data_parsed = VerdaInstanceBackendData.load(backend_data) try: self.client.instances.action(id_list=[instance_id], action="delete") except APIException as e: @@ -175,8 +197,10 @@ def terminate_instance( "Can't discontinue a discontinued instance", ]: logger.debug("Skipping instance %s termination. Instance not found.", instance_id) - return - raise + else: + raise + _delete_startup_script(self.client, backend_data_parsed.startup_script_id) + _delete_ssh_keys(self.client, backend_data_parsed.ssh_key_ids) def update_provisioning_data( self, @@ -200,26 +224,84 @@ def _get_vm_image_id(instance_offer: InstanceOfferWithAvailability) -> str: return "77777777-4f48-4249-82b3-f199fb9b701b" -def _get_or_create_ssh_key(client: VerdaClient, name: str, public_key: str) -> str: - fingerprint = get_public_key_fingerprint(public_key) - keys = client.ssh_keys.get() - found_keys = [key for key in keys if fingerprint == get_public_key_fingerprint(key.public_key)] - if found_keys: - key = found_keys[0] +def _create_ssh_key(client: VerdaClient, name: str, public_key: str) -> str: + try: + key = client.ssh_keys.create(name, public_key) return key.id - key = client.ssh_keys.create(name, public_key) - return key.id + except APIException as e: + raise BackendError(f"Verda API error while creating SSH key: {e.message}") -def _get_or_create_startup_scrpit(client: VerdaClient, name: str, script: str) -> str: - scripts = client.startup_scripts.get() - found_scripts = [startup_script for startup_script in scripts if script == startup_script] - if found_scripts: - startup_script = found_scripts[0] +def _create_startup_script(client: VerdaClient, name: str, script: str) -> str: + try: + startup_script = client.startup_scripts.create(name, script) return startup_script.id + except APIException as e: + raise BackendError(f"Verda API error while creating startup script: {e.message}") - startup_script = client.startup_scripts.create(name, script) - return startup_script.id + +def _delete_startup_script(client: VerdaClient, startup_script_id: Optional[str]) -> None: + if startup_script_id is None: + return + try: + client.startup_scripts.delete_by_id(startup_script_id) + except APIException as e: + if _is_startup_script_not_found_error(e): + logger.debug( + "Skipping startup script %s deletion. Startup script not found.", + startup_script_id, + ) + return + raise + + +def _delete_ssh_keys(client: VerdaClient, ssh_key_ids: Optional[List[str]]) -> None: + if not ssh_key_ids: + return + for ssh_key_id in ssh_key_ids: + _delete_ssh_key(client, ssh_key_id) + + +def _delete_ssh_key(client: VerdaClient, ssh_key_id: str) -> None: + try: + client.ssh_keys.delete_by_id(ssh_key_id) + except APIException as e: + if _is_ssh_key_not_found_error(e): + logger.debug("Skipping ssh key %s deletion. SSH key not found.", ssh_key_id) + return + raise + + +def _is_ssh_key_not_found_error(error: APIException) -> bool: + code = (error.code or "").lower() + message = (error.message or "").lower() + if code == "not_found": + return True + if code not in {"", "invalid_request"}: + return False + return ( + message == "invalid ssh-key id" + or message == "invalid ssh key id" + or message == "not found" + or ("ssh-key id" in message and "invalid" in message) + or ("ssh key id" in message and "invalid" in message) + ) + + +def _is_startup_script_not_found_error(error: APIException) -> bool: + code = (error.code or "").lower() + message = (error.message or "").lower() + if code == "not_found": + return True + if code not in {"", "invalid_request"}: + return False + return ( + message == "invalid startup script id" + or message == "invalid script id" + or message == "not found" + or ("startup script id" in message and "invalid" in message) + or ("script id" in message and "invalid" in message) + ) def _get_instance_by_id( @@ -264,3 +346,14 @@ def _deploy_instance( raise NoCapacityError(f"Verda API error: {e.message}") return instance + + +class VerdaInstanceBackendData(CoreModel): + startup_script_id: Optional[str] = None + ssh_key_ids: Optional[List[str]] = None + + @classmethod + def load(cls, raw: Optional[str]) -> "VerdaInstanceBackendData": + if raw is None: + return cls() + return cls.__response__.parse_raw(raw) diff --git a/src/tests/_internal/core/backends/verda/test_compute.py b/src/tests/_internal/core/backends/verda/test_compute.py new file mode 100644 index 0000000000..a1777af62d --- /dev/null +++ b/src/tests/_internal/core/backends/verda/test_compute.py @@ -0,0 +1,484 @@ +import sys +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +if sys.version_info < (3, 10): + pytest.skip("Verda requires Python 3.10", allow_module_level=True) + +from verda.exceptions import APIException + +from dstack._internal.core.backends.verda.compute import ( + VerdaCompute, + VerdaInstanceBackendData, + _create_ssh_key, + _create_startup_script, + _is_ssh_key_not_found_error, + _is_startup_script_not_found_error, +) +from dstack._internal.core.errors import BackendError, NoCapacityError + + +class TestCreateSSHKey: + def test_creates_ssh_key(self): + client = MagicMock() + client.ssh_keys.create.return_value = SimpleNamespace(id="new-ssh-key-id") + + key_id = _create_ssh_key( + client=client, + name="dstack-test-key", + public_key="ssh-rsa test", + ) + + assert key_id == "new-ssh-key-id" + client.ssh_keys.create.assert_called_once_with("dstack-test-key", "ssh-rsa test") + + def test_raises_backend_error_on_api_exception(self): + client = MagicMock() + client.ssh_keys.create.side_effect = APIException("invalid_request", "Boom") + + with pytest.raises(BackendError, match="creating SSH key: Boom"): + _create_ssh_key( + client=client, + name="dstack-test-key", + public_key="ssh-rsa test", + ) + + +class TestCreateStartupScript: + def test_creates_startup_script(self): + client = MagicMock() + client.startup_scripts.create.return_value = SimpleNamespace(id="new-script-id") + + script_id = _create_startup_script( + client=client, + name="dstack-test-script.sh", + script="echo bye", + ) + + assert script_id == "new-script-id" + client.startup_scripts.create.assert_called_once_with( + "dstack-test-script.sh", + "echo bye", + ) + + def test_raises_backend_error_on_api_exception(self): + client = MagicMock() + client.startup_scripts.create.side_effect = APIException("invalid_request", "Boom") + + with pytest.raises(BackendError, match="creating startup script: Boom"): + _create_startup_script( + client=client, + name="dstack-test-script.sh", + script="echo bye", + ) + + +class TestCreateInstance: + def test_cleans_up_created_ssh_keys_if_later_ssh_key_create_fails(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + + instance_offer = SimpleNamespace( + backend="verda", + instance=SimpleNamespace( + name="CPU.4V.16G", + resources=SimpleNamespace( + disk=SimpleNamespace(size_mib=102400), + gpus=[], + spot=False, + ), + ), + region="FIN-01", + price=0.0279, + ) + instance_config = SimpleNamespace( + instance_name="verda-one-node-0", + get_public_keys=lambda: ["ssh-rsa test-1", "ssh-rsa test-2"], + ) + + with ( + patch( + "dstack._internal.core.backends.verda.compute.generate_unique_instance_name", + return_value="verda-one-node-0", + ), + patch( + "dstack._internal.core.backends.verda.compute._create_ssh_key", + side_effect=["ssh-key-id-1", BackendError("ssh create failed")], + ), + patch( + "dstack._internal.core.backends.verda.compute._create_startup_script" + ) as create_startup_script, + patch( + "dstack._internal.core.backends.verda.compute._delete_startup_script" + ) as delete_startup_script, + patch( + "dstack._internal.core.backends.verda.compute._delete_ssh_keys" + ) as delete_ssh_keys, + ): + with pytest.raises(BackendError, match="ssh create failed"): + compute.create_instance(instance_offer, instance_config, None) + + create_startup_script.assert_not_called() + delete_startup_script.assert_called_once_with(compute.client, None) + delete_ssh_keys.assert_called_once_with(compute.client, ["ssh-key-id-1"]) + + def test_cleans_up_ssh_keys_if_startup_script_create_fails(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + + instance_offer = SimpleNamespace( + backend="verda", + instance=SimpleNamespace( + name="CPU.4V.16G", + resources=SimpleNamespace( + disk=SimpleNamespace(size_mib=102400), + gpus=[], + spot=False, + ), + ), + region="FIN-01", + price=0.0279, + ) + instance_config = SimpleNamespace( + instance_name="verda-one-node-0", + get_public_keys=lambda: ["ssh-rsa test-1", "ssh-rsa test-2"], + ) + + with ( + patch( + "dstack._internal.core.backends.verda.compute.generate_unique_instance_name", + return_value="verda-one-node-0", + ), + patch( + "dstack._internal.core.backends.verda.compute._create_ssh_key", + side_effect=["ssh-key-id-1", "ssh-key-id-2"], + ), + patch( + "dstack._internal.core.backends.verda.compute._create_startup_script", + side_effect=BackendError("script create failed"), + ), + patch( + "dstack._internal.core.backends.verda.compute._delete_startup_script" + ) as delete_startup_script, + patch( + "dstack._internal.core.backends.verda.compute._delete_ssh_keys" + ) as delete_ssh_keys, + ): + with pytest.raises(BackendError, match="script create failed"): + compute.create_instance(instance_offer, instance_config, None) + + delete_startup_script.assert_called_once_with(compute.client, None) + delete_ssh_keys.assert_called_once_with(compute.client, ["ssh-key-id-1", "ssh-key-id-2"]) + + def test_cleans_up_startup_script_if_deploy_fails(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + + instance_offer = SimpleNamespace( + backend="verda", + instance=SimpleNamespace( + name="CPU.4V.16G", + resources=SimpleNamespace( + disk=SimpleNamespace(size_mib=102400), + gpus=[], + spot=False, + ), + ), + region="FIN-01", + price=0.0279, + ) + instance_config = SimpleNamespace( + instance_name="verda-one-node-0", + get_public_keys=lambda: ["ssh-rsa test"], + ) + + with ( + patch( + "dstack._internal.core.backends.verda.compute.generate_unique_instance_name", + return_value="verda-one-node-0", + ), + patch( + "dstack._internal.core.backends.verda.compute.get_shim_commands", + return_value=["echo ready"], + ), + patch( + "dstack._internal.core.backends.verda.compute._create_ssh_key", + return_value="ssh-key-id", + ), + patch( + "dstack._internal.core.backends.verda.compute._create_startup_script", + return_value="startup-script-id", + ), + patch( + "dstack._internal.core.backends.verda.compute._deploy_instance", + side_effect=NoCapacityError("no capacity"), + ), + patch( + "dstack._internal.core.backends.verda.compute._delete_startup_script" + ) as delete_startup_script, + patch( + "dstack._internal.core.backends.verda.compute._delete_ssh_keys" + ) as delete_ssh_keys, + ): + with pytest.raises(NoCapacityError): + compute.create_instance(instance_offer, instance_config, None) + + delete_startup_script.assert_called_once_with(compute.client, "startup-script-id") + delete_ssh_keys.assert_called_once_with(compute.client, ["ssh-key-id"]) + + def test_stores_ssh_key_ids_in_backend_data(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + + instance_offer = SimpleNamespace( + backend="verda", + instance=SimpleNamespace( + name="CPU.4V.16G", + resources=SimpleNamespace( + disk=SimpleNamespace(size_mib=102400), + gpus=[], + spot=False, + ), + ), + region="FIN-01", + price=0.0279, + ) + instance_config = SimpleNamespace( + instance_name="verda-one-node-0", + get_public_keys=lambda: ["ssh-rsa test-1", "ssh-rsa test-2"], + ) + provider_instance = SimpleNamespace(id="provider-instance-id", location="FIN-01") + + with ( + patch( + "dstack._internal.core.backends.verda.compute.generate_unique_instance_name", + return_value="verda-one-node-0", + ), + patch( + "dstack._internal.core.backends.verda.compute.get_shim_commands", + return_value=["echo ready"], + ), + patch( + "dstack._internal.core.backends.verda.compute._create_ssh_key", + side_effect=["ssh-key-id-1", "ssh-key-id-2"], + ), + patch( + "dstack._internal.core.backends.verda.compute._create_startup_script", + return_value="startup-script-id", + ), + patch( + "dstack._internal.core.backends.verda.compute._deploy_instance", + return_value=provider_instance, + ), + patch( + "dstack._internal.core.backends.verda.compute.JobProvisioningData", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ), + ): + jpd = compute.create_instance(instance_offer, instance_config, None) + + backend_data = VerdaInstanceBackendData.load(jpd.backend_data) + assert backend_data.startup_script_id == "startup-script-id" + assert backend_data.ssh_key_ids == ["ssh-key-id-1", "ssh-key-id-2"] + + +class TestTerminateInstance: + def test_terminate_instance_without_backend_data(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + + compute.terminate_instance("instance-id", "FIN-01", None) + + compute.client.instances.action.assert_called_once_with( + id_list=["instance-id"], + action="delete", + ) + compute.client.startup_scripts.delete_by_id.assert_not_called() + compute.client.ssh_keys.delete_by_id.assert_not_called() + + def test_terminate_instance_deletes_startup_script(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + backend_data = VerdaInstanceBackendData( + startup_script_id="script-id", + ssh_key_ids=["ssh-key-id-1", "ssh-key-id-2"], + ).json() + + compute.terminate_instance("instance-id", "FIN-01", backend_data) + + compute.client.instances.action.assert_called_once_with( + id_list=["instance-id"], + action="delete", + ) + compute.client.startup_scripts.delete_by_id.assert_called_once_with("script-id") + assert compute.client.ssh_keys.delete_by_id.call_count == 2 + + def test_terminate_instance_still_deletes_script_when_instance_is_missing(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + compute.client.instances.action.side_effect = APIException("", "Invalid instance id") + backend_data = VerdaInstanceBackendData( + startup_script_id="script-id", + ssh_key_ids=["ssh-key-id-1"], + ).json() + + compute.terminate_instance("instance-id", "FIN-01", backend_data) + + compute.client.startup_scripts.delete_by_id.assert_called_once_with("script-id") + compute.client.ssh_keys.delete_by_id.assert_called_once_with("ssh-key-id-1") + + def test_terminate_instance_ignores_missing_startup_script(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + compute.client.startup_scripts.delete_by_id.side_effect = APIException( + "", + "Invalid startup script id", + ) + backend_data = VerdaInstanceBackendData( + startup_script_id="script-id", + ssh_key_ids=["ssh-key-id-1"], + ).json() + + compute.terminate_instance("instance-id", "FIN-01", backend_data) + + compute.client.instances.action.assert_called_once_with( + id_list=["instance-id"], + action="delete", + ) + compute.client.ssh_keys.delete_by_id.assert_called_once_with("ssh-key-id-1") + + def test_terminate_instance_ignores_missing_startup_script_invalid_script_id(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + compute.client.startup_scripts.delete_by_id.side_effect = APIException( + "invalid_request", + "Invalid script ID", + ) + backend_data = VerdaInstanceBackendData( + startup_script_id="script-id", + ssh_key_ids=["ssh-key-id-1"], + ).json() + + compute.terminate_instance("instance-id", "FIN-01", backend_data) + + compute.client.instances.action.assert_called_once_with( + id_list=["instance-id"], + action="delete", + ) + compute.client.ssh_keys.delete_by_id.assert_called_once_with("ssh-key-id-1") + + def test_terminate_instance_retries_on_script_delete_error(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + compute.client.startup_scripts.delete_by_id.side_effect = APIException( + "", "Random API error" + ) + backend_data = VerdaInstanceBackendData( + startup_script_id="script-id", + ssh_key_ids=["ssh-key-id-1"], + ).json() + + with pytest.raises(APIException): + compute.terminate_instance("instance-id", "FIN-01", backend_data) + + compute.client.ssh_keys.delete_by_id.assert_not_called() + + def test_terminate_instance_ignores_missing_ssh_key(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + compute.client.ssh_keys.delete_by_id.side_effect = APIException( + "invalid_request", + "Invalid ssh-key ID", + ) + backend_data = VerdaInstanceBackendData( + startup_script_id="script-id", + ssh_key_ids=["ssh-key-id-1"], + ).json() + + compute.terminate_instance("instance-id", "FIN-01", backend_data) + + compute.client.instances.action.assert_called_once_with( + id_list=["instance-id"], + action="delete", + ) + compute.client.startup_scripts.delete_by_id.assert_called_once_with("script-id") + compute.client.ssh_keys.delete_by_id.assert_called_once_with("ssh-key-id-1") + + def test_terminate_instance_deletes_remaining_ssh_keys_when_one_missing(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + compute.client.ssh_keys.delete_by_id.side_effect = [ + APIException("invalid_request", "Invalid ssh-key ID"), + None, + ] + backend_data = VerdaInstanceBackendData( + startup_script_id="script-id", + ssh_key_ids=["ssh-key-id-1", "ssh-key-id-2"], + ).json() + + compute.terminate_instance("instance-id", "FIN-01", backend_data) + + compute.client.startup_scripts.delete_by_id.assert_called_once_with("script-id") + compute.client.ssh_keys.delete_by_id.assert_any_call("ssh-key-id-1") + compute.client.ssh_keys.delete_by_id.assert_any_call("ssh-key-id-2") + assert compute.client.ssh_keys.delete_by_id.call_count == 2 + + def test_terminate_instance_retries_on_ssh_key_delete_error(self): + compute = VerdaCompute.__new__(VerdaCompute) + compute.client = MagicMock() + compute.client.ssh_keys.delete_by_id.side_effect = APIException("", "Random API error") + backend_data = VerdaInstanceBackendData( + startup_script_id="script-id", + ssh_key_ids=["ssh-key-id-1"], + ).json() + + with pytest.raises(APIException): + compute.terminate_instance("instance-id", "FIN-01", backend_data) + + +class TestIsStartupScriptNotFoundError: + def test_returns_true_for_not_found_code_even_with_custom_message(self): + assert _is_startup_script_not_found_error( + APIException("not_found", "Startup script does not exist anymore") + ) + + def test_returns_true_for_invalid_script_id(self): + assert _is_startup_script_not_found_error( + APIException("invalid_request", "Invalid script ID") + ) + + def test_returns_true_for_not_found(self): + assert _is_startup_script_not_found_error(APIException("not_found", "Not Found")) + + def test_returns_false_for_unrelated_error(self): + assert not _is_startup_script_not_found_error( + APIException("forbidden", "Permission denied") + ) + + def test_returns_false_for_unrelated_invalid_request(self): + assert not _is_startup_script_not_found_error( + APIException("invalid_request", "Some other invalid request") + ) + + +class TestIsSSHKeyNotFoundError: + def test_returns_true_for_not_found_code_even_with_custom_message(self): + assert _is_ssh_key_not_found_error( + APIException("not_found", "SSH key does not exist anymore") + ) + + def test_returns_true_for_invalid_ssh_key_id(self): + assert _is_ssh_key_not_found_error(APIException("invalid_request", "Invalid ssh-key ID")) + + def test_returns_true_for_not_found(self): + assert _is_ssh_key_not_found_error(APIException("not_found", "Not Found")) + + def test_returns_false_for_unrelated_error(self): + assert not _is_ssh_key_not_found_error(APIException("forbidden", "Permission denied")) + + def test_returns_false_for_unrelated_invalid_request(self): + assert not _is_ssh_key_not_found_error( + APIException("invalid_request", "Some other invalid request") + )