|
| 1 | +from typing import Mapping |
1 | 2 | from urllib.parse import urljoin |
2 | 3 |
|
3 | 4 | from requests import Session |
| 5 | +from requests_toolbelt.multipart.encoder import MultipartEncoder |
4 | 6 |
|
5 | 7 | from ctfcli.core.config import Config |
6 | 8 | from ctfcli.core.exceptions import MissingAPIKey, MissingInstanceURL |
@@ -46,20 +48,56 @@ def __init__(self): |
46 | 48 | if "cookies" in config: |
47 | 49 | self.cookies.update(dict(config["cookies"])) |
48 | 50 |
|
49 | | - def request(self, method, url, *args, **kwargs): |
| 51 | + def request(self, method, url, data=None, files=None, *args, **kwargs): |
50 | 52 | # Strip out the preceding / so that urljoin creates the right url |
51 | 53 | # considering the appended / on the prefix_url |
52 | 54 | url = urljoin(self.prefix_url, url.lstrip("/")) |
53 | 55 |
|
54 | | - # if data= is present, do not modify the content-type |
55 | | - if kwargs.get("data") is not None: |
56 | | - return super().request(method, url, *args, **kwargs) |
| 56 | + # If data or files are any kind of key/value iterable |
| 57 | + # then encode the body as form-data |
| 58 | + if isinstance(data, (list, tuple, Mapping)) or isinstance(files, (list, tuple, Mapping)): |
| 59 | + # In order to use the MultipartEncoder, we need to convert data and files to the following structure : |
| 60 | + # A list of tuple containing the key and the values : List[Tuple[str, str]] |
| 61 | + # For files, the structure can be List[Tuple[str, Tuple[str, str, Optional[str]]]] |
| 62 | + # Example: [ ('file', ('doc.pdf', open('doc.pdf'), 'text/plain') ) ] |
| 63 | + |
| 64 | + fields = list() |
| 65 | + if isinstance(data, dict): |
| 66 | + # int are not allowed as value in MultipartEncoder |
| 67 | + fields = list(map(lambda v: (v[0], str(v[1]) if isinstance(v[1], int) else v[1]), data.items())) |
| 68 | + |
| 69 | + if files is not None: |
| 70 | + if isinstance(files, dict): |
| 71 | + files = list(files.items()) |
| 72 | + fields.extend(files) # type: ignore |
| 73 | + |
| 74 | + multipart = MultipartEncoder(fields) |
| 75 | + headers = kwargs.pop("headers", {}) or {} |
| 76 | + headers = dict(headers) |
| 77 | + headers["Content-Type"] = multipart.content_type |
| 78 | + |
| 79 | + return super(API, self).request( |
| 80 | + method, |
| 81 | + url, |
| 82 | + data=multipart, |
| 83 | + headers=headers, |
| 84 | + *args, |
| 85 | + **kwargs, |
| 86 | + ) |
57 | 87 |
|
58 | 88 | # otherwise set the content-type to application/json for all API requests |
59 | 89 | # modify the headers here instead of using self.headers because we don't want to |
60 | 90 | # override the multipart/form-data case above |
61 | | - if kwargs.get("headers") is None: |
62 | | - kwargs["headers"] = {} |
63 | | - |
64 | | - kwargs["headers"]["Content-Type"] = "application/json" |
65 | | - return super().request(method, url, *args, **kwargs) |
| 91 | + if data is None and files is None: |
| 92 | + if kwargs.get("headers", None) is None: |
| 93 | + kwargs["headers"] = {} |
| 94 | + kwargs["headers"]["Content-Type"] = "application/json" |
| 95 | + |
| 96 | + return super(API, self).request( |
| 97 | + method, |
| 98 | + url, |
| 99 | + data=data, |
| 100 | + files=files, |
| 101 | + *args, |
| 102 | + **kwargs, |
| 103 | + ) |
0 commit comments