diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md index 960b6f40c2..56f35dd785 100644 --- a/contributing/samples/bigquery/README.md +++ b/contributing/samples/bigquery/README.md @@ -24,11 +24,11 @@ distributed via the `google.adk.tools.bigquery` module. These tools include: 5. `get_job_info` Fetches metadata about a BigQuery job. -5. `execute_sql` +6. `execute_sql` Runs or dry-runs a SQL query in BigQuery. -6. `ask_data_insights` +7. `ask_data_insights` Natural language-in, natural language-out tool that answers questions about structured data in BigQuery. Provides a one-stop solution for generating @@ -38,23 +38,26 @@ distributed via the `google.adk.tools.bigquery` module. These tools include: the official [Conversational Analytics API documentation](https://cloud.google.com/gemini/docs/conversational-analytics-api/overview) for instructions. -7. `forecast` +8. `forecast` Perform time series forecasting using BigQuery's `AI.FORECAST` function, leveraging the TimesFM 2.0 model. -8. `analyze_contribution` +9. `analyze_contribution` Perform contribution analysis in BigQuery by creating a temporary `CONTRIBUTION_ANALYSIS` model and then querying it with `ML.GET_INSIGHTS` to find top contributors for a given metric. -9. `detect_anomalies` +10. `detect_anomalies` Perform time series anomaly detection in BigQuery by creating a temporary `ARIMA_PLUS` model and then querying it with `ML.DETECT_ANOMALIES` to detect time series data anomalies. +11. `search_catalog` + Searches for data entries across projects using the Dataplex Catalog. This allows discovery of datasets, tables, and other assets. + ## How to use Set up environment variables in your `.env` file for using diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index d20741b84f..ec0b365c63 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -19,8 +19,10 @@ from .._google_credentials import BaseGoogleCredentialsConfig BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" -BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"] - +BIGQUERY_SCOPES = [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/cloud-platform", +] @experimental(FeatureName.GOOGLE_CREDENTIALS_CONFIG) class BigQueryCredentialsConfig(BaseGoogleCredentialsConfig): @@ -34,8 +36,7 @@ def __post_init__(self) -> BigQueryCredentialsConfig: super().__post_init__() if not self.scopes: - self.scopes = BIGQUERY_DEFAULT_SCOPE - + self.scopes = BIGQUERY_SCOPES # Set the token cache key self._token_cache_key = BIGQUERY_TOKEN_CACHE_KEY diff --git a/src/google/adk/tools/bigquery/bigquery_toolset.py b/src/google/adk/tools/bigquery/bigquery_toolset.py index 2800c19e38..1122774c12 100644 --- a/src/google/adk/tools/bigquery/bigquery_toolset.py +++ b/src/google/adk/tools/bigquery/bigquery_toolset.py @@ -24,6 +24,7 @@ from . import data_insights_tool from . import metadata_tool from . import query_tool +from . import search_tool from ...features import experimental from ...features import FeatureName from ...tools.base_tool import BaseTool @@ -87,6 +88,7 @@ async def get_tools( query_tool.analyze_contribution, query_tool.detect_anomalies, data_insights_tool.ask_data_insights, + search_tool.search_catalog, ] ] diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index 85912ce891..dfc5a3cd02 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -19,10 +19,14 @@ import google.api_core.client_info from google.auth.credentials import Credentials from google.cloud import bigquery +from google.cloud import dataplex_v1 +from google.api_core.gapic_v1 import client_info as gapic_client_info from ... import version -USER_AGENT = f"adk-bigquery-tool google-adk/{version.__version__}" +USER_AGENT_BASE = f"google-adk/{version.__version__}" +BQ_USER_AGENT = f"adk-bigquery-tool {USER_AGENT_BASE}" +DP_USER_AGENT = f"adk-dataplex-tool {USER_AGENT_BASE}" from typing import List @@ -48,7 +52,7 @@ def get_bigquery_client( A BigQuery client. """ - user_agents = [USER_AGENT] + user_agents = [BQ_USER_AGENT] if user_agent: if isinstance(user_agent, str): user_agents.append(user_agent) @@ -67,3 +71,36 @@ def get_bigquery_client( ) return bigquery_client + +def get_dataplex_catalog_client( + *, + credentials: Credentials, + user_agent: Optional[Union[str, List[str]]] = None, +) -> dataplex_v1.CatalogServiceClient: + """Get a Dataplex CatalogServiceClient with minimal necessary arguments. + + Args: + credentials: The credentials to use for the request. + user_agent: Additional user agent string(s) to append. + + Returns: + A Dataplex Client. + """ + + user_agents = [DP_USER_AGENT] + if user_agent: + if isinstance(user_agent, str): + user_agents.append(user_agent) + else: + user_agents.extend([ua for ua in user_agent if ua]) + + client_info = gapic_client_info.ClientInfo( + user_agent=" ".join(user_agents) + ) + + dataplex_client = dataplex_v1.CatalogServiceClient( + credentials=credentials, + client_info=client_info, + ) + + return dataplex_client diff --git a/src/google/adk/tools/bigquery/search_tool.py b/src/google/adk/tools/bigquery/search_tool.py new file mode 100644 index 0000000000..5bb73fb7eb --- /dev/null +++ b/src/google/adk/tools/bigquery/search_tool.py @@ -0,0 +1,130 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +from google.api_core import exceptions as api_exceptions +from google.auth.credentials import Credentials +from google.cloud import dataplex_v1 + +from . import client +from .config import BigQueryToolConfig + +def _construct_search_query_helper(predicate: str, operator: str, items: List[str]) -> str: + if not items: + return "" + if len(items) == 1: + return f'{predicate}{operator}"{items[0]}"' + + clauses = [f'{predicate}{operator}"{item}"' for item in items] + return "(" + " OR ".join(clauses) + ")" + +def search_catalog( + prompt: str, + project_id: str, + credentials: Credentials, + settings: BigQueryToolConfig, + location: str, + page_size: int = 10, + project_ids_filter: Optional[List[str]] = None, + dataset_ids_filter: Optional[List[str]] = None, + types_filter: Optional[List[str]] = None, +) -> Dict[str, Any]: + """Search for BigQuery assets within Dataplex. + + Args: + prompt (str): The base search query (natural language or keywords). + project_id (str): The Google Cloud project ID to scope the search. + credentials (Credentials): Credentials for the request. + settings (BigQueryToolConfig): BigQuery tool settings. + location (str): The Dataplex location to use. + page_size (int): Maximum number of results. + project_ids_filter (Optional[List[str]]): Specific project IDs to include in the search results. + If None, defaults to the scoping project_id. + dataset_ids_filter (Optional[List[str]]): BigQuery dataset IDs to filter by. + types_filter (Optional[List[str]]): Entry types to filter by (e.g., "TABLE", "DATASET"). + + Returns: + dict: Search results or error. + """ + try: + if not project_id: + return {"status": "ERROR", "error_details": "project_id must be provided."} + + dataplex_client = client.get_dataplex_catalog_client( + credentials=credentials, + user_agent=[settings.application_name, "search_catalog"], + ) + + query_parts = [] + if prompt: + query_parts.append(f"({prompt})") + + # Filter by project IDs + projects_to_filter = project_ids_filter if project_ids_filter else [project_id] + if projects_to_filter: + query_parts.append(_construct_search_query_helper("projectid", "=", projects_to_filter)) + + # Filter by dataset IDs + if dataset_ids_filter: + dataset_resource_filters = [f'linked_resource:"//bigquery.googleapis.com/projects/{pid}/datasets/{did}/*"' for pid in projects_to_filter for did in dataset_ids_filter] + if dataset_resource_filters: + query_parts.append(f"({' OR '.join(dataset_resource_filters)})") + # Filter by entry types + if types_filter: + query_parts.append(_construct_search_query_helper("type", "=", types_filter)) + + # Always scope to BigQuery system + query_parts.append("system=BIGQUERY") + + full_query = " AND ".join(filter(None, query_parts)) + + search_scope = f"projects/{project_id}/locations/{location}" + + request = dataplex_v1.SearchEntriesRequest( + name=search_scope, + query=full_query, + page_size=page_size, + semantic_search=True, + ) + + response = dataplex_client.search_entries(request=request) + + results = [] + for result in response.results: + entry = result.dataplex_entry + source = entry.entry_source + results.append( + { + "name": entry.name, + "display_name": source.display_name or "", + "entry_type": entry.entry_type, + "update_time": str(entry.update_time), + "linked_resource": source.resource or "", + "description": source.description or "", + "location": source.location or "", + } + ) + return {"status": "SUCCESS", "results": results} + + except api_exceptions.GoogleAPICallError as e: + logging.exception("search_catalog tool: API call failed") + return {"status": "ERROR", "error_details": f"Dataplex API Error: {str(e)}"} + except Exception as ex: + logging.exception("search_catalog tool: Unexpected error") + return {"status": "ERROR", "error_details": str(ex)} + diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials.py b/tests/unittests/tools/bigquery/test_bigquery_credentials.py index 2342446c2a..9a568071d5 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials.py @@ -46,7 +46,7 @@ def test_valid_credentials_object_auth_credentials(self): assert config.credentials == auth_creds assert config.client_id is None assert config.client_secret is None - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == ["https://www.googleapis.com/auth/bigquery","https://www.googleapis.com/auth/cloud-platform"] def test_valid_credentials_object_oauth2_credentials(self): """Test that providing valid Credentials object works correctly with @@ -86,7 +86,7 @@ def test_valid_client_id_secret_pair_default_scope(self): assert config.credentials is None assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == ["https://www.googleapis.com/auth/bigquery","https://www.googleapis.com/auth/cloud-platform",] def test_valid_client_id_secret_pair_w_scope(self): """Test that providing client ID and secret with explicit scopes works. @@ -128,7 +128,7 @@ def test_valid_client_id_secret_pair_w_empty_scope(self): assert config.credentials is None assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == ["https://www.googleapis.com/auth/bigquery","https://www.googleapis.com/auth/cloud-platform"] def test_missing_client_secret_raises_error(self): """Test that missing client secret raises appropriate validation error. diff --git a/tests/unittests/tools/bigquery/test_bigquery_search_tool.py b/tests/unittests/tools/bigquery/test_bigquery_search_tool.py new file mode 100644 index 0000000000..c72632ce06 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_search_tool.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +from unittest import mock +from typing import Any, Dict, List + +import pytest +from google.api_core import exceptions as api_exceptions +from google.auth.credentials import Credentials +from google.cloud import dataplex_v1 + +from google.adk.tools.bigquery import search_tool +from google.adk.tools.bigquery import client +from google.adk.tools.bigquery.config import BigQueryToolConfig + +# Helper function to create mock credentials +def _mock_creds(): + return mock.create_autospec(Credentials, instance=True) + +# Helper function to create mock settings +def _mock_settings(app_name: str | None = "test-app"): + return BigQueryToolConfig(application_name=app_name) + +# Mock response for dataplex_client.search_entries +def _mock_search_entries_response(results: List[Dict[str, Any]]): + mock_response = mock.MagicMock(spec=dataplex_v1.SearchEntriesResponse) + mock_results = [] + for r in results: + mock_result = mock.MagicMock() + mock_entry = mock_result.dataplex_entry + mock_entry.name = r.get("name") + mock_entry.entry_type = r.get("entry_type") + mock_entry.update_time = r.get("update_time", "2026-01-14T05:00:00Z") + mock_source = mock_entry.entry_source + mock_source.display_name = r.get("display_name") + mock_source.resource = r.get("linked_resource") + mock_source.description = r.get("description") + mock_source.location = r.get("location") + mock_results.append(mock_result) + mock_response.results = mock_results + return mock_response + +class TestSearchCatalog: + + @pytest.fixture(autouse=True) + def setup_mocks(self): + self.mock_dataplex_client = mock.MagicMock(spec=dataplex_v1.CatalogServiceClient) + self.mock_get_dataplex_client = mock.patch.object( + client, "get_dataplex_catalog_client", autospec=True + ).start() + self.mock_get_dataplex_client.return_value = self.mock_dataplex_client + self.mock_search_request = mock.patch.object( + dataplex_v1, "SearchEntriesRequest", autospec=True + ).start() + + yield + + mock.patch.stopall() + + def test_search_catalog_success(self): + """Test the successful path of search_catalog.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "customer data" + project_id = "test-project" + + mock_api_results = [ + { + "name": "entry1", "entry_type": "TABLE", "display_name": "Cust Table", + "linked_resource": "//bigquery.googleapis.com/projects/p/datasets/d/tables/t1", + "description": "Table 1", "location": "us", + } + ] + self.mock_dataplex_client.search_entries.return_value = _mock_search_entries_response(mock_api_results) + + result = search_tool.search_catalog( + prompt, project_id, creds, settings + ) + + assert result["status"] == "SUCCESS" + assert len(result["results"]) == 1 + assert result["results"][0]["name"] == "entry1" + assert result["results"][0]["display_name"] == "Cust Table" + + self.mock_get_dataplex_client.assert_called_once_with( + credentials=creds, user_agent=["test-app", "search_catalog"] + ) + + expected_query = '(customer data) AND projectid="test-project" AND system=BIGQUERY' + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/global", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once_with( + request=self.mock_search_request.return_value + ) + + def test_search_catalog_no_project_id(self): + """Test search_catalog with missing project_id.""" + result = search_tool.search_catalog( + "test", "", _mock_creds(), _mock_settings() + ) + assert result["status"] == "ERROR" + assert "project_id must be provided" in result["error_details"] + self.mock_get_dataplex_client.assert_not_called() + + def test_search_catalog_api_error(self): + """Test search_catalog handling API exceptions.""" + self.mock_dataplex_client.search_entries.side_effect = api_exceptions.BadRequest("Invalid query") + + result = search_tool.search_catalog( + "test", "test-project", _mock_creds(), _mock_settings() + ) + assert result["status"] == "ERROR" + assert "Dataplex API Error: Invalid query" in result["error_details"] + + def test_search_catalog_other_exception(self): + """Test search_catalog handling unexpected exceptions.""" + self.mock_get_dataplex_client.side_effect = Exception("Something went wrong") + + result = search_tool.search_catalog( + "test", "test-project", _mock_creds(), _mock_settings() + ) + assert result["status"] == "ERROR" + assert "Something went wrong" in result["error_details"] + + @pytest.mark.parametrize( + "prompt, project_ids, dataset_ids, types, expected_query_part", + [ + ("p", None, None, None, 'projectid="test-project"'), + ("p", ["proj1"], None, None, 'projectid="proj1"'), + ("p", ["p1", "p2"], None, None, '(projectid="p1" OR projectid="p2")'), + ("p", None, None, ["TABLE"], 'type="TABLE"'), + ("p", None, None, ["TABLE", "DATASET"], '(type="TABLE" OR type="DATASET")'), + ], + ) + def test_search_catalog_query_construction( + self, prompt, project_ids, dataset_ids, types, expected_query_part + ): + """Test different query constructions based on filters.""" + search_tool.search_catalog( + prompt, "test-project", _mock_creds(), _mock_settings(), + project_ids_filter=project_ids, + dataset_ids_filter=dataset_ids, + types_filter=types, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + query = kwargs["query"] + + if prompt: + assert f"({prompt})" in query + assert "system=BIGQUERY" in query + assert expected_query_part in query + + def test_search_catalog_no_app_name(self): + """Test search_catalog when settings.application_name is None.""" + creds = _mock_creds() + settings = _mock_settings(app_name=None) + search_tool.search_catalog("test", "test-project", creds, settings) + + self.mock_get_dataplex_client.assert_called_once_with( + credentials=creds, user_agent=[None, "search_catalog"] + ) + + def test_search_catalog_multi_project_filter_semantic(self): + """Test semantic search with a multi-project filter.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "What datasets store user profiles?" + project_id = "main-project" + project_filters = ["user-data-proj", "shared-infra-proj"] + location = "global" + + self.mock_dataplex_client.search_entries.return_value = _mock_search_entries_response([]) + + search_tool.search_catalog( + prompt, project_id, creds, settings, location=location, + project_ids_filter=project_filters, + types_filter=["DATASET"] + ) + + expected_query = ( + f"({prompt}) AND " + "(projectid=\"user-data-proj\" OR projectid=\"shared-infra-proj\") AND " + 'type="DATASET" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/{location}", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once() + + def test_search_catalog_natural_language_semantic(self): + """Test natural language prompts with semantic search enabled and check output.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "Find tables about football matches" + project_id = "sports-analytics" + location = "europe-west1" + + # Mock the results that the API would return for this semantic query + mock_api_results = [ + { + "name": "projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb1", + "display_name": "uk_football_premiership", + "entry_type": "projects/655216118709/locations/global/entryTypes/bigquery-table", + "linked_resource": "//bigquery.googleapis.com/projects/sports-analytics/datasets/uk/tables/premiership", + "description": "Stats for UK Premier League matches.", + "location": "europe-west1" + }, + { + "name": "projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb2", + "display_name": "serie_a_matches", + "entry_type": "projects/655216118709/locations/global/entryTypes/bigquery-table", + "linked_resource": "//bigquery.googleapis.com/projects/sports-analytics/datasets/italy/tables/serie_a", + "description": "Italian Serie A football results.", + "location": "europe-west1" + } + ] + self.mock_dataplex_client.search_entries.return_value = _mock_search_entries_response(mock_api_results) + + # Call the tool + result = search_tool.search_catalog( + prompt, project_id, creds, settings, location=location + ) + + # Assert the request was made as expected + expected_query = f"({prompt}) AND projectid=\"{project_id}\" AND system=BIGQUERY" + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/{location}", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once() + + # Assert the output is processed correctly + assert result["status"] == "SUCCESS" + assert len(result["results"]) == 2 + assert result["results"][0]["display_name"] == "uk_football_premiership" + assert result["results"][1]["display_name"] == "serie_a_matches" + assert "UK Premier League" in result["results"][0]["description"] + + def test_query_with_project_and_dataset_filters(self): + creds = _mock_creds() + settings = _mock_settings() + project_id = "proj1" + location = "us-central1" # Using a specific location + + search_tool.search_catalog( + prompt="inventory", + project_id=project_id, + credentials=creds, + settings=settings, + project_ids_filter=["proj1", "proj2"], + dataset_ids_filter=["dsetA"], + location=location, + ) + + self.mock_get_dataplex_client.assert_called_once_with( + credentials=creds, + user_agent=["test-app", "search_catalog"] + ) + + expected_query = '(inventory) AND (projectid="proj1" OR projectid="proj2") AND (linked_resource:"//bigquery.googleapis.com/projects/proj1/datasets/dsetA/*" OR linked_resource:"//bigquery.googleapis.com/projects/proj2/datasets/dsetA/*") AND system=BIGQUERY' + expected_search_scope = f"projects/{project_id}/locations/{location}" + self.mock_search_request.assert_called_once_with( + name=expected_search_scope, + query=expected_query, + page_size=10, + semantic_search=True + ) + + self.mock_dataplex_client.search_entries.assert_called_once_with( + request=self.mock_search_request.return_value + ) + + + + diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index 2d890fb51a..21f7708c12 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -41,7 +41,7 @@ async def test_bigquery_toolset_tools_default(): tools = await toolset.get_tools() assert tools is not None - assert len(tools) == 10 + assert len(tools) == 11 assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set([ @@ -55,6 +55,7 @@ async def test_bigquery_toolset_tools_default(): "forecast", "analyze_contribution", "detect_anomalies", + "search_catalog", ]) actual_tool_names = set([tool.name for tool in tools]) assert actual_tool_names == expected_tool_names diff --git a/tests/unittests/tools/bigquery/test_dataplex_catalog_client.py b/tests/unittests/tools/bigquery/test_dataplex_catalog_client.py new file mode 100644 index 0000000000..a16b0f2993 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_dataplex_catalog_client.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from unittest import mock +from typing import Union, List, Optional + +import google.adk +from google.adk.tools.bigquery.client import get_dataplex_catalog_client, DP_USER_AGENT +from google.cloud import dataplex_v1 +from google.oauth2.credentials import Credentials +from google.api_core.gapic_v1 import client_info as gapic_client_info +from google.api_core import client_options as client_options_lib + +# Mock the CatalogServiceClient class directly +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_default(mock_catalog_service_client): + """Test get_dataplex_catalog_client with default user agent.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + + # Call the function under test + client = get_dataplex_catalog_client(credentials=mock_creds) + + # Assert that CatalogServiceClient constructor was called once + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + + # Check the arguments passed to the CatalogServiceClient constructor + assert kwargs["credentials"] == mock_creds + client_info = kwargs["client_info"] + assert isinstance(client_info, gapic_client_info.ClientInfo) + assert client_info.user_agent == DP_USER_AGENT + + # Ensure the function returns the mock instance + assert client == mock_catalog_service_client.return_value + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_str(mock_catalog_service_client): + """Test get_dataplex_catalog_client with a custom user agent string.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua = "catalog_ua/1.0" + expected_ua = f"{DP_USER_AGENT} {custom_ua}" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_list(mock_catalog_service_client): + """Test get_dataplex_catalog_client with a custom user agent list.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua_list = ["catalog_ua", "catalog_ua_2.0"] + expected_ua = f"{DP_USER_AGENT} {' '.join(custom_ua_list)}" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_list_with_none(mock_catalog_service_client): + """Test get_dataplex_catalog_client with a list containing None.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua_list = ["catalog_ua", None, "catalog_ua_2.0"] + expected_ua = f"{DP_USER_AGENT} catalog_ua catalog_ua_2.0" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua