Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,19 +1353,19 @@ def _parse_host(host: Optional[str]) -> str:
>>> _parse_host('1.2.3.4:56789')
'http://1.2.3.4:56789'
>>> _parse_host('http://1.2.3.4')
'http://1.2.3.4:80'
'http://1.2.3.4'
>>> _parse_host('https://1.2.3.4')
'https://1.2.3.4:443'
'https://1.2.3.4'
>>> _parse_host('https://1.2.3.4:56789')
'https://1.2.3.4:56789'
>>> _parse_host('example.com')
'http://example.com:11434'
>>> _parse_host('example.com:56789')
'http://example.com:56789'
>>> _parse_host('http://example.com')
'http://example.com:80'
'http://example.com'
>>> _parse_host('https://example.com')
'https://example.com:443'
'https://example.com'
>>> _parse_host('https://example.com:56789')
'https://example.com:56789'
>>> _parse_host('example.com/')
Expand All @@ -1378,16 +1378,18 @@ def _parse_host(host: Optional[str]) -> str:
'http://example.com:56789/path'
>>> _parse_host('https://example.com:56789/path')
'https://example.com:56789/path'
>>> _parse_host('https://example.com/path')
'https://example.com/path'
>>> _parse_host('example.com:56789/path/')
'http://example.com:56789/path'
>>> _parse_host('[0001:002:003:0004::1]')
'http://[0001:002:003:0004::1]:11434'
>>> _parse_host('[0001:002:003:0004::1]:56789')
'http://[0001:002:003:0004::1]:56789'
>>> _parse_host('http://[0001:002:003:0004::1]')
'http://[0001:002:003:0004::1]:80'
'http://[0001:002:003:0004::1]'
>>> _parse_host('https://[0001:002:003:0004::1]')
'https://[0001:002:003:0004::1]:443'
'https://[0001:002:003:0004::1]'
>>> _parse_host('https://[0001:002:003:0004::1]:56789')
'https://[0001:002:003:0004::1]:56789'
>>> _parse_host('[0001:002:003:0004::1]/')
Expand All @@ -1400,22 +1402,21 @@ def _parse_host(host: Optional[str]) -> str:
'http://[0001:002:003:0004::1]:56789/path'
>>> _parse_host('https://[0001:002:003:0004::1]:56789/path')
'https://[0001:002:003:0004::1]:56789/path'
>>> _parse_host('https://[0001:002:003:0004::1]/path')
'https://[0001:002:003:0004::1]/path'
>>> _parse_host('[0001:002:003:0004::1]:56789/path/')
'http://[0001:002:003:0004::1]:56789/path'
"""

host, port = host or '', 11434
host, default_port = host or '', 11434
scheme, _, hostport = host.partition('://')
has_scheme = bool(hostport)
if not hostport:
scheme, hostport = 'http', host
elif scheme == 'http':
port = 80
elif scheme == 'https':
port = 443

split = urllib.parse.urlsplit(f'{scheme}://{hostport}')
host = split.hostname or '127.0.0.1'
port = split.port or port
port = split.port

try:
if isinstance(ipaddress.ip_address(host), ipaddress.IPv6Address):
Expand All @@ -1425,6 +1426,18 @@ def _parse_host(host: Optional[str]) -> str:
...

if path := split.path.strip('/'):
if port is None:
if has_scheme:
return f'{scheme}://{host}/{path}'

port = default_port

return f'{scheme}://{host}:{port}/{path}'

if port is None:
if has_scheme:
return f'{scheme}://{host}'

port = default_port

return f'{scheme}://{host}:{port}'
18 changes: 17 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytest_httpserver import HTTPServer, URIPattern
from werkzeug.wrappers import Request, Response

from ollama._client import CONNECTION_ERROR_MESSAGE, AsyncClient, Client, _copy_tools
from ollama._client import CONNECTION_ERROR_MESSAGE, AsyncClient, Client, _copy_tools, _parse_host
from ollama._types import Image, Message

PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'
Expand All @@ -26,6 +26,22 @@ def anyio_backend():
return 'asyncio'


@pytest.mark.parametrize(
('host', 'expected'),
[
('example.com', 'http://example.com:11434'),
('http://example.com', 'http://example.com'),
('https://example.com', 'https://example.com'),
('https://example.com/path', 'https://example.com/path'),
('http://[0001:002:003:0004::1]', 'http://[0001:002:003:0004::1]'),
('https://[0001:002:003:0004::1]', 'https://[0001:002:003:0004::1]'),
('https://[0001:002:003:0004::1]/path', 'https://[0001:002:003:0004::1]/path'),
],
)
def test_parse_host_default_ports(host: str, expected: str):
assert _parse_host(host) == expected


class PrefixPattern(URIPattern):
def __init__(self, prefix: str):
self.prefix = prefix
Expand Down