Skip to content

Commit 1c67554

Browse files
Compute connect port 22 path fix (#9082)
* port 22 compute connect fix * add changelog
1 parent f8edddf commit 1c67554

4 files changed

Lines changed: 14 additions & 5 deletions

File tree

src/machinelearningservices/CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
## Azure Machine Learning CLI (v2) (unreleased)
22
- `az ml compute update`
33
- Fix a bug compute update which caused Enable SSO property to reset.
4+
- `az ml compute connect-ssh`
5+
- Fix proxy endpoint path
46

57
## 2025-05-15
68

src/machinelearningservices/azext_mlv2/manual/custom/_ssh_command.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def get_ssh_command(
2424
services_dict: Dict[str, ServiceInstance],
2525
node_index: int,
2626
private_key_file_path: str,
27-
ssh_args: Optional[Sequence[str]] = None
27+
ssh_args: Optional[Sequence[str]] = None,
28+
connector_args: Optional[Sequence[str]] = None
2829
) -> Tuple[bool, str]:
2930
proxyEndpoint = _get_proxy_endpoint(services_dict, node_index).replace("<nodeIndex>", str(node_index))
3031
connect_ssh_path = pathlib.Path(__file__).parent / "_ssh_connector.py"
@@ -45,9 +46,10 @@ def get_ssh_command(
4546
identity_param = " -i {}".format(private_key_file_path) if private_key_file_path else ""
4647
# TODO: Find how to enable debug mode
4748
ssh_args_str = " ".join(ssh_args) if ssh_args else ""
49+
connector_args_str = " ".join(connector_args) if connector_args else ""
4850
return (
4951
connect_ssh_path_has_space,
50-
f'{ssh_path} -v -o ProxyCommand="{sys.executable} {connect_ssh_path} {proxyEndpoint}" '
52+
f'{ssh_path} -v -o ProxyCommand="{sys.executable} {connect_ssh_path} {proxyEndpoint} {connector_args_str}" '
5153
f"azureuser@{proxyEndpoint}{identity_param}{ssh_args_str}",
5254
)
5355

src/machinelearningservices/azext_mlv2/manual/custom/_ssh_connector.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,17 @@ async def _connect_ssh(self):
6161
)
6262
raise Exception(msg) # pylint: disable=broad-exception-raised
6363
proxy_endpoint = sys.argv[1]
64+
65+
is_compute = len(sys.argv) > 2 and sys.argv[2] == "--is-compute"
66+
uri = f"{proxy_endpoint}/nbip/v1.0/ws-tcp"
67+
if is_compute:
68+
uri += "/port/22"
6469
mgtScope = ["https://management.core.windows.net/.default"]
6570

6671
aml_token = run_az_cli(["account", "get-access-token", "--scope", mgtScope[0]])["accessToken"]
6772

6873
async with websockets.client.connect(
69-
uri=f"{proxy_endpoint}/nbip/v1.0/ws-tcp",
74+
uri=uri,
7075
extra_headers={"Authorization": f"Bearer {aml_token}"},
7176
) as websocket:
7277

src/machinelearningservices/azext_mlv2/manual/custom/compute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,12 @@ def ml_compute_connect_ssh(cmd, resource_group_name, workspace_name, name, priva
274274
# create proxy endpoint for CI based on endpoint for jupyter
275275
# TODO: Improve with a call to get proxyendpoint from CI, requires API
276276
jupyter = [f["endpoint_uri"] for f in compute.services if f["display_name"] == "Jupyter"][0]
277-
proxyEndpoint = jupyter.replace(name, f"{name}-22").replace("https://", "wss://").replace("/tree/", "")
277+
proxyEndpoint = jupyter.replace("https://", "wss://").replace("/tree/", "")
278278

279279
services_dict = {
280280
"ssh": ServiceInstance(type="SSH", status="Running", properties={"ProxyEndpoint": proxyEndpoint})
281281
}
282-
path_has_space, ssh_command = get_ssh_command(services_dict, 0, private_key_file_path)
282+
path_has_space, ssh_command = get_ssh_command(services_dict, 0, private_key_file_path, connector_args=["--is-compute"])
283283
print(f"ssh_command: {ssh_command}")
284284
if path_has_space:
285285
module_logger.error(ssh_connector_file_path_space_message())

0 commit comments

Comments
 (0)