Skip to content

Commit 36774e0

Browse files
Add file getting, sending to cli.py
Convert device connection into a separate function Upgrade file sending to be in line with file getting regulate spacing between functions to two lines (there's been a mix of 1 and 2 lines so far) Remove leftover comment Add docstrings, typehinting
1 parent ad189ce commit 36774e0

1 file changed

Lines changed: 127 additions & 34 deletions

File tree

admin/cli.py

Lines changed: 127 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import random
44
import logging
55

6-
from os import getenv
6+
from os import getenv, path, makedirs, walk
77
from uuid import UUID
88
from time import sleep
99
from typing import Optional
@@ -69,6 +69,46 @@ def get_tunnel(tunnel_id) -> Optional[dict]:
6969
return None
7070

7171

72+
def connect_tunnel(c, tunnel_id):
73+
"""Connects to a live support tunnel using the tunnel_id, and returns a connection to the device"""
74+
# Care should be exercised here; we're taking data from a remote source and using it to
75+
# run shell commands. Validate every last bit of data.
76+
t = get_tunnel(tunnel_id)
77+
assert TunnelState(t['state']) == TunnelState.running, "Device has not yet connected"
78+
dip = device_ip(IPv4Network(t['network'])).ip
79+
assert t['support_user'].isalnum()
80+
assert t['support_user'].isascii()
81+
support_user = t['support_user']
82+
83+
# set up local
84+
c.run("gcloud compute config-ssh", hide="both")
85+
user_from_oslogin = c.run("gcloud compute os-login describe-profile --format=json", hide="both")
86+
ts_user = json.loads(user_from_oslogin.stdout)['posixAccounts'][0]['username']
87+
88+
# set up connection to bastion
89+
ts = Connection(
90+
host = str(get_ts_instance_public_ip(tunnel_id)),
91+
user = ts_user,
92+
connect_kwargs={"auth_timeout": 120}
93+
)
94+
95+
# grab the ssh private key on the tunnel server
96+
ssh_privkey = ts.run(f"sudo cat {SSH_KEYFILE_PATH}", hide="both")
97+
assert ssh_privkey
98+
99+
# set up connection to destination device
100+
device = Connection(
101+
host=str(dip),
102+
user=support_user,
103+
gateway=ts,
104+
connect_kwargs = {
105+
"pkey": Ed25519Key.from_private_key(io.StringIO(ssh_privkey.stdout)),
106+
}
107+
)
108+
109+
return device
110+
111+
72112
@task
73113
def show(c, tunnel_id):
74114
""" Show a single tunnel's details """
@@ -84,6 +124,7 @@ def list(c):
84124
res.raise_for_status()
85125
print(res.text)
86126

127+
87128
@task
88129
def create(c, tunnel_id: Optional[UUID4] = None, preshared_key: Optional[WireguardKey] = None):
89130
""" Create a tunnel server.
@@ -143,7 +184,7 @@ def create(c, tunnel_id: Optional[UUID4] = None, preshared_key: Optional[Wiregua
143184
connect_kwargs={"auth_timeout": 120} # long for 2FA
144185
)
145186

146-
# things get hacky when being concerned with local ssh keys and all -
187+
# things get hacky when being concerned with local ssh keys and all -
147188
# the below configures things to "just work", every time.
148189
c.run("gcloud compute config-ssh", hide="both")
149190
user_from_oslogin = c.run("gcloud compute os-login describe-profile --format=json", hide="both")
@@ -226,46 +267,98 @@ def stop(c, tunnel_id):
226267
# being lazy and overzealous at the same time - we'll just garbage-college its resources.
227268
gc(c)
228269

270+
229271
@task
230272
def connect(c, tunnel_id, command="/bin/bash", pty=True):
231273
""" Connect to a remote device, identified by a tunnel. """
232-
# Care should be exercised here; we're taking data from a remote source and using it to
233-
# run shell commands. Validate every last bit of data.
234-
t = get_tunnel(tunnel_id)
235-
assert TunnelState(t['state']) == TunnelState.running, "Device has not yet connected"
236-
dip = device_ip(IPv4Network(t['network'])).ip
237-
assert t['support_user'].isalnum()
238-
assert t['support_user'].isascii()
239-
support_user = t['support_user']
274+
device = connect_tunnel(c, tunnel_id)
275+
# and finally execute a shell
276+
device.sudo(command, pty=pty)
240277

241-
# set up local
242-
c.run("gcloud compute config-ssh", hide="both")
243-
user_from_oslogin = c.run("gcloud compute os-login describe-profile --format=json", hide="both")
244-
ts_user = json.loads(user_from_oslogin.stdout)['posixAccounts'][0]['username']
245278

246-
# set up connection to bastion
247-
ts = Connection(
248-
host = str(get_ts_instance_public_ip(tunnel_id)),
249-
user = ts_user,
250-
connect_kwargs={"auth_timeout": 120} # long for 2FA
251-
)
279+
def is_remote_directory(remote_device: Connection, remote_path: str):
280+
"""Determines if a remote path is a directory using a bash command over a tunnel connection, returns true or false"""
281+
result = remote_device.run(f"if [ -d '{remote_path}' ]; then echo 'directory'; else echo 'file'; fi", hide=True).stdout.strip()
282+
return result == 'directory'
252283

253-
# grab the ssh private key on the tunnel server
254-
ssh_privkey = ts.run(f"sudo cat {SSH_KEYFILE_PATH}", hide="both")
255-
assert ssh_privkey
256284

257-
# set up connection to destination device
258-
device = Connection(
259-
host=str(dip),
260-
user=support_user,
261-
gateway=ts,
262-
connect_kwargs = {
263-
"pkey": Ed25519Key.from_private_key(io.StringIO(ssh_privkey.stdout)),
264-
}
265-
)
285+
def send_file(remote_device: Connection, local_file: str, remote_dir: str):
286+
"""Copies local file to remote directory via a tunnel connection"""
287+
filename = path.basename(local_file)
288+
remote_path = path.join(remote_dir, filename).replace("./", "") # path.join adds an unneccessary ./ when combining things with their own directory, making the prints down the line look odd
289+
290+
# Support user lacks permissions to send file to just any directory, scrape the filename and send to an intermediary and then sudo mv it to the proper location
291+
print(f"Uploading {local_file} to {remote_path}...")
292+
temp_remote = f"/tmp/{filename}"
293+
remote_device.put(local=local_file, remote=temp_remote, preserve_mode=True)
294+
remote_device.sudo(f"mv {temp_remote} {remote_path}", pty=False)
295+
print(f"File successfully uploaded to {remote_path}")
296+
297+
298+
def send_directory(remote_device: Connection, local_dir: str, remote_dir: str):
299+
"""Recursively copies everything in a local directory to a remote directory via a tunnel connection"""
300+
for root, _, files in walk(local_dir):
301+
relative_root = path.relpath(root, local_dir)
302+
remote_root = path.join(remote_dir, relative_root)
303+
304+
# Create the corresponding remote directory
305+
remote_device.sudo(f"mkdir -p {remote_root}")
306+
307+
for file in files:
308+
local_file_path = path.join(root, file)
309+
send_file(remote_device, local_file_path, remote_root)
310+
print("") # Empty print to break up "sending file", "file sent" prints so full directory transfers more human readable
311+
312+
313+
@task
314+
def sfile(c, tunnel_id, local, remote):
315+
""" Connect to a remote device and upload a local file or directory to the remote directory.
316+
Requests are formatted as tunnel_id, local file/directory, and remote directory for installation. """
317+
device = connect_tunnel(c, tunnel_id)
318+
319+
# Send either a single file or an entire directory
320+
if path.isdir(local):
321+
send_directory(device, local, remote)
322+
else:
323+
send_file(device, local, remote)
324+
325+
326+
def get_file(remote_device: Connection, local_dir: str, remote_file: str):
327+
"""Copies a remote file to a local directory using a tunnel connection"""
328+
filename = path.basename(remote_file)
329+
local_path = path.join(local_dir, filename)
330+
331+
print(f"Downloading {remote_file} to {local_path}...")
332+
remote_device.get(remote=remote_file, local=local_path, preserve_mode=True)
333+
print(f"File successfully downloaded to {local_path}.")
334+
335+
336+
def get_directory(remote_device, remote_dir, local_dir):
337+
"""Recursively copies everything in a remote directory to the local directory using a tunnel connection"""
338+
# List all files and directories under the remote directory
339+
result = remote_device.run(f"find {remote_dir} -type d -or -type f", hide=True).stdout.splitlines()
340+
341+
for item in result:
342+
relative_path = path.relpath(item, remote_dir)
343+
local_item_path = path.join(local_dir, relative_path).replace("\\", "/")
344+
345+
if is_remote_directory(remote_device, item):
346+
# Create the corresponding local directory
347+
makedirs(local_item_path, exist_ok=True)
348+
else:
349+
get_file(remote_device, item, path.dirname(local_item_path))
350+
351+
352+
@task
353+
def gfile(c, tunnel_id, remote, local):
354+
""" Connect to a remote device and download a remote file or directory to the local directory.
355+
Requests are formatted as tunnel_id, remote file/directory, and local directory for installation. """
356+
device = connect_tunnel(c, tunnel_id)
357+
if is_remote_directory(device, remote):
358+
get_directory(device, remote, local)
359+
else:
360+
get_file(device, remote, local)
266361

267-
# and finally execute a shell
268-
device.sudo(command, pty=pty)
269362

270363
@task
271364
def command(c, tunnel_id, command):

0 commit comments

Comments
 (0)