Skip to content

Commit b153ad9

Browse files
sjhddhclaude
andauthored
test(security): add unit tests for network_guard and zip_safe modules (#1395)
* test(security): add unit tests for network_guard and zip_safe modules Add comprehensive test coverage for two security-critical modules that previously had zero tests: - network_guard: SSRF protection, internal IP blocking, DNS rebinding edge cases, protocol validation, malformed URL handling - zip_safe: Zip Slip traversal prevention, path normalization, special character handling These modules protect against OWASP Top 10 vulnerabilities (SSRF, path traversal) and should have regression tests to prevent accidental weakening of security boundaries. * test: add missing assertion for CJK filename normalization The test_repairs_cjk_filename_from_cp437_mojibake test was missing a final assertion to verify that normalize_zip_filenames() actually repaired the mojibake filename back to the original CJK name. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3b4dfd3 commit b153ad9

2 files changed

Lines changed: 681 additions & 0 deletions

File tree

tests/misc/test_network_guard.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
2+
# SPDX-License-Identifier: AGPL-3.0
3+
"""Tests for network_guard SSRF protection utilities."""
4+
5+
from __future__ import annotations
6+
7+
from unittest.mock import AsyncMock, patch
8+
9+
import pytest
10+
11+
from openviking.utils.network_guard import (
12+
_is_public_ip,
13+
_normalize_host,
14+
_resolve_host_addresses,
15+
build_httpx_request_validation_hooks,
16+
ensure_public_remote_target,
17+
extract_remote_host,
18+
)
19+
from openviking_cli.exceptions import PermissionDeniedError
20+
21+
22+
# ── extract_remote_host ──────────────────────────────────────────────────────
23+
24+
25+
class TestExtractRemoteHost:
26+
"""Verify host extraction from URLs and git SSH addresses."""
27+
28+
@pytest.mark.parametrize(
29+
("source", "expected"),
30+
[
31+
("https://example.com/repo.git", "example.com"),
32+
("http://example.com:8080/path", "example.com"),
33+
("https://sub.domain.example.com/foo", "sub.domain.example.com"),
34+
("ftp://files.example.org/data.zip", "files.example.org"),
35+
],
36+
)
37+
def test_extracts_host_from_http_urls(self, source: str, expected: str) -> None:
38+
assert extract_remote_host(source) == expected
39+
40+
@pytest.mark.parametrize(
41+
("source", "expected"),
42+
[
43+
("git@github.com:user/repo.git", "github.com"),
44+
("git@gitlab.com:group/project.git", "gitlab.com"),
45+
("git@[::1]:user/repo.git", "::1"),
46+
],
47+
)
48+
def test_extracts_host_from_git_ssh(self, source: str, expected: str) -> None:
49+
assert extract_remote_host(source) == expected
50+
51+
def test_git_ssh_missing_colon_returns_none(self) -> None:
52+
assert extract_remote_host("git@github.com") is None
53+
54+
def test_url_without_hostname_returns_none(self) -> None:
55+
assert extract_remote_host("/just/a/path") is None
56+
57+
def test_empty_string_returns_none(self) -> None:
58+
assert extract_remote_host("") is None
59+
60+
def test_strips_brackets_from_ipv6_host(self) -> None:
61+
result = extract_remote_host("http://[::1]:8080/path")
62+
assert result == "::1"
63+
64+
65+
# ── _normalize_host ──────────────────────────────────────────────────────────
66+
67+
68+
class TestNormalizeHost:
69+
"""Verify trailing-dot stripping and lowercasing."""
70+
71+
def test_strips_trailing_dot(self) -> None:
72+
assert _normalize_host("example.com.") == "example.com"
73+
74+
def test_lowercases_host(self) -> None:
75+
assert _normalize_host("EXAMPLE.COM") == "example.com"
76+
77+
def test_strips_dot_and_lowercases(self) -> None:
78+
assert _normalize_host("Example.COM.") == "example.com"
79+
80+
81+
# ── _is_public_ip ───────────────────────────────────────────────────────────
82+
83+
84+
class TestIsPublicIP:
85+
"""Verify classification of public vs non-public IPs."""
86+
87+
@pytest.mark.parametrize(
88+
"address",
89+
[
90+
"8.8.8.8",
91+
"1.1.1.1",
92+
"151.101.1.67",
93+
"2607:f8b0:4004:800::200e", # Google IPv6
94+
],
95+
)
96+
def test_public_addresses_are_global(self, address: str) -> None:
97+
assert _is_public_ip(address) is True
98+
99+
@pytest.mark.parametrize(
100+
"address",
101+
[
102+
"127.0.0.1",
103+
"10.0.0.1",
104+
"172.16.0.1",
105+
"172.31.255.255",
106+
"192.168.1.1",
107+
"0.0.0.0",
108+
"169.254.1.1", # link-local
109+
"::1",
110+
"fe80::1", # IPv6 link-local
111+
"fc00::1", # IPv6 ULA
112+
"::ffff:127.0.0.1", # IPv4-mapped IPv6 loopback
113+
"::ffff:10.0.0.1", # IPv4-mapped IPv6 private
114+
"::ffff:192.168.1.1", # IPv4-mapped IPv6 private
115+
],
116+
)
117+
def test_non_public_addresses_are_not_global(self, address: str) -> None:
118+
assert _is_public_ip(address) is False
119+
120+
def test_invalid_address_returns_false(self) -> None:
121+
assert _is_public_ip("not-an-ip") is False
122+
123+
def test_empty_string_returns_false(self) -> None:
124+
assert _is_public_ip("") is False
125+
126+
127+
# ── _resolve_host_addresses ──────────────────────────────────────────────────
128+
129+
130+
class TestResolveHostAddresses:
131+
"""Verify DNS resolution wrapper behavior."""
132+
133+
def test_returns_empty_set_for_unresolvable_host(self) -> None:
134+
result = _resolve_host_addresses("this.host.definitely.does.not.exist.invalid")
135+
assert result == set()
136+
137+
def test_returns_empty_set_for_unicode_error(self) -> None:
138+
# A hostname that triggers UnicodeError in getaddrinfo
139+
result = _resolve_host_addresses("\udcff.invalid")
140+
assert result == set()
141+
142+
@patch("openviking.utils.network_guard.socket.getaddrinfo")
143+
def test_strips_ipv6_scope_id(self, mock_getaddrinfo) -> None:
144+
import socket
145+
146+
mock_getaddrinfo.return_value = [
147+
(socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("fe80::1%eth0", 0, 0, 0)),
148+
]
149+
result = _resolve_host_addresses("some-host")
150+
assert "fe80::1" in result
151+
assert "fe80::1%eth0" not in result
152+
153+
@patch("openviking.utils.network_guard.socket.getaddrinfo")
154+
def test_skips_non_inet_families(self, mock_getaddrinfo) -> None:
155+
mock_getaddrinfo.return_value = [
156+
(999, 1, 0, "", ("1.2.3.4", 0)), # unknown AF
157+
]
158+
result = _resolve_host_addresses("some-host")
159+
assert result == set()
160+
161+
162+
# ── ensure_public_remote_target ──────────────────────────────────────────────
163+
164+
165+
class TestEnsurePublicRemoteTarget:
166+
"""End-to-end SSRF protection tests."""
167+
168+
# -- Rejection: no valid host --
169+
170+
def test_rejects_empty_source(self) -> None:
171+
with pytest.raises(PermissionDeniedError, match="valid destination host"):
172+
ensure_public_remote_target("")
173+
174+
def test_rejects_bare_path(self) -> None:
175+
with pytest.raises(PermissionDeniedError, match="valid destination host"):
176+
ensure_public_remote_target("/etc/passwd")
177+
178+
def test_rejects_git_ssh_without_colon(self) -> None:
179+
with pytest.raises(PermissionDeniedError, match="valid destination host"):
180+
ensure_public_remote_target("git@github.com")
181+
182+
# -- Rejection: localhost variants --
183+
184+
@pytest.mark.parametrize(
185+
"source",
186+
[
187+
"http://localhost/path",
188+
"http://localhost.localdomain/path",
189+
"http://LOCALHOST/path",
190+
"http://sub.localhost/path",
191+
"http://anything.localhost/path",
192+
],
193+
)
194+
def test_rejects_localhost_variants(self, source: str) -> None:
195+
with pytest.raises(PermissionDeniedError, match="non-public"):
196+
ensure_public_remote_target(source)
197+
198+
def test_rejects_localhost_with_trailing_dot(self) -> None:
199+
with pytest.raises(PermissionDeniedError, match="non-public"):
200+
ensure_public_remote_target("http://localhost./path")
201+
202+
# -- Rejection: non-public resolved IPs --
203+
204+
@pytest.mark.parametrize(
205+
("source", "resolved_ip"),
206+
[
207+
("http://evil.attacker.com/path", "127.0.0.1"),
208+
("http://evil.attacker.com/path", "10.0.0.1"),
209+
("http://evil.attacker.com/path", "172.16.0.1"),
210+
("http://evil.attacker.com/path", "192.168.1.1"),
211+
("http://evil.attacker.com/path", "0.0.0.0"),
212+
("http://evil.attacker.com/path", "::1"),
213+
("http://evil.attacker.com/path", "fe80::1"),
214+
("http://evil.attacker.com/path", "::ffff:127.0.0.1"),
215+
("http://evil.attacker.com/path", "::ffff:10.0.0.1"),
216+
("http://evil.attacker.com/path", "169.254.169.254"), # AWS metadata
217+
],
218+
)
219+
@patch("openviking.utils.network_guard._resolve_host_addresses")
220+
def test_rejects_non_public_resolved_addresses(
221+
self, mock_resolve, source: str, resolved_ip: str
222+
) -> None:
223+
mock_resolve.return_value = {resolved_ip}
224+
with pytest.raises(PermissionDeniedError, match="non-public address"):
225+
ensure_public_remote_target(source)
226+
227+
# -- Rejection: DNS rebinding with mixed results --
228+
229+
@patch("openviking.utils.network_guard._resolve_host_addresses")
230+
def test_rejects_when_any_resolved_address_is_non_public(self, mock_resolve) -> None:
231+
"""DNS rebinding: even if some IPs are public, one private IP is enough to reject."""
232+
mock_resolve.return_value = {"8.8.8.8", "127.0.0.1"}
233+
with pytest.raises(PermissionDeniedError, match="non-public address"):
234+
ensure_public_remote_target("http://rebinding.attacker.com/path")
235+
236+
# -- Pass-through: valid public targets --
237+
238+
@patch("openviking.utils.network_guard._resolve_host_addresses")
239+
def test_allows_public_http_url(self, mock_resolve) -> None:
240+
mock_resolve.return_value = {"151.101.1.67"}
241+
ensure_public_remote_target("https://github.com/repo.git") # should not raise
242+
243+
@patch("openviking.utils.network_guard._resolve_host_addresses")
244+
def test_allows_public_git_ssh(self, mock_resolve) -> None:
245+
mock_resolve.return_value = {"140.82.121.4"}
246+
ensure_public_remote_target("git@github.com:user/repo.git") # should not raise
247+
248+
@patch("openviking.utils.network_guard._resolve_host_addresses")
249+
def test_allows_when_dns_returns_empty(self, mock_resolve) -> None:
250+
"""Unresolvable host is allowed through (fail-open for DNS)."""
251+
mock_resolve.return_value = set()
252+
ensure_public_remote_target("http://new-host.example.com/path") # should not raise
253+
254+
@patch("openviking.utils.network_guard._resolve_host_addresses")
255+
def test_allows_multiple_public_addresses(self, mock_resolve) -> None:
256+
mock_resolve.return_value = {"8.8.8.8", "8.8.4.4"}
257+
ensure_public_remote_target("http://dns-rr.example.com/path") # should not raise
258+
259+
260+
# ── build_httpx_request_validation_hooks ─────────────────────────────────────
261+
262+
263+
class TestBuildHttpxRequestValidationHooks:
264+
"""Verify httpx hook construction."""
265+
266+
def test_returns_none_when_no_validator(self) -> None:
267+
assert build_httpx_request_validation_hooks(None) is None
268+
269+
def test_returns_request_hook_dict(self) -> None:
270+
def dummy_validator(url: str) -> None:
271+
pass
272+
273+
hooks = build_httpx_request_validation_hooks(dummy_validator)
274+
assert hooks is not None
275+
assert "request" in hooks
276+
assert len(hooks["request"]) == 1
277+
278+
@pytest.mark.asyncio
279+
async def test_hook_calls_validator_with_url(self) -> None:
280+
calls: list[str] = []
281+
282+
def tracking_validator(url: str) -> None:
283+
calls.append(url)
284+
285+
hooks = build_httpx_request_validation_hooks(tracking_validator)
286+
assert hooks is not None
287+
288+
mock_request = AsyncMock()
289+
mock_request.url = "http://example.com/test"
290+
291+
hook_fn = hooks["request"][0]
292+
await hook_fn(mock_request)
293+
294+
assert calls == ["http://example.com/test"]
295+
296+
@pytest.mark.asyncio
297+
async def test_hook_propagates_validator_exception(self) -> None:
298+
def failing_validator(url: str) -> None:
299+
raise PermissionDeniedError("blocked")
300+
301+
hooks = build_httpx_request_validation_hooks(failing_validator)
302+
assert hooks is not None
303+
304+
mock_request = AsyncMock()
305+
mock_request.url = "http://evil.com"
306+
307+
with pytest.raises(PermissionDeniedError, match="blocked"):
308+
await hooks["request"][0](mock_request)

0 commit comments

Comments
 (0)