Skip to content

Commit eafbd4b

Browse files
fix: apply proxy to (almost) all requests (#336)
Right now, the `BOT_PROXY_URL` env variable only applies to `bot.http_session`. This expands this to more requests: - All requests to the Discord API by the underlying disnake client (except interaction responses due to DisnakeDev/disnake#261) - GraphQL requests to GitHub The GraphQL one was a bit annoying, since graphql-python doesn't have a proxy parameter itself. It does, however, support passing arbitrary kwargs to the created `aiohttp.ClientSession`, which we can use to set a proxy directly. --------- Co-authored-by: arielle <me@arielle.codes>
1 parent 9cb92af commit eafbd4b

7 files changed

Lines changed: 56 additions & 25 deletions

File tree

monty/aiohttp_session.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import socket
44
import sys
55
from datetime import timedelta
6-
from typing import TYPE_CHECKING, Any
6+
from typing import TYPE_CHECKING, Any, TypedDict
77
from unittest.mock import Mock
88

99
import aiohttp
1010
from multidict import CIMultiDict, CIMultiDictProxy
1111

1212
from monty import constants
1313
from monty.log import get_logger
14+
from monty.utils import helpers
1415
from monty.utils.caching import RedisCache
1516

1617

@@ -45,14 +46,29 @@ async def _on_request_end(
4546
)
4647

4748

49+
class SessionArgs(TypedDict):
50+
proxy: str | None
51+
connector: aiohttp.BaseConnector
52+
53+
54+
def session_args_for_proxy(proxy: str | None) -> SessionArgs:
55+
"""Create a dict with `proxy` and `connector` items, to be passed to aiohttp.ClientSession."""
56+
connector = aiohttp.TCPConnector(
57+
resolver=aiohttp.AsyncResolver(),
58+
family=socket.AF_INET,
59+
ssl=(
60+
helpers._SSL_CONTEXT_UNVERIFIED
61+
if (proxy and proxy.startswith("http://"))
62+
else helpers._SSL_CONTEXT_VERIFIED
63+
),
64+
)
65+
return {"proxy": proxy or None, "connector": connector}
66+
67+
4868
class CachingClientSession(aiohttp.ClientSession):
4969
def __init__(self, *args: Any, **kwargs: Any) -> None:
50-
if "connector" not in kwargs:
51-
kwargs["connector"] = aiohttp.TCPConnector(
52-
resolver=aiohttp.AsyncResolver(),
53-
family=socket.AF_INET,
54-
verify_ssl=not bool(constants.Client.proxy and constants.Client.proxy.startswith("http://")),
55-
)
70+
kwargs.update(session_args_for_proxy(kwargs.get("proxy")))
71+
5672
if "trace_configs" not in kwargs:
5773
trace_config = aiohttp.TraceConfig()
5874
trace_config.on_request_end.append(_on_request_end)

monty/bot.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sqlalchemy.orm import selectinload
1515

1616
from monty import constants
17-
from monty.aiohttp_session import CachingClientSession
17+
from monty.aiohttp_session import CachingClientSession, session_args_for_proxy
1818
from monty.database import Feature, Guild, GuildConfig
1919
from monty.database.rollouts import Rollout
2020
from monty.log import get_logger
@@ -60,6 +60,10 @@ def __init__(
6060
if TEST_GUILDS:
6161
kwargs["test_guilds"] = TEST_GUILDS
6262
log.warning("registering as test_guilds")
63+
64+
# pass proxy and connector to disnake client
65+
kwargs.update(session_args_for_proxy(proxy))
66+
6367
super().__init__(**kwargs)
6468

6569
self.redis_session = redis_session

monty/exts/info/docs/_batch_parser.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from monty import constants
1212
from monty.bot import Monty
1313
from monty.log import get_logger
14-
from monty.utils import helpers, scheduling
14+
from monty.utils import scheduling
1515
from monty.utils.html_parsing import get_symbol_markdown
1616

1717
from . import _cog, doc_cache
@@ -120,11 +120,8 @@ async def get_markdown(self, doc_item: "_cog.DocItem") -> str | None:
120120
if doc_item not in self._item_futures and doc_item not in self._queue:
121121
self._item_futures[doc_item].user_requested = True
122122

123-
# providing a context is workaround for cloudflare issues
124123
try:
125-
async with self._bot.http_session.get(
126-
doc_item.url, raise_for_status=True, ssl=helpers.ssl_create_default_context()
127-
) as response:
124+
async with self._bot.http_session.get(doc_item.url, raise_for_status=True) as response:
128125
soup = await self._bot.loop.run_in_executor(
129126
None,
130127
BeautifulSoup,

monty/exts/info/github_info.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import monty.utils.services
2525
from monty import constants
26+
from monty.aiohttp_session import session_args_for_proxy
2627
from monty.bot import Monty
2728
from monty.constants import Feature
2829
from monty.errors import MontyCommandError
@@ -221,7 +222,13 @@ class GithubInfo(
221222
def __init__(self, bot: Monty) -> None:
222223
self.bot = bot
223224

224-
transport = AIOHTTPTransport(url="https://api.github.com/graphql", timeout=20, headers=GITHUB_REQUEST_HEADERS)
225+
transport = AIOHTTPTransport(
226+
url="https://api.github.com/graphql",
227+
timeout=20,
228+
headers=GITHUB_REQUEST_HEADERS,
229+
# copy because invariance
230+
client_session_args=dict(session_args_for_proxy(bot.http.proxy)),
231+
)
225232

226233
self.gql_client = gql.Client(transport=transport, fetch_schema_from_transport=True)
227234

monty/utils/converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from monty.bot import Monty
1717
from monty.database import Feature, Rollout
1818
from monty.log import get_logger
19-
from monty.utils import helpers, inventory_parser
19+
from monty.utils import inventory_parser
2020
from monty.utils.extensions import EXTENSIONS, unqualify
2121
from monty.utils.features import NAME_REGEX as FEATURE_NAME_REGEX
2222

@@ -204,7 +204,7 @@ class ValidURL(commands.Converter):
204204
async def convert(ctx: commands.Context, url: str) -> str:
205205
"""This converter checks whether the given URL can be reached with a status code of 200."""
206206
try:
207-
async with ctx.bot.http_session.get(url, ssl=helpers.ssl_create_default_context()) as resp:
207+
async with ctx.bot.http_session.get(url) as resp:
208208
if resp.status != 200:
209209
msg = f"HTTP GET on `{url}` returned status `{resp.status}`, expected 200"
210210
raise commands.BadArgument(msg)

monty/utils/helpers.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,21 @@ def fromisoformat(timestamp: str) -> datetime.datetime:
144144
return dt
145145

146146

147-
def ssl_create_default_context() -> ssl.SSLContext:
148-
"""Return an ssl context that CloudFlare shouldn't flag."""
149-
ssl_context = ssl.create_default_context()
150-
ssl_context.post_handshake_auth = True
151-
return ssl_context
147+
def _create_ssl_context(*, verify: bool) -> ssl.SSLContext:
148+
if verify:
149+
ctx = ssl.create_default_context()
150+
else:
151+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
152+
ctx.check_hostname = False
153+
ctx.verify_mode = ssl.CERT_NONE
154+
ctx.set_alpn_protocols(["http/1.1"])
155+
# change tls fingerprint to avoid being flagged by cloudflare
156+
ctx.post_handshake_auth = True
157+
return ctx
158+
159+
160+
_SSL_CONTEXT_VERIFIED = _create_ssl_context(verify=True)
161+
_SSL_CONTEXT_UNVERIFIED = _create_ssl_context(verify=False)
152162

153163

154164
@overload

monty/utils/inventory_parser.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import aiohttp
1010

1111
from monty.log import get_logger
12-
from monty.utils import helpers
1312
from monty.utils.caching import redis_cache
1413

1514

@@ -93,9 +92,7 @@ async def _load_v2(stream: aiohttp.StreamReader) -> InventoryDict:
9392
async def _fetch_inventory(bot: Monty, url: str) -> InventoryDict:
9493
"""Fetch, parse and return an intersphinx inventory file from an url."""
9594
timeout = aiohttp.ClientTimeout(sock_connect=5, sock_read=5)
96-
async with bot.http_session.get(
97-
url, timeout=timeout, raise_for_status=True, use_cache=False, ssl=helpers.ssl_create_default_context()
98-
) as response:
95+
async with bot.http_session.get(url, timeout=timeout, raise_for_status=True, use_cache=False) as response:
9996
stream = response.content
10097

10198
inventory_header = (await stream.readline()).decode().rstrip()

0 commit comments

Comments
 (0)