diff --git a/examples/web-search-crawl.py b/examples/web-search-crawl.py index ac7ec512..6222d9dc 100644 --- a/examples/web-search-crawl.py +++ b/examples/web-search-crawl.py @@ -5,12 +5,11 @@ # "ollama", # ] # /// -import os from typing import Union from rich import print -from ollama import Client, WebCrawlResponse, WebSearchResponse +from ollama import WebCrawlResponse, WebSearchResponse, chat, web_crawl, web_search def format_tool_results(results: Union[WebSearchResponse, WebCrawlResponse]): @@ -49,15 +48,17 @@ def format_tool_results(results: Union[WebSearchResponse, WebCrawlResponse]): return '\n'.join(output).rstrip() -client = Client(headers={'Authorization': (os.getenv('OLLAMA_API_KEY'))}) -available_tools = {'web_search': client.web_search, 'web_crawl': client.web_crawl} +# Set OLLAMA_API_KEY in the environment variable or use the headers parameter to set the authorization header +# client = Client(headers={'Authorization': 'Bearer '}) + +available_tools = {'web_search': web_search, 'web_crawl': web_crawl} query = "ollama's new engine" print('Query: ', query) messages = [{'role': 'user', 'content': query}] while True: - response = client.chat(model='qwen3', messages=messages, tools=[client.web_search, client.web_crawl], think=True) + response = chat(model='qwen3', messages=messages, tools=[web_search, web_crawl], think=True) if response.message.thinking: print('Thinking: ') print(response.message.thinking + '\n\n') diff --git a/ollama/_client.py b/ollama/_client.py index 9abd5cc6..d6a26c69 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -94,20 +94,25 @@ def __init__( `kwargs` are passed to the httpx client. """ + headers = { + k.lower(): v + for k, v in { + **(headers or {}), + 'Content-Type': 'application/json', + 'Accept': 'application/json', + 'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}', + }.items() + if v is not None + } + api_key = os.getenv('OLLAMA_API_KEY', None) + if not headers.get('authorization') and api_key: + headers['authorization'] = f'Bearer {api_key}' + self._client = client( base_url=_parse_host(host or os.getenv('OLLAMA_HOST')), follow_redirects=follow_redirects, timeout=timeout, - # Lowercase all headers to ensure override - headers={ - k.lower(): v - for k, v in { - **(headers or {}), - 'Content-Type': 'application/json', - 'Accept': 'application/json', - 'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}', - }.items() - }, + headers=headers, **kwargs, ) @@ -638,7 +643,12 @@ def web_search(self, queries: Sequence[str], max_results: int = 3) -> WebSearchR Returns: WebSearchResponse with the search results + Raises: + ValueError: If OLLAMA_API_KEY environment variable is not set """ + if not self._client.headers.get('authorization', '').startswith('Bearer '): + raise ValueError('Authorization header with Bearer token is required for web search') + return self._request( WebSearchResponse, 'POST', @@ -658,7 +668,12 @@ def web_crawl(self, urls: Sequence[str]) -> WebCrawlResponse: Returns: WebCrawlResponse with the crawl results + Raises: + ValueError: If OLLAMA_API_KEY environment variable is not set """ + if not self._client.headers.get('authorization', '').startswith('Bearer '): + raise ValueError('Authorization header with Bearer token is required for web fetch') + return self._request( WebCrawlResponse, 'POST', diff --git a/tests/test_client.py b/tests/test_client.py index 6917edc2..17d5750b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1195,3 +1195,83 @@ async def test_arbitrary_roles_accepted_in_message_request_async(monkeypatch: py client = AsyncClient() await client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}]) + + +def test_client_web_search_requires_bearer_auth_header(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv('OLLAMA_API_KEY', raising=False) + + client = Client() + + with pytest.raises(ValueError, match='Authorization header with Bearer token is required for web search'): + client.web_search(['test query']) + + +def test_client_web_crawl_requires_bearer_auth_header(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv('OLLAMA_API_KEY', raising=False) + + client = Client() + + with pytest.raises(ValueError, match='Authorization header with Bearer token is required for web fetch'): + client.web_crawl(['https://example.com']) + + +def _mock_request_web_search(self, cls, method, url, json=None, **kwargs): + assert method == 'POST' + assert url == 'https://ollama.com/api/web_search' + assert json is not None and 'queries' in json and 'max_results' in json + return httpxResponse(status_code=200, content='{"results": {}, "success": true}') + + +def _mock_request_web_crawl(self, cls, method, url, json=None, **kwargs): + assert method == 'POST' + assert url == 'https://ollama.com/api/web_crawl' + assert json is not None and 'urls' in json + return httpxResponse(status_code=200, content='{"results": {}, "success": true}') + + +def test_client_web_search_with_env_api_key(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv('OLLAMA_API_KEY', 'test-key') + monkeypatch.setattr(Client, '_request', _mock_request_web_search) + + client = Client() + client.web_search(['what is ollama?'], max_results=2) + + +def test_client_web_crawl_with_env_api_key(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv('OLLAMA_API_KEY', 'test-key') + monkeypatch.setattr(Client, '_request', _mock_request_web_crawl) + + client = Client() + client.web_crawl(['https://example.com']) + + +def test_client_web_search_with_explicit_bearer_header(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv('OLLAMA_API_KEY', raising=False) + monkeypatch.setattr(Client, '_request', _mock_request_web_search) + + client = Client(headers={'Authorization': 'Bearer custom-token'}) + client.web_search(['what is ollama?'], max_results=1) + + +def test_client_web_crawl_with_explicit_bearer_header(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv('OLLAMA_API_KEY', raising=False) + monkeypatch.setattr(Client, '_request', _mock_request_web_crawl) + + client = Client(headers={'Authorization': 'Bearer custom-token'}) + client.web_crawl(['https://example.com']) + + +def test_client_bearer_header_from_env(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv('OLLAMA_API_KEY', 'env-token') + + client = Client() + assert client._client.headers['authorization'] == 'Bearer env-token' + + +def test_client_explicit_bearer_header_overrides_env(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv('OLLAMA_API_KEY', 'env-token') + monkeypatch.setattr(Client, '_request', _mock_request_web_search) + + client = Client(headers={'Authorization': 'Bearer explicit-token'}) + assert client._client.headers['authorization'] == 'Bearer explicit-token' + client.web_search(['override check'])