Skip to content

Commit 52d4b4e

Browse files
Merge pull request #56 from thewebscraping/bugs/shared-extra-configs
Fix per-request header order and shared config issue (#53)
2 parents e80972f + 08463e2 commit 52d4b4e

22 files changed

Lines changed: 1940 additions & 94 deletions

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ dependencies = [
3535
"idna~=3.10",
3636
"charset-normalizer",
3737
"orjson",
38-
"chardet~=5.2.0",
3938
]
4039

4140
[dependency-groups]

src/tls_requests/client.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -141,25 +141,16 @@ def __init__(
141141
) -> None:
142142
if "tls_identifier" in config:
143143
logger.warning(
144-
"The 'tls_identifier' parameter is deprecated and will be removed in version 1.3.0. "
145-
"Please use 'client_identifier' instead."
144+
"The 'client_identifier' parameter is deprecated and will be removed in version 1.3.0. "
145+
"Please use 'tls_identifier' instead."
146146
)
147-
if client_identifier == DEFAULT_CLIENT_IDENTIFIER:
148-
client_identifier = config.pop("tls_identifier")
149-
150-
if "tls_debug" in config:
151-
logger.warning(
152-
"The 'tls_debug' parameter is deprecated and will be removed in version 1.3.0. "
153-
"Please use 'debug' instead."
154-
)
155-
if debug == DEFAULT_DEBUG:
156-
debug = config.pop("tls_debug")
147+
client_identifier = config.pop("tls_identifier")
157148

158149
self._session = TLSClient.initialize()
159150
self._config = TLSConfig.from_kwargs(
160151
http2=http2,
161152
verify=verify,
162-
tls_identifier=self.prepare_tls_identifier(client_identifier),
153+
client_identifier=self.prepare_client_identifier(client_identifier),
163154
debug=debug,
164155
protocol_racing=protocol_racing,
165156
allow_http=allow_http,
@@ -297,14 +288,14 @@ def prepare_proxy(self, proxy: Optional[ProxyTypes]) -> Optional[Proxy]:
297288
return Proxy(str(proxy))
298289
raise ProxyError(f"Unsupported proxy type: {type(proxy)}")
299290

300-
def prepare_tls_identifier(self, identifier: Optional[IdentifierArgTypes]) -> str:
291+
def prepare_client_identifier(self, identifier: Optional[IdentifierArgTypes]) -> str:
301292
if isinstance(identifier, str):
302293
return identifier
303294
if isinstance(identifier, TLSIdentifierRotator):
304295
return str(identifier.next())
305296
return str(DEFAULT_CLIENT_IDENTIFIER)
306297

307-
def prepare_config(self, request: Request, tls_identifier: str = DEFAULT_CLIENT_IDENTIFIER):
298+
def prepare_config(self, request: Request, client_identifier: str = DEFAULT_CLIENT_IDENTIFIER):
308299
"""Prepare TLS Config"""
309300

310301
config = self.config.copy_with(
@@ -317,7 +308,7 @@ def prepare_config(self, request: Request, tls_identifier: str = DEFAULT_CLIENT_
317308
timeout=request.timeout,
318309
http2=True if self.http2 in ["auto", "http2", True, None] else False,
319310
verify=self.verify,
320-
tls_identifier=tls_identifier,
311+
client_identifier=client_identifier,
321312
protocol_racing=request.protocol_racing,
322313
allow_http=request.allow_http,
323314
stream_id=request.stream_id,
@@ -457,8 +448,8 @@ def _send(
457448
) -> Response:
458449
history = [] if history is None else history
459450
start = start or time.perf_counter()
460-
tls_identifier = self.prepare_tls_identifier(self.client_identifier)
461-
config = self.prepare_config(request, tls_identifier=tls_identifier)
451+
client_identifier = self.prepare_client_identifier(self.client_identifier)
452+
config = self.prepare_config(request, client_identifier=client_identifier)
462453
response = Response.from_tls_response(
463454
self.session.request(config.to_dict()),
464455
is_byte_response=config.isByteResponse,
@@ -894,7 +885,7 @@ async def aprepare_proxy(self, proxy: Optional[ProxyTypes]) -> Optional[Proxy]:
894885
return Proxy(str(proxy))
895886
raise ProxyError(f"Unsupported proxy type: {type(proxy)}")
896887

897-
async def aprepare_tls_identifier(self, identifier) -> str:
888+
async def aprepare_client_identifier(self, identifier) -> str:
898889
if isinstance(identifier, str):
899890
return identifier
900891
if isinstance(identifier, TLSIdentifierRotator):
@@ -1246,8 +1237,8 @@ async def _send( # type: ignore[override]
12461237
) -> Response:
12471238
history = [] if history is None else history
12481239
start = start or time.perf_counter()
1249-
tls_identifier = await self.aprepare_tls_identifier(self.client_identifier)
1250-
config = self.prepare_config(request, tls_identifier=tls_identifier)
1240+
client_identifier = await self.aprepare_client_identifier(self.client_identifier)
1241+
config = self.prepare_config(request, client_identifier=client_identifier)
12511242
response = Response.from_tls_response(
12521243
await self.session.arequest(config.to_dict()),
12531244
is_byte_response=config.isByteResponse,

src/tls_requests/models/encoders.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
from io import BufferedReader, BytesIO, TextIOWrapper
66
from mimetypes import guess_type
7-
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Tuple, TypeVar
7+
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Tuple, TypeVar, cast
88
from urllib.parse import urlencode
99

1010
from ..types import BufferTypes, ByteOrStr, RequestData, RequestFiles, RequestFileValue, RequestJson
@@ -106,7 +106,7 @@ def unpack(self, value: RequestFileValue) -> Tuple[str, BufferTypes, str]:
106106
if args:
107107
content_type = args[0]
108108
else:
109-
buffer = value
109+
buffer = value[0]
110110

111111
elif isinstance(value, str):
112112
buffer = value.encode("utf-8")
@@ -124,12 +124,12 @@ def unpack(self, value: RequestFileValue) -> Tuple[str, BufferTypes, str]:
124124
buffer.close()
125125
buffer = open(buffer.name, "rb")
126126

127-
elif not isinstance(buffer, bytes):
128-
raise ValueError
129-
else:
127+
elif isinstance(buffer, bytes):
130128
buffer = BytesIO(buffer)
129+
elif not hasattr(buffer, "read"):
130+
raise ValueError
131131

132-
return str(filename or "upload"), buffer, str(content_type or "application/octet-stream")
132+
return str(filename or "upload"), cast(BufferTypes, buffer), str(content_type or "application/octet-stream")
133133

134134
def render_data(self, chunk_size: int = 65_536) -> Iterator[bytes]:
135135
yield from iter_buffer(self._buffer, chunk_size)

src/tls_requests/models/headers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ def __setitem__(self, key, value) -> None:
115115
self._items.append((key, value))
116116

117117
def __getitem__(self, key):
118-
return self.get(key)
118+
val = self.get(key)
119+
if val is None:
120+
raise KeyError(key)
121+
return val
119122

120123
def __delitem__(self, key):
121124
key = self._normalize_key(key)

src/tls_requests/models/rotators.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
T = TypeVar("T")
2525
R = TypeVar("R", bound="BaseRotator")
2626

27-
TLS_IDENTIFIER_TEMPLATES = [
27+
CLIENT_IDENTIFIER_TEMPLATES = [
2828
"chrome_120",
2929
"chrome_124",
3030
"chrome_131",
@@ -237,9 +237,20 @@ def __init__(
237237
self.strategy = strategy
238238
self._iterator: Optional[Iterator[T]] = None
239239
self._lock = threading.Lock()
240-
self._async_lock = asyncio.Lock()
240+
self._async_lock: Optional[asyncio.Lock] = None
241241
self._rebuild_iterator()
242242

243+
@property
244+
def async_lock(self) -> asyncio.Lock:
245+
"""
246+
Lazily creates and returns an `asyncio.Lock`. This ensures that the lock
247+
is created within the correct event loop and avoids issues in
248+
synchronous contexts or older Python versions (like 3.9).
249+
"""
250+
if self._async_lock is None:
251+
self._async_lock = asyncio.Lock()
252+
return self._async_lock
253+
243254
@classmethod
244255
def from_file(
245256
cls: type[R],
@@ -354,7 +365,7 @@ async def anext(self, *args, **kwargs) -> T:
354365
Raises:
355366
ValueError: If the rotator contains no items.
356367
"""
357-
async with self._async_lock:
368+
async with self.async_lock:
358369
if not self.items:
359370
raise ValueError("Rotator is empty.")
360371
if self.strategy == "random":
@@ -368,15 +379,15 @@ async def aadd(self, item: T) -> None:
368379
"""
369380
Adds a new item to the rotator in a coroutine-safe manner.
370381
"""
371-
async with self._async_lock:
382+
async with self.async_lock:
372383
self.items.append(item)
373384
self._rebuild_iterator()
374385

375386
async def aremove(self, item: T) -> None:
376387
"""
377388
Removes an item from the rotator in a coroutine-safe manner.
378389
"""
379-
async with self._async_lock:
390+
async with self.async_lock:
380391
self.items = [i for i in self.items if i != item]
381392
self._rebuild_iterator()
382393

@@ -418,7 +429,7 @@ async def amark_result(self, proxy: Proxy, success: bool, latency: Optional[floa
418429
"""
419430
Coroutine-safely updates a proxy's performance statistics.
420431
"""
421-
async with self._async_lock:
432+
async with self.async_lock:
422433
self._update_proxy_stats(proxy, success, latency)
423434

424435
def _update_proxy_stats(self, proxy: Proxy, success: bool, latency: Optional[float] = None):
@@ -442,7 +453,7 @@ def __init__(
442453
items: Optional[Iterable[IdentifierTypes]] = None,
443454
strategy: Literal["round_robin", "random", "weighted"] = "round_robin",
444455
) -> None:
445-
super().__init__(items or TLS_IDENTIFIER_TEMPLATES, strategy) # type: ignore[arg-type]
456+
super().__init__(items or CLIENT_IDENTIFIER_TEMPLATES, strategy) # type: ignore[arg-type]
446457

447458
@classmethod
448459
def rebuild_item(cls, item: Any) -> Optional[IdentifierTypes]:

src/tls_requests/models/tls.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ async def _aread(cls, fn: Callable, payload: dict):
200200
class _BaseConfig:
201201
"""Base configuration for TLSSession"""
202202

203-
_extra_kwargs: dict = field(default_factory=dict, init=False, repr=False)
203+
_extra_config: dict = field(default_factory=dict, init=False, repr=False)
204204

205205
@classmethod
206206
def model_fields_set(cls) -> Set[str]:
@@ -212,13 +212,13 @@ def from_kwargs(cls: type[T], **kwargs: Any) -> T:
212212
known_kwargs = {cls.to_camel_case(k): v for k, v in kwargs.items() if k in model_fields_set}
213213
extra_kwargs = {cls.to_camel_case(k): v for k, v in kwargs.items() if k not in model_fields_set}
214214
instance = cls(**known_kwargs)
215-
instance._extra_kwargs = extra_kwargs
215+
instance._extra_config = extra_kwargs
216216
return instance
217217

218218
def to_dict(self) -> dict:
219219
data = asdict(self)
220-
if hasattr(self, "_extra_kwargs"):
221-
data.update(self._extra_kwargs)
220+
if hasattr(self, "_extra_config"):
221+
data.update(self._extra_config)
222222
return {k: v for k, v in data.items() if not k.startswith("_") and v is not None}
223223

224224
def to_payload(self) -> dict:
@@ -552,13 +552,13 @@ def copy_with(
552552
kwargs.update(filtered_mapping)
553553

554554
current_kwargs = asdict(self)
555-
if hasattr(self, "_extra"):
556-
current_kwargs.update(self._extra_kwargs)
555+
if hasattr(self, "_extra_config"):
556+
current_kwargs.update(self._extra_config)
557557

558558
for k, v in kwargs.items():
559559
current_kwargs[k] = v
560560

561-
return self.__class__.from_kwargs(**current_kwargs)
561+
return super().from_kwargs(**current_kwargs)
562562

563563
@classmethod
564564
def from_kwargs(

src/tls_requests/models/urls.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -339,12 +339,81 @@ def _prepare(self, url: Union["URL", str, bytes]) -> ParseResult:
339339
url = str(url)
340340

341341
if not isinstance(url, str):
342-
raise URLError("Invalid URL: %s" % url)
342+
raise URLError(f"Invalid URL: {url}")
343+
344+
url_to_parse = url.lstrip()
345+
346+
# 0. Pre-parsing: default to http if scheme is missing
347+
if "://" not in url_to_parse and not url_to_parse.startswith("/") and not url_to_parse.startswith("./"):
348+
# Check if it doesn't look like a potential relative URL with query/fragment
349+
if not (url_to_parse.startswith("?") or url_to_parse.startswith("#")):
350+
url_to_parse = f"http://{url_to_parse}"
351+
352+
# 1. Pre-parsing repair for raw IPv6 addresses
353+
if ":" in url_to_parse:
354+
# Extract authority candidate: part between scheme and path
355+
if "://" in url_to_parse:
356+
authority_candidate = url_to_parse.split("://", 1)[1].split("/", 1)[0].split("?", 1)[0].split("#", 1)[0]
357+
else:
358+
authority_candidate = url_to_parse.split("/", 1)[0].split("?", 1)[0].split("#", 1)[0]
359+
360+
# Extract host part (ignoring user:pass@)
361+
host_candidate = authority_candidate.rsplit("@", 1)[-1]
362+
363+
# If it looks like IPv6 but lacks brackets
364+
if host_candidate.count(":") > 1 and not (host_candidate.startswith("[") and "]" in host_candidate):
365+
# Try to determine if it's IP:PORT or just IP
366+
# We prioritize IP:PORT if the last part is digits
367+
possible_ips = []
368+
h_p, _, port = host_candidate.rpartition(":")
369+
if port.isdigit() and h_p.count(":") >= 1:
370+
possible_ips.append((h_p, port))
371+
possible_ips.append((host_candidate, ""))
372+
373+
for ip, p_val in possible_ips:
374+
try:
375+
ipaddress.IPv6Address(ip)
376+
repaired = f"[{ip}]"
377+
if p_val:
378+
repaired += f":{p_val}"
379+
url_to_parse = url_to_parse.replace(host_candidate, repaired, 1)
380+
break
381+
except ValueError:
382+
continue
343383

344384
for attr in self.__attrs__:
345385
setattr(self, attr, None)
346386

347-
parsed = urlparse(url.lstrip())
387+
# 2. Parse and Validate
388+
try:
389+
# First, check for malformed brackets in the string we're about to parse
390+
# We strictly enforce one '[' and one ']' in the authority if any exist
391+
authority = ""
392+
if "://" in url_to_parse:
393+
authority = url_to_parse.split("://", 1)[1].split("/", 1)[0]
394+
else:
395+
authority = url_to_parse.split("/", 1)[0]
396+
397+
if "[" in authority or "]" in authority:
398+
if authority.count("[") != 1 or authority.count("]") != 1:
399+
raise ValueError("Malformed bracketed host")
400+
401+
start = authority.find("[")
402+
end = authority.find("]")
403+
if start > end:
404+
raise ValueError("Invalid bracket order")
405+
406+
# Content inside brackets MUST be a valid IPv6
407+
ip_content = authority[start + 1 : end]
408+
try:
409+
ipaddress.IPv6Address(ip_content)
410+
except ValueError:
411+
raise ValueError(f"Invalid IPv6 in brackets: {ip_content}")
412+
413+
parsed = urlparse(url_to_parse)
414+
415+
except (ValueError, AttributeError) as e:
416+
raise URLError(f"Invalid URL: {url}. {str(e)}") from e
348417

349418
self.auth = parsed.username, parsed.password
350419
self.scheme = parsed.scheme
@@ -363,14 +432,14 @@ def _prepare(self, url: Union["URL", str, bytes]) -> ParseResult:
363432
try:
364433
self.host = idna.encode(hostname).decode("ascii")
365434
except idna.IDNAError:
366-
raise URLError("Invalid IDNA hostname: %s" % hostname)
435+
raise URLError(f"Invalid IDNA hostname: {hostname}")
367436

368437
self.port = ""
369438
try:
370439
if parsed.port:
371440
self.port = str(parsed.port)
372441
except ValueError as e:
373-
raise URLError("%s. port range must be 0 - 65535." % e.args[0])
442+
raise URLError(f"{e.args[0]}. port range must be 0 - 65535.")
374443

375444
self.path = parsed.path
376445
self.fragment = parsed.fragment
@@ -388,25 +457,28 @@ def _build(self, secure: bool = False) -> str:
388457
Returns:
389458
The final URL string.
390459
"""
391-
urls = [self.scheme, "://"]
460+
scheme = self.scheme or ""
461+
urls = [scheme, "://"] if scheme else []
392462
authority = self.netloc
393463
if self.username or self.password:
464+
username = self.username or ""
394465
password = self.password or ""
395466
if secure:
396467
password = "[secure]"
397468

398469
authority = "@".join(
399470
[
400-
":".join([self.username, password]),
471+
":".join([username, password]),
401472
self.netloc,
402473
]
403474
)
404475

405476
urls.append(authority)
477+
path = self.path or ""
406478
if self.query:
407-
urls.append("?".join([self.path, self.query]))
479+
urls.append("?".join([path, self.query]))
408480
else:
409-
urls.append(self.path)
481+
urls.append(path)
410482

411483
if self.fragment:
412484
urls.append("#" + self.fragment)

0 commit comments

Comments
 (0)