Skip to content

Commit 44da596

Browse files
Stream POST request in order to handle large files (#161)
* Add MultipartEncoder to support request streaming The Multipart encoder helps requests to upload large files without the need to read the entire file in memory * Remove unused typing * Remove duplicate call * Update _create_solution to use new file upload strategy * Only override content type instead of all headers --------- Co-authored-by: Kevin Chung <kchung@ctfd.io>
1 parent 4d47e5e commit 44da596

6 files changed

Lines changed: 424 additions & 359 deletions

File tree

ctfcli/cli/media.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ def add(self, path):
1515

1616
api = API()
1717

18-
new_file = ("file", open(path, mode="rb")) # noqa: SIM115
1918
filename = os.path.basename(path)
19+
new_file = (filename, open(path, mode="rb"))
2020
location = f"media/{filename}"
2121
file_payload = {
2222
"type": "page",
2323
"location": location,
2424
}
2525

2626
# Specifically use data= here to send multipart/form-data
27-
r = api.post("/api/v1/files", files=[new_file], data=file_payload)
27+
r = api.post("/api/v1/files", files={"file": new_file}, data=file_payload)
2828
r.raise_for_status()
2929
resp = r.json()
3030
server_location = resp["data"][0]["location"]

ctfcli/core/api.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from typing import Mapping
12
from urllib.parse import urljoin
23

34
from requests import Session
5+
from requests_toolbelt.multipart.encoder import MultipartEncoder
46

57
from ctfcli.core.config import Config
68
from ctfcli.core.exceptions import MissingAPIKey, MissingInstanceURL
@@ -46,20 +48,56 @@ def __init__(self):
4648
if "cookies" in config:
4749
self.cookies.update(dict(config["cookies"]))
4850

49-
def request(self, method, url, *args, **kwargs):
51+
def request(self, method, url, data=None, files=None, *args, **kwargs):
5052
# Strip out the preceding / so that urljoin creates the right url
5153
# considering the appended / on the prefix_url
5254
url = urljoin(self.prefix_url, url.lstrip("/"))
5355

54-
# if data= is present, do not modify the content-type
55-
if kwargs.get("data") is not None:
56-
return super().request(method, url, *args, **kwargs)
56+
# If data or files are any kind of key/value iterable
57+
# then encode the body as form-data
58+
if isinstance(data, (list, tuple, Mapping)) or isinstance(files, (list, tuple, Mapping)):
59+
# In order to use the MultipartEncoder, we need to convert data and files to the following structure :
60+
# A list of tuple containing the key and the values : List[Tuple[str, str]]
61+
# For files, the structure can be List[Tuple[str, Tuple[str, str, Optional[str]]]]
62+
# Example: [ ('file', ('doc.pdf', open('doc.pdf'), 'text/plain') ) ]
63+
64+
fields = list()
65+
if isinstance(data, dict):
66+
# int are not allowed as value in MultipartEncoder
67+
fields = list(map(lambda v: (v[0], str(v[1]) if isinstance(v[1], int) else v[1]), data.items()))
68+
69+
if files is not None:
70+
if isinstance(files, dict):
71+
files = list(files.items())
72+
fields.extend(files) # type: ignore
73+
74+
multipart = MultipartEncoder(fields)
75+
headers = kwargs.pop("headers", {}) or {}
76+
headers = dict(headers)
77+
headers["Content-Type"] = multipart.content_type
78+
79+
return super(API, self).request(
80+
method,
81+
url,
82+
data=multipart,
83+
headers=headers,
84+
*args,
85+
**kwargs,
86+
)
5787

5888
# otherwise set the content-type to application/json for all API requests
5989
# modify the headers here instead of using self.headers because we don't want to
6090
# override the multipart/form-data case above
61-
if kwargs.get("headers") is None:
62-
kwargs["headers"] = {}
63-
64-
kwargs["headers"]["Content-Type"] = "application/json"
65-
return super().request(method, url, *args, **kwargs)
91+
if data is None and files is None:
92+
if kwargs.get("headers", None) is None:
93+
kwargs["headers"] = {}
94+
kwargs["headers"]["Content-Type"] = "application/json"
95+
96+
return super(API, self).request(
97+
method,
98+
url,
99+
data=data,
100+
files=files,
101+
*args,
102+
**kwargs,
103+
)

ctfcli/core/challenge.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -389,22 +389,21 @@ def _delete_file(self, remote_location: str):
389389
r.raise_for_status()
390390

391391
def _create_file(self, local_path: Path):
392-
new_file = ("file", open(local_path, mode="rb")) # noqa: SIM115
392+
new_file = (local_path.name, open(local_path, mode="rb"))
393393
file_payload = {"challenge_id": self.challenge_id, "type": "challenge"}
394394

395395
# Specifically use data= here to send multipart/form-data
396-
r = self.api.post("/api/v1/files", files=[new_file], data=file_payload)
396+
r = self.api.post("/api/v1/files", files={"file": new_file}, data=file_payload)
397397
r.raise_for_status()
398398

399399
# Close the file handle
400400
new_file[1].close()
401401

402402
def _create_all_files(self):
403403
new_files = []
404-
405-
files = self.get("files") or []
406-
for challenge_file in files:
407-
new_files.append(("file", open(self.challenge_directory / challenge_file, mode="rb"))) # noqa: SIM115
404+
for challenge_file in self["files"]:
405+
file_path = self.challenge_directory / challenge_file
406+
new_files.append(("file", (file_path.name, file_path.open("rb"))))
408407

409408
files_payload = {"challenge_id": self.challenge_id, "type": "challenge"}
410409

@@ -414,7 +413,7 @@ def _create_all_files(self):
414413

415414
# Close the file handles
416415
for file_payload in new_files:
417-
file_payload[1].close()
416+
file_payload[1][1].close()
418417

419418
def _delete_existing_hints(self):
420419
remote_hints = self.api.get("/api/v1/hints").json()["data"]
@@ -585,14 +584,15 @@ def _create_solution(self):
585584
snippet_includes = re.findall(r'(--8<--\s+["\']([^"\']+)["\'])', content)
586585

587586
for mdx, alt, path in markdown_images:
588-
new_file = ("file", open(solution_path.parent / path, mode="rb"))
587+
local_path = solution_path.parent / path
588+
new_file = (local_path.name, open(solution_path.parent / path, mode="rb"))
589589
file_payload = {
590590
"type": "solution",
591591
"solution_id": solution_id,
592592
}
593593

594594
# Specifically use data= here to send multipart/form-data
595-
r = self.api.post("/api/v1/files", files=[new_file], data=file_payload)
595+
r = self.api.post("/api/v1/files", files={"file": new_file}, data=file_payload)
596596
r.raise_for_status()
597597
resp = r.json()
598598
server_location = resp["data"][0]["location"]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies = [
1919
"fire>=0.7.0,<0.8",
2020
"typing-extensions>=4.7.1,<5",
2121
"python-slugify>=8.0.4,<9",
22+
"requests-toolbelt==1.0.0",
2223
]
2324

2425
[project.scripts]
@@ -49,4 +50,3 @@ exclude = ["build/**", "ctfcli/templates/**"]
4950

5051
[tool.ruff.format]
5152
exclude = ["build/**", "ctfcli/templates/**"]
52-

tests/core/test_api.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,20 @@ def test_api_object_request_strips_preceding_slash_from_url_path(self, mock_requ
4646

4747
mock_request.assert_has_calls(
4848
[
49-
call("GET", "https://example.com/test/path", headers={"Content-Type": "application/json"}),
50-
call("GET", "https://example.com/test/path", headers={"Content-Type": "application/json"}),
49+
call(
50+
"GET",
51+
"https://example.com/test/path",
52+
headers={"Content-Type": "application/json"},
53+
data=None,
54+
files=None,
55+
),
56+
call(
57+
"GET",
58+
"https://example.com/test/path",
59+
headers={"Content-Type": "application/json"},
60+
data=None,
61+
files=None,
62+
),
5163
]
5264
)
5365

@@ -60,7 +72,7 @@ def test_api_object_request_assigns_prefix_url(self, mock_request: MagicMock, *a
6072
api = API()
6173
api.request("GET", "path")
6274
mock_request.assert_called_once_with(
63-
"GET", "https://example.com/test/path", headers={"Content-Type": "application/json"}
75+
"GET", "https://example.com/test/path", headers={"Content-Type": "application/json"}, data=None, files=None
6476
)
6577

6678
def test_api_object_assigns_ssl_verify(self, *args, **kwargs):
@@ -170,4 +182,4 @@ def test_api_object_assigns_cookies(self, *args, **kwargs):
170182
def test_request_does_not_override_form_data_content_type(self, mock_request: MagicMock, *args, **kwargs):
171183
api = API()
172184
api.request("GET", "/test", data="some-file")
173-
mock_request.assert_called_once_with("GET", "https://example.com/test", data="some-file")
185+
mock_request.assert_called_once_with("GET", "https://example.com/test", data="some-file", files=None)

0 commit comments

Comments
 (0)