-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathtest_server_urls.py
More file actions
237 lines (218 loc) · 7.82 KB
/
test_server_urls.py
File metadata and controls
237 lines (218 loc) · 7.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import pytest
from dataclasses import dataclass
from unstructured_client import UnstructuredClient, utils
from typing import Optional
# Raise one of these from our mock to return to the test code
class BaseUrlCorrect(Exception):
pass
class BaseUrlIncorrect(Exception):
pass
def get_client_method_with_mock(
sdk_endpoint_name,
client_instance,
mocked_server_url,
monkeypatch
):
"""
Given an endpoint name, e.g. "general.partition", return a reference
to that method off of the given client instance.
The client's _build_request will have the following mock:
Assert that the provided server_url is passed into _build_request.
Raise a custom exception to get back to the test.
"""
# Mock this to get past param validation
def mock_unmarshal(*args, **kwargs):
return {}
monkeypatch.setattr(utils, "unmarshal", mock_unmarshal)
# Assert that the correct base_url makes it to here
def mock_build_request(*args, base_url, **kwargs):
if base_url == mocked_server_url:
raise BaseUrlCorrect
else:
raise BaseUrlIncorrect(base_url)
# Find the method from the given string
class_name, method_name = sdk_endpoint_name.split(".")
endpoint_class = getattr(client_instance, class_name)
endpoint_method = getattr(endpoint_class, method_name)
if "async" in method_name:
monkeypatch.setattr(endpoint_class, "_build_request_async", mock_build_request)
else:
monkeypatch.setattr(endpoint_class, "_build_request", mock_build_request)
return endpoint_method
@dataclass
class URLTestCase:
description: str
sdk_endpoint_name: str
# expected url when actually making the HTTP request in build_request
expected_url: str
# url when you init the client (global for all endpoints)
client_url: Optional[str] = None
# url when you init the SDK endpoint (vary per endpoint)
endpoint_url: Optional[str] = None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"case",
[
URLTestCase(
description="custom client-level URL, no path",
sdk_endpoint_name="general.partition_async",
client_url="http://localhost:8000/",
endpoint_url=None,
expected_url="http://localhost:8000"
),
URLTestCase(
description="custom client-level URL, with path",
sdk_endpoint_name="general.partition_async",
client_url="http://localhost:8000/my/endpoint/",
endpoint_url=None,
expected_url="http://localhost:8000/my/endpoint"
),
URLTestCase(
description="custom endpoint-level URL, no path",
sdk_endpoint_name="general.partition_async",
client_url=None,
endpoint_url="http://localhost:8000/",
expected_url="http://localhost:8000"
),
URLTestCase(
description="custom endpoint-level URL, with path",
sdk_endpoint_name="general.partition_async",
client_url=None,
endpoint_url="http://localhost:8000/my/endpoint/",
expected_url="http://localhost:8000/my/endpoint"
),
URLTestCase(
description="default URL fallback",
sdk_endpoint_name="general.partition_async",
client_url=None,
endpoint_url=None,
expected_url="https://api.unstructuredapp.io"
),
]
)
async def test_async_endpoint_uses_correct_url(monkeypatch, case: URLTestCase):
if case.client_url:
s = UnstructuredClient(server_url=case.client_url)
else:
s = UnstructuredClient()
client_method = get_client_method_with_mock(
case.sdk_endpoint_name,
s,
case.expected_url,
monkeypatch
)
try:
if case.endpoint_url:
await client_method(request={}, server_url=case.endpoint_url)
else:
await client_method(request={})
except BaseUrlCorrect:
pass
except BaseUrlIncorrect as e:
pytest.fail(
f"{case.description}: Expected {case.expected_url}, got {e}"
)
@pytest.mark.parametrize(
"case",
[
URLTestCase(
description="custom client-level URL, no path",
sdk_endpoint_name="destinations.create_destination",
client_url="http://localhost:8000/",
endpoint_url=None,
expected_url="http://localhost:8000"
),
URLTestCase(
description="custom client-level URL, with path",
sdk_endpoint_name="sources.create_source",
client_url="http://localhost:8000/my/endpoint/",
endpoint_url=None,
expected_url="http://localhost:8000/my/endpoint"
),
URLTestCase(
description="custom endpoint-level URL, no path",
sdk_endpoint_name="jobs.get_job",
client_url=None,
endpoint_url="http://localhost:8000",
expected_url="http://localhost:8000"
),
URLTestCase(
description="custom endpoint-level URL, with path",
sdk_endpoint_name="workflows.create_workflow",
client_url=None,
endpoint_url="http://localhost:8000/my/endpoint",
expected_url="http://localhost:8000/my/endpoint"
),
URLTestCase(
description="partition client level with path",
sdk_endpoint_name="general.partition",
client_url="https://api.unstructuredapp.io/general/v0/general",
endpoint_url=None,
expected_url="https://api.unstructuredapp.io"
),
URLTestCase(
description="partition endpoint level with path",
sdk_endpoint_name="general.partition",
client_url=None,
endpoint_url="https://api.unstructuredapp.io/general/v0/general",
expected_url="https://api.unstructuredapp.io"
),
URLTestCase(
description="partition default url",
sdk_endpoint_name="general.partition",
client_url=None,
endpoint_url=None,
expected_url="https://api.unstructuredapp.io"
),
URLTestCase(
description="default URL fallback",
sdk_endpoint_name="destinations.create_destination",
client_url=None,
endpoint_url=None,
expected_url="https://platform.unstructuredapp.io"
),
URLTestCase(
description="default URL fallback",
sdk_endpoint_name="sources.create_source",
client_url=None,
endpoint_url=None,
expected_url="https://platform.unstructuredapp.io"
),
URLTestCase(
description="default URL fallback",
sdk_endpoint_name="jobs.get_job",
client_url=None,
endpoint_url=None,
expected_url="https://platform.unstructuredapp.io"
),
URLTestCase(
description="default URL fallback",
sdk_endpoint_name="workflows.create_workflow",
client_url=None,
endpoint_url=None,
expected_url="https://platform.unstructuredapp.io"
),
]
)
def test_endpoint_uses_correct_url(monkeypatch, case: URLTestCase):
if case.client_url:
s = UnstructuredClient(server_url=case.client_url)
else:
s = UnstructuredClient()
client_method = get_client_method_with_mock(
case.sdk_endpoint_name,
s,
case.expected_url,
monkeypatch
)
try:
if case.endpoint_url:
client_method(request={}, server_url=case.endpoint_url)
else:
client_method(request={})
except BaseUrlCorrect:
pass
except BaseUrlIncorrect as e:
pytest.fail(
f"{case.description}: Expected {case.expected_url}, got {e}"
)