-
Notifications
You must be signed in to change notification settings - Fork 254
feat(ollama): add max_retries to chat generator #2899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
83278a1
0195286
63262c4
192cd7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |||||||||||
| ) | ||||||||||||
| from haystack.utils.callable_serialization import deserialize_callable, serialize_callable | ||||||||||||
| from pydantic.json_schema import JsonSchemaValue | ||||||||||||
| from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential | ||||||||||||
|
|
||||||||||||
| from ollama import AsyncClient, ChatResponse, Client | ||||||||||||
|
|
||||||||||||
|
|
@@ -216,6 +217,7 @@ def __init__( | |||||||||||
| url: str = "http://localhost:11434", | ||||||||||||
| generation_kwargs: dict[str, Any] | None = None, | ||||||||||||
| timeout: int = 120, | ||||||||||||
| max_retries: int = 0, | ||||||||||||
| keep_alive: float | str | None = None, | ||||||||||||
| streaming_callback: Callable[[StreamingChunk], None] | None = None, | ||||||||||||
| tools: ToolsType | None = None, | ||||||||||||
|
|
@@ -233,6 +235,8 @@ def __init__( | |||||||||||
| [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). | ||||||||||||
| :param timeout: | ||||||||||||
| The number of seconds before throwing a timeout error from the Ollama API. | ||||||||||||
| :param max_retries: | ||||||||||||
| Maximum number of retries to attempt for failed requests. | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lest's make this parameter more descriptive.
Suggested change
|
||||||||||||
| :param think: | ||||||||||||
| If True, the model will "think" before producing a response. | ||||||||||||
| Only [thinking models](https://ollama.com/search?c=thinking) support this feature. | ||||||||||||
|
|
@@ -268,6 +272,7 @@ def __init__( | |||||||||||
| self.url = url | ||||||||||||
| self.generation_kwargs = generation_kwargs or {} | ||||||||||||
| self.timeout = timeout | ||||||||||||
| self.max_retries = max_retries | ||||||||||||
| self.keep_alive = keep_alive | ||||||||||||
| self.streaming_callback = streaming_callback | ||||||||||||
| self.tools = tools # Store original tools for serialization | ||||||||||||
|
|
@@ -292,6 +297,7 @@ def to_dict(self) -> dict[str, Any]: | |||||||||||
| url=self.url, | ||||||||||||
| generation_kwargs=self.generation_kwargs, | ||||||||||||
| timeout=self.timeout, | ||||||||||||
| max_retries=self.max_retries, | ||||||||||||
| keep_alive=self.keep_alive, | ||||||||||||
| streaming_callback=callback_name, | ||||||||||||
| tools=serialize_tools_or_toolset(self.tools), | ||||||||||||
|
|
@@ -518,16 +524,25 @@ def run( | |||||||||||
|
|
||||||||||||
| ollama_messages = [_convert_chatmessage_to_ollama_format(m) for m in messages] | ||||||||||||
|
|
||||||||||||
| response = self._client.chat( | ||||||||||||
| model=self.model, | ||||||||||||
| messages=ollama_messages, | ||||||||||||
| tools=ollama_tools, | ||||||||||||
| stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool | ||||||||||||
| keep_alive=self.keep_alive, | ||||||||||||
| options=generation_kwargs, | ||||||||||||
| format=self.response_format, | ||||||||||||
| think=self.think, | ||||||||||||
| @retry( | ||||||||||||
| reraise=True, | ||||||||||||
| stop=stop_after_attempt(self.max_retries + 1), | ||||||||||||
| retry=retry_if_exception_type(Exception), | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Retrying on all kind of exceptions is too broad.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Retry on 429 and 5xx? Lemme know if you want to retry on something else as well. |
||||||||||||
| wait=wait_exponential(), | ||||||||||||
| ) | ||||||||||||
| def chat_with_retry() -> ChatResponse | Iterator[ChatResponse]: | ||||||||||||
| return self._client.chat( | ||||||||||||
| model=self.model, | ||||||||||||
| messages=ollama_messages, | ||||||||||||
| tools=ollama_tools, | ||||||||||||
| stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool | ||||||||||||
| keep_alive=self.keep_alive, | ||||||||||||
| options=generation_kwargs, | ||||||||||||
| format=self.response_format, | ||||||||||||
| think=self.think, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| response = chat_with_retry() | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have this not nested inside of run?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Retry annotation requires function so can
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will separate out function in next commit |
||||||||||||
|
|
||||||||||||
| if isinstance(response, Iterator): | ||||||||||||
| return self._handle_streaming_response(response_iter=response, callback=callback) | ||||||||||||
|
|
@@ -579,16 +594,25 @@ async def run_async( | |||||||||||
|
|
||||||||||||
| ollama_messages = [_convert_chatmessage_to_ollama_format(m) for m in messages] | ||||||||||||
|
|
||||||||||||
| response = await self._async_client.chat( | ||||||||||||
| model=self.model, | ||||||||||||
| messages=ollama_messages, | ||||||||||||
| tools=ollama_tools, | ||||||||||||
| stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool | ||||||||||||
| keep_alive=self.keep_alive, | ||||||||||||
| options=generation_kwargs, | ||||||||||||
| format=self.response_format, | ||||||||||||
| think=self.think, | ||||||||||||
| @retry( | ||||||||||||
| reraise=True, | ||||||||||||
| stop=stop_after_attempt(self.max_retries + 1), | ||||||||||||
| retry=retry_if_exception_type(Exception), | ||||||||||||
| wait=wait_exponential(), | ||||||||||||
| ) | ||||||||||||
| async def chat_with_retry() -> ChatResponse | AsyncIterator[ChatResponse]: | ||||||||||||
| return await self._async_client.chat( | ||||||||||||
| model=self.model, | ||||||||||||
| messages=ollama_messages, | ||||||||||||
| tools=ollama_tools, | ||||||||||||
| stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool | ||||||||||||
| keep_alive=self.keep_alive, | ||||||||||||
| options=generation_kwargs, | ||||||||||||
| format=self.response_format, | ||||||||||||
| think=self.think, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| response = await chat_with_retry() | ||||||||||||
|
|
||||||||||||
| if isinstance(response, AsyncIterator): | ||||||||||||
| # response is an async iterator for streaming | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a specific reason for
>=8.2.3for tenacity?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the initial commit, I did not include this, so the Python 3.13 check built with an older version of tenacity that did not include either one this
retry,retry_if_exception, orwait_exponential, which caused the verification checks to fail. Because of this, I had to explicitly specify them.