|
| 1 | +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. |
| 2 | +# SPDX-License-Identifier: AGPL-3.0 |
| 3 | +"""Tests for network_guard SSRF protection utilities.""" |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +from unittest.mock import AsyncMock, patch |
| 8 | + |
| 9 | +import pytest |
| 10 | + |
| 11 | +from openviking.utils.network_guard import ( |
| 12 | + _is_public_ip, |
| 13 | + _normalize_host, |
| 14 | + _resolve_host_addresses, |
| 15 | + build_httpx_request_validation_hooks, |
| 16 | + ensure_public_remote_target, |
| 17 | + extract_remote_host, |
| 18 | +) |
| 19 | +from openviking_cli.exceptions import PermissionDeniedError |
| 20 | + |
| 21 | + |
| 22 | +# ── extract_remote_host ────────────────────────────────────────────────────── |
| 23 | + |
| 24 | + |
| 25 | +class TestExtractRemoteHost: |
| 26 | + """Verify host extraction from URLs and git SSH addresses.""" |
| 27 | + |
| 28 | + @pytest.mark.parametrize( |
| 29 | + ("source", "expected"), |
| 30 | + [ |
| 31 | + ("https://example.com/repo.git", "example.com"), |
| 32 | + ("http://example.com:8080/path", "example.com"), |
| 33 | + ("https://sub.domain.example.com/foo", "sub.domain.example.com"), |
| 34 | + ("ftp://files.example.org/data.zip", "files.example.org"), |
| 35 | + ], |
| 36 | + ) |
| 37 | + def test_extracts_host_from_http_urls(self, source: str, expected: str) -> None: |
| 38 | + assert extract_remote_host(source) == expected |
| 39 | + |
| 40 | + @pytest.mark.parametrize( |
| 41 | + ("source", "expected"), |
| 42 | + [ |
| 43 | + ("git@github.com:user/repo.git", "github.com"), |
| 44 | + ("git@gitlab.com:group/project.git", "gitlab.com"), |
| 45 | + ("git@[::1]:user/repo.git", "::1"), |
| 46 | + ], |
| 47 | + ) |
| 48 | + def test_extracts_host_from_git_ssh(self, source: str, expected: str) -> None: |
| 49 | + assert extract_remote_host(source) == expected |
| 50 | + |
| 51 | + def test_git_ssh_missing_colon_returns_none(self) -> None: |
| 52 | + assert extract_remote_host("git@github.com") is None |
| 53 | + |
| 54 | + def test_url_without_hostname_returns_none(self) -> None: |
| 55 | + assert extract_remote_host("/just/a/path") is None |
| 56 | + |
| 57 | + def test_empty_string_returns_none(self) -> None: |
| 58 | + assert extract_remote_host("") is None |
| 59 | + |
| 60 | + def test_strips_brackets_from_ipv6_host(self) -> None: |
| 61 | + result = extract_remote_host("http://[::1]:8080/path") |
| 62 | + assert result == "::1" |
| 63 | + |
| 64 | + |
| 65 | +# ── _normalize_host ────────────────────────────────────────────────────────── |
| 66 | + |
| 67 | + |
| 68 | +class TestNormalizeHost: |
| 69 | + """Verify trailing-dot stripping and lowercasing.""" |
| 70 | + |
| 71 | + def test_strips_trailing_dot(self) -> None: |
| 72 | + assert _normalize_host("example.com.") == "example.com" |
| 73 | + |
| 74 | + def test_lowercases_host(self) -> None: |
| 75 | + assert _normalize_host("EXAMPLE.COM") == "example.com" |
| 76 | + |
| 77 | + def test_strips_dot_and_lowercases(self) -> None: |
| 78 | + assert _normalize_host("Example.COM.") == "example.com" |
| 79 | + |
| 80 | + |
| 81 | +# ── _is_public_ip ─────────────────────────────────────────────────────────── |
| 82 | + |
| 83 | + |
| 84 | +class TestIsPublicIP: |
| 85 | + """Verify classification of public vs non-public IPs.""" |
| 86 | + |
| 87 | + @pytest.mark.parametrize( |
| 88 | + "address", |
| 89 | + [ |
| 90 | + "8.8.8.8", |
| 91 | + "1.1.1.1", |
| 92 | + "151.101.1.67", |
| 93 | + "2607:f8b0:4004:800::200e", # Google IPv6 |
| 94 | + ], |
| 95 | + ) |
| 96 | + def test_public_addresses_are_global(self, address: str) -> None: |
| 97 | + assert _is_public_ip(address) is True |
| 98 | + |
| 99 | + @pytest.mark.parametrize( |
| 100 | + "address", |
| 101 | + [ |
| 102 | + "127.0.0.1", |
| 103 | + "10.0.0.1", |
| 104 | + "172.16.0.1", |
| 105 | + "172.31.255.255", |
| 106 | + "192.168.1.1", |
| 107 | + "0.0.0.0", |
| 108 | + "169.254.1.1", # link-local |
| 109 | + "::1", |
| 110 | + "fe80::1", # IPv6 link-local |
| 111 | + "fc00::1", # IPv6 ULA |
| 112 | + "::ffff:127.0.0.1", # IPv4-mapped IPv6 loopback |
| 113 | + "::ffff:10.0.0.1", # IPv4-mapped IPv6 private |
| 114 | + "::ffff:192.168.1.1", # IPv4-mapped IPv6 private |
| 115 | + ], |
| 116 | + ) |
| 117 | + def test_non_public_addresses_are_not_global(self, address: str) -> None: |
| 118 | + assert _is_public_ip(address) is False |
| 119 | + |
| 120 | + def test_invalid_address_returns_false(self) -> None: |
| 121 | + assert _is_public_ip("not-an-ip") is False |
| 122 | + |
| 123 | + def test_empty_string_returns_false(self) -> None: |
| 124 | + assert _is_public_ip("") is False |
| 125 | + |
| 126 | + |
| 127 | +# ── _resolve_host_addresses ────────────────────────────────────────────────── |
| 128 | + |
| 129 | + |
| 130 | +class TestResolveHostAddresses: |
| 131 | + """Verify DNS resolution wrapper behavior.""" |
| 132 | + |
| 133 | + def test_returns_empty_set_for_unresolvable_host(self) -> None: |
| 134 | + result = _resolve_host_addresses("this.host.definitely.does.not.exist.invalid") |
| 135 | + assert result == set() |
| 136 | + |
| 137 | + def test_returns_empty_set_for_unicode_error(self) -> None: |
| 138 | + # A hostname that triggers UnicodeError in getaddrinfo |
| 139 | + result = _resolve_host_addresses("\udcff.invalid") |
| 140 | + assert result == set() |
| 141 | + |
| 142 | + @patch("openviking.utils.network_guard.socket.getaddrinfo") |
| 143 | + def test_strips_ipv6_scope_id(self, mock_getaddrinfo) -> None: |
| 144 | + import socket |
| 145 | + |
| 146 | + mock_getaddrinfo.return_value = [ |
| 147 | + (socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("fe80::1%eth0", 0, 0, 0)), |
| 148 | + ] |
| 149 | + result = _resolve_host_addresses("some-host") |
| 150 | + assert "fe80::1" in result |
| 151 | + assert "fe80::1%eth0" not in result |
| 152 | + |
| 153 | + @patch("openviking.utils.network_guard.socket.getaddrinfo") |
| 154 | + def test_skips_non_inet_families(self, mock_getaddrinfo) -> None: |
| 155 | + mock_getaddrinfo.return_value = [ |
| 156 | + (999, 1, 0, "", ("1.2.3.4", 0)), # unknown AF |
| 157 | + ] |
| 158 | + result = _resolve_host_addresses("some-host") |
| 159 | + assert result == set() |
| 160 | + |
| 161 | + |
| 162 | +# ── ensure_public_remote_target ────────────────────────────────────────────── |
| 163 | + |
| 164 | + |
| 165 | +class TestEnsurePublicRemoteTarget: |
| 166 | + """End-to-end SSRF protection tests.""" |
| 167 | + |
| 168 | + # -- Rejection: no valid host -- |
| 169 | + |
| 170 | + def test_rejects_empty_source(self) -> None: |
| 171 | + with pytest.raises(PermissionDeniedError, match="valid destination host"): |
| 172 | + ensure_public_remote_target("") |
| 173 | + |
| 174 | + def test_rejects_bare_path(self) -> None: |
| 175 | + with pytest.raises(PermissionDeniedError, match="valid destination host"): |
| 176 | + ensure_public_remote_target("/etc/passwd") |
| 177 | + |
| 178 | + def test_rejects_git_ssh_without_colon(self) -> None: |
| 179 | + with pytest.raises(PermissionDeniedError, match="valid destination host"): |
| 180 | + ensure_public_remote_target("git@github.com") |
| 181 | + |
| 182 | + # -- Rejection: localhost variants -- |
| 183 | + |
| 184 | + @pytest.mark.parametrize( |
| 185 | + "source", |
| 186 | + [ |
| 187 | + "http://localhost/path", |
| 188 | + "http://localhost.localdomain/path", |
| 189 | + "http://LOCALHOST/path", |
| 190 | + "http://sub.localhost/path", |
| 191 | + "http://anything.localhost/path", |
| 192 | + ], |
| 193 | + ) |
| 194 | + def test_rejects_localhost_variants(self, source: str) -> None: |
| 195 | + with pytest.raises(PermissionDeniedError, match="non-public"): |
| 196 | + ensure_public_remote_target(source) |
| 197 | + |
| 198 | + def test_rejects_localhost_with_trailing_dot(self) -> None: |
| 199 | + with pytest.raises(PermissionDeniedError, match="non-public"): |
| 200 | + ensure_public_remote_target("http://localhost./path") |
| 201 | + |
| 202 | + # -- Rejection: non-public resolved IPs -- |
| 203 | + |
| 204 | + @pytest.mark.parametrize( |
| 205 | + ("source", "resolved_ip"), |
| 206 | + [ |
| 207 | + ("http://evil.attacker.com/path", "127.0.0.1"), |
| 208 | + ("http://evil.attacker.com/path", "10.0.0.1"), |
| 209 | + ("http://evil.attacker.com/path", "172.16.0.1"), |
| 210 | + ("http://evil.attacker.com/path", "192.168.1.1"), |
| 211 | + ("http://evil.attacker.com/path", "0.0.0.0"), |
| 212 | + ("http://evil.attacker.com/path", "::1"), |
| 213 | + ("http://evil.attacker.com/path", "fe80::1"), |
| 214 | + ("http://evil.attacker.com/path", "::ffff:127.0.0.1"), |
| 215 | + ("http://evil.attacker.com/path", "::ffff:10.0.0.1"), |
| 216 | + ("http://evil.attacker.com/path", "169.254.169.254"), # AWS metadata |
| 217 | + ], |
| 218 | + ) |
| 219 | + @patch("openviking.utils.network_guard._resolve_host_addresses") |
| 220 | + def test_rejects_non_public_resolved_addresses( |
| 221 | + self, mock_resolve, source: str, resolved_ip: str |
| 222 | + ) -> None: |
| 223 | + mock_resolve.return_value = {resolved_ip} |
| 224 | + with pytest.raises(PermissionDeniedError, match="non-public address"): |
| 225 | + ensure_public_remote_target(source) |
| 226 | + |
| 227 | + # -- Rejection: DNS rebinding with mixed results -- |
| 228 | + |
| 229 | + @patch("openviking.utils.network_guard._resolve_host_addresses") |
| 230 | + def test_rejects_when_any_resolved_address_is_non_public(self, mock_resolve) -> None: |
| 231 | + """DNS rebinding: even if some IPs are public, one private IP is enough to reject.""" |
| 232 | + mock_resolve.return_value = {"8.8.8.8", "127.0.0.1"} |
| 233 | + with pytest.raises(PermissionDeniedError, match="non-public address"): |
| 234 | + ensure_public_remote_target("http://rebinding.attacker.com/path") |
| 235 | + |
| 236 | + # -- Pass-through: valid public targets -- |
| 237 | + |
| 238 | + @patch("openviking.utils.network_guard._resolve_host_addresses") |
| 239 | + def test_allows_public_http_url(self, mock_resolve) -> None: |
| 240 | + mock_resolve.return_value = {"151.101.1.67"} |
| 241 | + ensure_public_remote_target("https://github.com/repo.git") # should not raise |
| 242 | + |
| 243 | + @patch("openviking.utils.network_guard._resolve_host_addresses") |
| 244 | + def test_allows_public_git_ssh(self, mock_resolve) -> None: |
| 245 | + mock_resolve.return_value = {"140.82.121.4"} |
| 246 | + ensure_public_remote_target("git@github.com:user/repo.git") # should not raise |
| 247 | + |
| 248 | + @patch("openviking.utils.network_guard._resolve_host_addresses") |
| 249 | + def test_allows_when_dns_returns_empty(self, mock_resolve) -> None: |
| 250 | + """Unresolvable host is allowed through (fail-open for DNS).""" |
| 251 | + mock_resolve.return_value = set() |
| 252 | + ensure_public_remote_target("http://new-host.example.com/path") # should not raise |
| 253 | + |
| 254 | + @patch("openviking.utils.network_guard._resolve_host_addresses") |
| 255 | + def test_allows_multiple_public_addresses(self, mock_resolve) -> None: |
| 256 | + mock_resolve.return_value = {"8.8.8.8", "8.8.4.4"} |
| 257 | + ensure_public_remote_target("http://dns-rr.example.com/path") # should not raise |
| 258 | + |
| 259 | + |
| 260 | +# ── build_httpx_request_validation_hooks ───────────────────────────────────── |
| 261 | + |
| 262 | + |
| 263 | +class TestBuildHttpxRequestValidationHooks: |
| 264 | + """Verify httpx hook construction.""" |
| 265 | + |
| 266 | + def test_returns_none_when_no_validator(self) -> None: |
| 267 | + assert build_httpx_request_validation_hooks(None) is None |
| 268 | + |
| 269 | + def test_returns_request_hook_dict(self) -> None: |
| 270 | + def dummy_validator(url: str) -> None: |
| 271 | + pass |
| 272 | + |
| 273 | + hooks = build_httpx_request_validation_hooks(dummy_validator) |
| 274 | + assert hooks is not None |
| 275 | + assert "request" in hooks |
| 276 | + assert len(hooks["request"]) == 1 |
| 277 | + |
| 278 | + @pytest.mark.asyncio |
| 279 | + async def test_hook_calls_validator_with_url(self) -> None: |
| 280 | + calls: list[str] = [] |
| 281 | + |
| 282 | + def tracking_validator(url: str) -> None: |
| 283 | + calls.append(url) |
| 284 | + |
| 285 | + hooks = build_httpx_request_validation_hooks(tracking_validator) |
| 286 | + assert hooks is not None |
| 287 | + |
| 288 | + mock_request = AsyncMock() |
| 289 | + mock_request.url = "http://example.com/test" |
| 290 | + |
| 291 | + hook_fn = hooks["request"][0] |
| 292 | + await hook_fn(mock_request) |
| 293 | + |
| 294 | + assert calls == ["http://example.com/test"] |
| 295 | + |
| 296 | + @pytest.mark.asyncio |
| 297 | + async def test_hook_propagates_validator_exception(self) -> None: |
| 298 | + def failing_validator(url: str) -> None: |
| 299 | + raise PermissionDeniedError("blocked") |
| 300 | + |
| 301 | + hooks = build_httpx_request_validation_hooks(failing_validator) |
| 302 | + assert hooks is not None |
| 303 | + |
| 304 | + mock_request = AsyncMock() |
| 305 | + mock_request.url = "http://evil.com" |
| 306 | + |
| 307 | + with pytest.raises(PermissionDeniedError, match="blocked"): |
| 308 | + await hooks["request"][0](mock_request) |
0 commit comments