Skip to content

Commit 610fc60

Browse files
committed
fix a bug with parameter name
closes #7
1 parent af76672 commit 610fc60

File tree

8 files changed

+148
-3
lines changed

8 files changed

+148
-3
lines changed

Cargo.lock

Lines changed: 67 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pprof = { version = "0.15.0", default-features = false, features = ["flamegraph"
2727
pyo3 = { version = "0.28", features = ["extension-module"] }
2828
pyo3-async-runtimes = { version = "0.28", features = ["async-std-runtime"] }
2929
rustls = "0.23"
30+
rustls-native-certs = "0.8"
3031
rustls-pemfile = "2"
3132
signal-hook = "0.4"
3233
socket2 = "0.6"

python/rsloop/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,17 @@ def __install_ssl_tracking() -> None:
104104
if getattr(context_cls, "_rsloop_tracking_installed", False):
105105
return
106106

107+
def mark_default_verify_paths(context):
108+
context.__dict__["_rsloop_use_default_verify_paths"] = True
109+
return context
110+
111+
orig_create_default_context = __ssl.create_default_context
107112
orig_load_cert_chain = context_cls.load_cert_chain
113+
orig_load_default_certs = context_cls.load_default_certs
114+
orig_set_default_verify_paths = context_cls.set_default_verify_paths
115+
116+
def create_default_context(*args, **kwargs):
117+
return mark_default_verify_paths(orig_create_default_context(*args, **kwargs))
108118

109119
def load_cert_chain(self, certfile, keyfile=None, password=None):
110120
result = orig_load_cert_chain(
@@ -128,7 +138,20 @@ def load_cert_chain(self, certfile, keyfile=None, password=None):
128138
self.__dict__["_rsloop_key_password"] = password_value
129139
return result
130140

141+
def load_default_certs(self, *args, **kwargs):
142+
result = orig_load_default_certs(self, *args, **kwargs)
143+
mark_default_verify_paths(self)
144+
return result
145+
146+
def set_default_verify_paths(self):
147+
result = orig_set_default_verify_paths(self)
148+
mark_default_verify_paths(self)
149+
return result
150+
151+
__ssl.create_default_context = create_default_context
131152
context_cls.load_cert_chain = load_cert_chain
153+
context_cls.load_default_certs = load_default_certs
154+
context_cls.set_default_verify_paths = set_default_verify_paths
132155
context_cls._rsloop_tracking_installed = True
133156

134157

src/python_api.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,7 +1994,7 @@ impl PyLoop {
19941994
})
19951995
}
19961996

1997-
#[pyo3(signature=(host, port, *, family=0, sock_type=0, proto=0, flags=0))]
1997+
#[pyo3(signature=(host, port, *, family=0, r#type=0, proto=0, flags=0))]
19981998
#[expect(
19991999
clippy::too_many_arguments,
20002000
reason = "Mirrors asyncio loop.getaddrinfo()"
@@ -2005,7 +2005,7 @@ impl PyLoop {
20052005
host: Option<Py<PyAny>>,
20062006
port: Option<Py<PyAny>>,
20072007
family: i32,
2008-
sock_type: i32,
2008+
r#type: i32,
20092009
proto: i32,
20102010
flags: i32,
20112011
) -> PyResult<Bound<'_, PyAny>> {
@@ -2016,7 +2016,7 @@ impl PyLoop {
20162016
let socket = py.import("socket")?;
20172017
let kwargs = PyDict::new(py);
20182018
kwargs.set_item("family", family)?;
2019-
kwargs.set_item("type", sock_type)?;
2019+
kwargs.set_item("type", r#type)?;
20202020
kwargs.set_item("proto", proto)?;
20212021
kwargs.set_item("flags", flags)?;
20222022

src/tls.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,26 @@ fn root_store_from_context(py: Python<'_>, ssl_context: &Py<PyAny>) -> PyResult<
208208
.add(CertificateDer::from(bytes.as_bytes().to_vec()))
209209
.map_err(to_py_tls_err)?;
210210
}
211+
212+
let use_default_verify_paths = ssl_context
213+
.bind(py)
214+
.getattr("__dict__")?
215+
.cast::<PyDict>()?
216+
.get_item("_rsloop_use_default_verify_paths")?
217+
.and_then(|value| value.extract::<bool>().ok())
218+
.unwrap_or(false);
219+
if use_default_verify_paths {
220+
let native = rustls_native_certs::load_native_certs();
221+
for error in native.errors {
222+
return Err(PyRuntimeError::new_err(format!(
223+
"failed to load native CA certificates: {error}"
224+
)));
225+
}
226+
for cert in native.certs {
227+
roots.add(cert).map_err(to_py_tls_err)?;
228+
}
229+
}
230+
211231
Ok(roots)
212232
}
213233

tests/packages/aiohttp_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import rsloop
2+
import aiohttp
3+
4+
5+
async def fetch_url(url):
6+
async with aiohttp.ClientSession() as session:
7+
async with session.get(url) as response:
8+
return await response.text()
9+
10+
11+
if __name__ == "__main__":
12+
html_content = rsloop.run(fetch_url("https://httpbin.org/get"))
13+
print(html_content)

tests/test_run.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import os
3+
import socket
34
import sys
45
import unittest
56

@@ -160,6 +161,21 @@ async def main() -> dict[str, object]:
160161
},
161162
)
162163

164+
def test_getaddrinfo_accepts_type_keyword(self) -> None:
165+
async def main() -> list[tuple[object, ...]]:
166+
loop = asyncio.get_running_loop()
167+
return await loop.getaddrinfo(
168+
"localhost",
169+
80,
170+
type=socket.SOCK_STREAM,
171+
)
172+
173+
addrinfos = rsloop.run(main())
174+
self.assertTrue(addrinfos)
175+
self.assertTrue(
176+
all(addrinfo[1] == socket.SOCK_STREAM for addrinfo in addrinfos),
177+
)
178+
163179

164180
if __name__ == "__main__":
165181
unittest.main()

tests/test_tls.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import pathlib
44
import socket
5+
import ssl
56
import tempfile
67
import unittest
78

@@ -42,6 +43,10 @@ def make_ssl_contexts(tmpdir: str):
4243

4344

4445
class TlsTests(unittest.TestCase):
46+
def test_create_default_context_marks_default_verify_paths(self) -> None:
47+
context = ssl.create_default_context()
48+
self.assertTrue(context.__dict__.get("_rsloop_use_default_verify_paths"))
49+
4550
def test_create_connection_and_server_tls_round_trip(self) -> None:
4651
async def main() -> str:
4752
loop = asyncio.get_running_loop()

0 commit comments

Comments
 (0)