Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/ollama/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai>=2.22.0", "ollama>=0.5.0", "pydantic"]
dependencies = ["haystack-ai>=2.22.0", "ollama>=0.5.0", "pydantic", "tenacity>=8.2.3"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason for >=8.2.3 for tenacity?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the initial commit, I did not include this, so the Python 3.13 check built with an older version of tenacity that did not include either one this retry, retry_if_exception, or wait_exponential, which caused the verification checks to fail. Because of this, I had to explicitly specify them.


[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from pydantic.json_schema import JsonSchemaValue
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential

from ollama import AsyncClient, ChatResponse, Client

Expand Down Expand Up @@ -216,6 +217,7 @@ def __init__(
url: str = "http://localhost:11434",
generation_kwargs: dict[str, Any] | None = None,
timeout: int = 120,
max_retries: int = 0,
keep_alive: float | str | None = None,
streaming_callback: Callable[[StreamingChunk], None] | None = None,
tools: ToolsType | None = None,
Expand All @@ -233,6 +235,8 @@ def __init__(
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout:
The number of seconds before throwing a timeout error from the Ollama API.
:param max_retries:
Maximum number of retries to attempt for failed requests.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lest's make this parameter more descriptive.

Suggested change
:param max_retries:
Maximum number of retries to attempt for failed requests.
:param max_retries:
Maximum number of retries to attempt for failed requests (HTTP 429, 5xx, connection/timeout errors).
Uses exponential backoff between attempts. Set to 0 (default) to disable retries.

:param think:
If True, the model will "think" before producing a response.
Only [thinking models](https://ollama.com/search?c=thinking) support this feature.
Expand Down Expand Up @@ -268,6 +272,7 @@ def __init__(
self.url = url
self.generation_kwargs = generation_kwargs or {}
self.timeout = timeout
self.max_retries = max_retries
self.keep_alive = keep_alive
self.streaming_callback = streaming_callback
self.tools = tools # Store original tools for serialization
Expand All @@ -292,6 +297,7 @@ def to_dict(self) -> dict[str, Any]:
url=self.url,
generation_kwargs=self.generation_kwargs,
timeout=self.timeout,
max_retries=self.max_retries,
keep_alive=self.keep_alive,
streaming_callback=callback_name,
tools=serialize_tools_or_toolset(self.tools),
Expand Down Expand Up @@ -518,16 +524,25 @@ def run(

ollama_messages = [_convert_chatmessage_to_ollama_format(m) for m in messages]

response = self._client.chat(
model=self.model,
messages=ollama_messages,
tools=ollama_tools,
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
keep_alive=self.keep_alive,
options=generation_kwargs,
format=self.response_format,
think=self.think,
@retry(
reraise=True,
stop=stop_after_attempt(self.max_retries + 1),
retry=retry_if_exception_type(Exception),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retrying on all kind of exceptions is too broad.

Copy link
Copy Markdown
Contributor Author

@Keyur-S-Patel Keyur-S-Patel Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retry on 429 and 5xx?

Lemme know if you want to retry on something else as well.

wait=wait_exponential(),
)
def chat_with_retry() -> ChatResponse | Iterator[ChatResponse]:
return self._client.chat(
model=self.model,
messages=ollama_messages,
tools=ollama_tools,
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
keep_alive=self.keep_alive,
options=generation_kwargs,
format=self.response_format,
think=self.think,
)

response = chat_with_retry()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have this not nested inside of run?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Retry annotation requires function so can

  1. use wrapper function or
  2. Build new method and put retry annotation there

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will separate out function in next commit


if isinstance(response, Iterator):
return self._handle_streaming_response(response_iter=response, callback=callback)
Expand Down Expand Up @@ -579,16 +594,25 @@ async def run_async(

ollama_messages = [_convert_chatmessage_to_ollama_format(m) for m in messages]

response = await self._async_client.chat(
model=self.model,
messages=ollama_messages,
tools=ollama_tools,
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
keep_alive=self.keep_alive,
options=generation_kwargs,
format=self.response_format,
think=self.think,
@retry(
reraise=True,
stop=stop_after_attempt(self.max_retries + 1),
retry=retry_if_exception_type(Exception),
wait=wait_exponential(),
)
async def chat_with_retry() -> ChatResponse | AsyncIterator[ChatResponse]:
return await self._async_client.chat(
model=self.model,
messages=ollama_messages,
tools=ollama_tools,
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
keep_alive=self.keep_alive,
options=generation_kwargs,
format=self.response_format,
think=self.think,
)

response = await chat_with_retry()

if isinstance(response, AsyncIterator):
# response is an async iterator for streaming
Expand Down
38 changes: 38 additions & 0 deletions integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ def test_init_default(self):
assert component.url == "http://localhost:11434"
assert component.generation_kwargs == {}
assert component.timeout == 120
assert component.max_retries == 0
assert component.streaming_callback is None
assert component.tools is None
assert component.keep_alive is None
Expand All @@ -529,6 +530,7 @@ def test_init(self, tools):
url="http://my-custom-endpoint:11434",
generation_kwargs={"temperature": 0.5},
timeout=5,
max_retries=2,
keep_alive="10m",
streaming_callback=print_streaming_chunk,
tools=tools,
Expand All @@ -539,6 +541,7 @@ def test_init(self, tools):
assert component.url == "http://my-custom-endpoint:11434"
assert component.generation_kwargs == {"temperature": 0.5}
assert component.timeout == 5
assert component.max_retries == 2
assert component.keep_alive == "10m"
assert component.streaming_callback is print_streaming_chunk
assert component.tools == tools
Expand Down Expand Up @@ -603,6 +606,7 @@ def test_to_dict(self):
"type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
"init_parameters": {
"timeout": 120,
"max_retries": 0,
"model": "llama2",
"url": "custom_url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
Expand Down Expand Up @@ -650,6 +654,7 @@ def test_from_dict(self):
"type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
"init_parameters": {
"timeout": 120,
"max_retries": 0,
"model": "llama2",
"url": "custom_url",
"keep_alive": "5m",
Expand Down Expand Up @@ -689,6 +694,7 @@ def test_from_dict(self):
"some_test_param": "test-params",
}
assert component.timeout == 120
assert component.max_retries == 0
assert component.tools == [tool]
assert component.response_format == {
"type": "object",
Expand Down Expand Up @@ -790,6 +796,38 @@ def test_run(self, mock_client):
assert result["replies"][0].text == "Fine. How can I help you today?"
assert result["replies"][0].role == "assistant"

@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
def test_run_retries_after_failure(self, mock_client):
generator = OllamaChatGenerator(max_retries=1)

mock_response = ChatResponse(
model="qwen3:0.6b",
created_at="2023-12-12T14:13:43.416799Z",
message={"role": "assistant", "content": "Recovered after retry"},
done=True,
prompt_eval_count=1,
eval_count=2,
)

mock_client_instance = mock_client.return_value
mock_client_instance.chat.side_effect = [RuntimeError("temporary failure"), mock_response]

result = generator.run(messages=[ChatMessage.from_user("Hello!")])

assert mock_client_instance.chat.call_count == 2
assert result["replies"][0].text == "Recovered after retry"

@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
def test_run_raises_after_retry_exhausted(self, mock_client):
generator = OllamaChatGenerator(max_retries=1)
mock_client_instance = mock_client.return_value
mock_client_instance.chat.side_effect = RuntimeError("persistent failure")

with pytest.raises(RuntimeError, match="persistent failure"):
generator.run(messages=[ChatMessage.from_user("Hello!")])

assert mock_client_instance.chat.call_count == 2

@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
def test_run_streaming(self, mock_client):
collected_chunks = []
Expand Down