-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Expand file tree
/
Copy path_ssh_command.py
More file actions
146 lines (125 loc) · 5.83 KB
/
_ssh_command.py
File metadata and controls
146 lines (125 loc) · 5.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is
# regenerated.
# --------------------------------------------------------------------------
# ---------------------------------------------------------
import os
import pathlib
import platform
import subprocess
import sys
from typing import Dict, Tuple, Sequence, Optional
from xmlrpc.client import boolean
from azure.ai.ml.entities import ServiceInstance
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException, ValidationErrorType, ValidationException
def get_ssh_command(
services_dict: Dict[str, ServiceInstance],
node_index: int,
private_key_file_path: str,
ssh_args: Optional[Sequence[str]] = None,
connector_args: Optional[Sequence[str]] = None
) -> Tuple[bool, str]:
proxyEndpoint = _get_proxy_endpoint(services_dict, node_index).replace("<nodeIndex>", str(node_index))
connect_ssh_path = pathlib.Path(__file__).parent / "_ssh_connector.py"
# split by space to check if file path has space
connect_ssh_path_has_space = len(str(connect_ssh_path).split(" ")) > 1
ssh_path = "ssh"
if os.name == "nt":
# On Windows, need to set the path to ssh since if using Python 32 bit, the file system redirector will cause
# 32 bit python to be unable to find ssh
system32 = os.path.join(
os.environ["SystemRoot"], "SysNative" if platform.architecture()[0] == "32bit" else "System32"
)
ssh_path = os.path.join(system32, "OpenSSH\\ssh.exe")
identity_param = " -i {}".format(private_key_file_path) if private_key_file_path else ""
# TODO: Find how to enable debug mode
ssh_args_str = " ".join(ssh_args) if ssh_args else ""
connector_args_str = " ".join(connector_args) if connector_args else ""
return (
connect_ssh_path_has_space,
f'{ssh_path} -v -o ProxyCommand="{sys.executable} {connect_ssh_path} {proxyEndpoint} {connector_args_str}" '
f"azureuser@{proxyEndpoint}{identity_param}{ssh_args_str}",
)
def _get_proxy_endpoint(services_dict: Dict[str, ServiceInstance], node_index: int) -> str:
if not services_dict or len(services_dict.values()) < 1:
msg = f"The node {node_index} of the job does not have services. Please ensure that the job has services."
raise ValidationException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.JOB,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.INVALID_VALUE,
)
first_ssh_service = next(
(service for service in services_dict.values() if service.type == "SSH"),
None,
)
if not first_ssh_service:
msg = f"Please ensure that the job is ssh enabled on node '{node_index}'."
raise ValidationException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.JOB,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.INVALID_VALUE,
)
if not first_ssh_service.status == "Running":
msg = (
f"Please ensure that ssh service at node '{node_index}' has the status as 'Running'. "
f"The current status is '{first_ssh_service.status}'."
)
raise ValidationException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.JOB,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.INVALID_VALUE,
)
if not first_ssh_service.properties or not first_ssh_service.properties.get("ProxyEndpoint"):
msg = "The ssh JobService.properties is missing ProxyEndpoint."
raise JobException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.JOB,
error_category=ErrorCategory.SYSTEM_ERROR,
error_type=ValidationErrorType.INVALID_VALUE,
)
return first_ssh_service.properties.get("ProxyEndpoint")
def has_ssh_dependencies_installed() -> boolean:
reqs = subprocess.check_output([sys.executable, "-m", "pip", "freeze"])
installed_packages = [r.decode().split("==")[0] for r in reqs.split()]
if "websockets" not in installed_packages:
if _confirm("connect-ssh command requires websockets package. Do you like to install websockets now?"):
subprocess.check_call([sys.executable, "-m", "pip", "install", "websockets"])
print("Successfully installed websockets.")
return True
print("Exiting as you preferred not to install websockets.")
return False
return True
def ssh_connector_file_path_space_message():
return """
File path for _ssh_connector.py has space, unfortunately which will not work with ProxyCommand. To work
around this you can copy the _ssh_connector.py file from the location above (see ssh_command) to the
current working directory and run ssh_command output above, swapping the new ssh_connector file path:
ssh -v -o ProxyCommand="... _ssh_connector.py ..."""
def _confirm(question, default="no"):
if default is None:
prompt = " [y/n] "
elif default == "yes":
prompt = " [Y/n] "
else:
prompt = " [y/N] "
while True:
answer = input(question + prompt).strip().lower()
if default is not None and answer == "":
return default == "yes"
if answer in ["yes", "y"]:
return True
if answer in ["no", "n"]:
return False
print("Please answer 'yes' or 'no' (or 'y' or 'n').\n")