From 36774e03e62b92abd2b7b64595d2550fd9083aea Mon Sep 17 00:00:00 2001 From: Steven Engelbert Date: Tue, 22 Oct 2024 15:21:09 -0400 Subject: [PATCH] Add file getting, sending to `cli.py` Convert device connection into a separate function Upgrade file sending to be in line with file getting regulate spacing between functions to two lines (there's been a mix of 1 and 2 lines so far) Remove leftover comment Add docstrings, typehinting --- admin/cli.py | 161 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 127 insertions(+), 34 deletions(-) diff --git a/admin/cli.py b/admin/cli.py index a7f8a02..6bbd844 100644 --- a/admin/cli.py +++ b/admin/cli.py @@ -3,7 +3,7 @@ import random import logging -from os import getenv +from os import getenv, path, makedirs, walk from uuid import UUID from time import sleep from typing import Optional @@ -69,6 +69,46 @@ def get_tunnel(tunnel_id) -> Optional[dict]: return None +def connect_tunnel(c, tunnel_id): + """Connects to a live support tunnel using the tunnel_id, and returns a connection to the device""" + # Care should be exercised here; we're taking data from a remote source and using it to + # run shell commands. Validate every last bit of data. + t = get_tunnel(tunnel_id) + assert TunnelState(t['state']) == TunnelState.running, "Device has not yet connected" + dip = device_ip(IPv4Network(t['network'])).ip + assert t['support_user'].isalnum() + assert t['support_user'].isascii() + support_user = t['support_user'] + + # set up local + c.run("gcloud compute config-ssh", hide="both") + user_from_oslogin = c.run("gcloud compute os-login describe-profile --format=json", hide="both") + ts_user = json.loads(user_from_oslogin.stdout)['posixAccounts'][0]['username'] + + # set up connection to bastion + ts = Connection( + host = str(get_ts_instance_public_ip(tunnel_id)), + user = ts_user, + connect_kwargs={"auth_timeout": 120} + ) + + # grab the ssh private key on the tunnel server + ssh_privkey = ts.run(f"sudo cat {SSH_KEYFILE_PATH}", hide="both") + assert ssh_privkey + + # set up connection to destination device + device = Connection( + host=str(dip), + user=support_user, + gateway=ts, + connect_kwargs = { + "pkey": Ed25519Key.from_private_key(io.StringIO(ssh_privkey.stdout)), + } + ) + + return device + + @task def show(c, tunnel_id): """ Show a single tunnel's details """ @@ -84,6 +124,7 @@ def list(c): res.raise_for_status() print(res.text) + @task def create(c, tunnel_id: Optional[UUID4] = None, preshared_key: Optional[WireguardKey] = None): """ Create a tunnel server. @@ -143,7 +184,7 @@ def create(c, tunnel_id: Optional[UUID4] = None, preshared_key: Optional[Wiregua connect_kwargs={"auth_timeout": 120} # long for 2FA ) - # things get hacky when being concerned with local ssh keys and all - + # things get hacky when being concerned with local ssh keys and all - # the below configures things to "just work", every time. c.run("gcloud compute config-ssh", hide="both") user_from_oslogin = c.run("gcloud compute os-login describe-profile --format=json", hide="both") @@ -226,46 +267,98 @@ def stop(c, tunnel_id): # being lazy and overzealous at the same time - we'll just garbage-college its resources. gc(c) + @task def connect(c, tunnel_id, command="/bin/bash", pty=True): """ Connect to a remote device, identified by a tunnel. """ - # Care should be exercised here; we're taking data from a remote source and using it to - # run shell commands. Validate every last bit of data. - t = get_tunnel(tunnel_id) - assert TunnelState(t['state']) == TunnelState.running, "Device has not yet connected" - dip = device_ip(IPv4Network(t['network'])).ip - assert t['support_user'].isalnum() - assert t['support_user'].isascii() - support_user = t['support_user'] + device = connect_tunnel(c, tunnel_id) + # and finally execute a shell + device.sudo(command, pty=pty) - # set up local - c.run("gcloud compute config-ssh", hide="both") - user_from_oslogin = c.run("gcloud compute os-login describe-profile --format=json", hide="both") - ts_user = json.loads(user_from_oslogin.stdout)['posixAccounts'][0]['username'] - # set up connection to bastion - ts = Connection( - host = str(get_ts_instance_public_ip(tunnel_id)), - user = ts_user, - connect_kwargs={"auth_timeout": 120} # long for 2FA - ) +def is_remote_directory(remote_device: Connection, remote_path: str): + """Determines if a remote path is a directory using a bash command over a tunnel connection, returns true or false""" + result = remote_device.run(f"if [ -d '{remote_path}' ]; then echo 'directory'; else echo 'file'; fi", hide=True).stdout.strip() + return result == 'directory' - # grab the ssh private key on the tunnel server - ssh_privkey = ts.run(f"sudo cat {SSH_KEYFILE_PATH}", hide="both") - assert ssh_privkey - # set up connection to destination device - device = Connection( - host=str(dip), - user=support_user, - gateway=ts, - connect_kwargs = { - "pkey": Ed25519Key.from_private_key(io.StringIO(ssh_privkey.stdout)), - } - ) +def send_file(remote_device: Connection, local_file: str, remote_dir: str): + """Copies local file to remote directory via a tunnel connection""" + filename = path.basename(local_file) + remote_path = path.join(remote_dir, filename).replace("./", "") # path.join adds an unneccessary ./ when combining things with their own directory, making the prints down the line look odd + + # Support user lacks permissions to send file to just any directory, scrape the filename and send to an intermediary and then sudo mv it to the proper location + print(f"Uploading {local_file} to {remote_path}...") + temp_remote = f"/tmp/{filename}" + remote_device.put(local=local_file, remote=temp_remote, preserve_mode=True) + remote_device.sudo(f"mv {temp_remote} {remote_path}", pty=False) + print(f"File successfully uploaded to {remote_path}") + + +def send_directory(remote_device: Connection, local_dir: str, remote_dir: str): + """Recursively copies everything in a local directory to a remote directory via a tunnel connection""" + for root, _, files in walk(local_dir): + relative_root = path.relpath(root, local_dir) + remote_root = path.join(remote_dir, relative_root) + + # Create the corresponding remote directory + remote_device.sudo(f"mkdir -p {remote_root}") + + for file in files: + local_file_path = path.join(root, file) + send_file(remote_device, local_file_path, remote_root) + print("") # Empty print to break up "sending file", "file sent" prints so full directory transfers more human readable + + +@task +def sfile(c, tunnel_id, local, remote): + """ Connect to a remote device and upload a local file or directory to the remote directory. + Requests are formatted as tunnel_id, local file/directory, and remote directory for installation. """ + device = connect_tunnel(c, tunnel_id) + + # Send either a single file or an entire directory + if path.isdir(local): + send_directory(device, local, remote) + else: + send_file(device, local, remote) + + +def get_file(remote_device: Connection, local_dir: str, remote_file: str): + """Copies a remote file to a local directory using a tunnel connection""" + filename = path.basename(remote_file) + local_path = path.join(local_dir, filename) + + print(f"Downloading {remote_file} to {local_path}...") + remote_device.get(remote=remote_file, local=local_path, preserve_mode=True) + print(f"File successfully downloaded to {local_path}.") + + +def get_directory(remote_device, remote_dir, local_dir): + """Recursively copies everything in a remote directory to the local directory using a tunnel connection""" + # List all files and directories under the remote directory + result = remote_device.run(f"find {remote_dir} -type d -or -type f", hide=True).stdout.splitlines() + + for item in result: + relative_path = path.relpath(item, remote_dir) + local_item_path = path.join(local_dir, relative_path).replace("\\", "/") + + if is_remote_directory(remote_device, item): + # Create the corresponding local directory + makedirs(local_item_path, exist_ok=True) + else: + get_file(remote_device, item, path.dirname(local_item_path)) + + +@task +def gfile(c, tunnel_id, remote, local): + """ Connect to a remote device and download a remote file or directory to the local directory. + Requests are formatted as tunnel_id, remote file/directory, and local directory for installation. """ + device = connect_tunnel(c, tunnel_id) + if is_remote_directory(device, remote): + get_directory(device, remote, local) + else: + get_file(device, remote, local) - # and finally execute a shell - device.sudo(command, pty=pty) @task def command(c, tunnel_id, command):