|
30 | 30 | from google.genai.errors import ClientError |
31 | 31 | from typing_extensions import override |
32 | 32 |
|
33 | | -from ..utils._client_labels_utils import get_client_labels |
| 33 | +from ..utils._google_client_headers import get_tracking_headers |
| 34 | +from ..utils._google_client_headers import merge_tracking_headers |
34 | 35 | from ..utils.context_utils import Aclosing |
35 | 36 | from ..utils.streaming_utils import StreamingResponseAggregator |
36 | 37 | from ..utils.variant_utils import GoogleLLMVariant |
@@ -316,13 +317,7 @@ def _api_backend(self) -> GoogleLLMVariant: |
316 | 317 | ) |
317 | 318 |
|
318 | 319 | def _tracking_headers(self) -> dict[str, str]: |
319 | | - labels = get_client_labels() |
320 | | - header_value = ' '.join(labels) |
321 | | - tracking_headers = { |
322 | | - 'x-goog-api-client': header_value, |
323 | | - 'user-agent': header_value, |
324 | | - } |
325 | | - return tracking_headers |
| 320 | + return get_tracking_headers() |
326 | 321 |
|
327 | 322 | @cached_property |
328 | 323 | def _live_api_version(self) -> str: |
@@ -362,8 +357,10 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: |
362 | 357 | ): |
363 | 358 | if not llm_request.live_connect_config.http_options.headers: |
364 | 359 | llm_request.live_connect_config.http_options.headers = {} |
365 | | - llm_request.live_connect_config.http_options.headers.update( |
366 | | - self._tracking_headers() |
| 360 | + llm_request.live_connect_config.http_options.headers = ( |
| 361 | + self._merge_tracking_headers( |
| 362 | + llm_request.live_connect_config.http_options.headers |
| 363 | + ) |
367 | 364 | ) |
368 | 365 | llm_request.live_connect_config.http_options.api_version = ( |
369 | 366 | self._live_api_version |
@@ -456,20 +453,7 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None: |
456 | 453 |
|
457 | 454 | def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: |
458 | 455 | """Merge tracking headers to the given headers.""" |
459 | | - headers = headers or {} |
460 | | - for key, tracking_header_value in self._tracking_headers().items(): |
461 | | - custom_value = headers.get(key, None) |
462 | | - if not custom_value: |
463 | | - headers[key] = tracking_header_value |
464 | | - continue |
465 | | - |
466 | | - # Merge tracking headers with existing headers and avoid duplicates. |
467 | | - value_parts = tracking_header_value.split(' ') |
468 | | - for custom_value_part in custom_value.split(' '): |
469 | | - if custom_value_part not in value_parts: |
470 | | - value_parts.append(custom_value_part) |
471 | | - headers[key] = ' '.join(value_parts) |
472 | | - return headers |
| 456 | + return merge_tracking_headers(headers) |
473 | 457 |
|
474 | 458 |
|
475 | 459 | def _build_function_declaration_log( |
|
0 commit comments