Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sdks/python/apache_beam/ml/inference/gemini_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Any
from typing import Optional
from typing import Union
from typing import cast

from google import genai
from google.genai import errors
Expand Down Expand Up @@ -73,7 +74,7 @@ def generate_from_string(
call.
"""
return model.models.generate_content(
model=model_name, contents=batch, **inference_args)
model=model_name, contents=cast(Any, batch), **inference_args)


def generate_image_from_strings_and_images(
Expand All @@ -96,7 +97,7 @@ def generate_image_from_strings_and_images(
call.
"""
return model.models.generate_content(
model=model_name, contents=batch, **inference_args)
model=model_name, contents=cast(Any, batch), **inference_args)


class GeminiModelHandler(RemoteModelHandler[Any, PredictionResult,
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

"""Tests for apache_beam.ml.rag.embeddings.vertex_ai."""

import pytest
import shutil
import tempfile
import unittest

import pytest

import apache_beam as beam
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
Expand Down
21 changes: 13 additions & 8 deletions sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@
from typing import cast

import pytest
from pymilvus import CollectionSchema
from pymilvus import DataType
from pymilvus import FieldSchema
from pymilvus import MilvusClient
from pymilvus.exceptions import MilvusException
from pymilvus.milvus_client import IndexParams

import apache_beam as beam
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
Expand All @@ -41,11 +35,21 @@
from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs
from apache_beam.testing.test_pipeline import TestPipeline

# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
try:
from pymilvus import CollectionSchema
from pymilvus import DataType
from pymilvus import FieldSchema
from pymilvus import MilvusClient
from pymilvus.exceptions import MilvusException
from pymilvus.milvus_client import IndexParams

from apache_beam.ml.rag.ingestion.milvus_search import MilvusVectorWriterConfig
from apache_beam.ml.rag.ingestion.milvus_search import MilvusWriteConfig
except ImportError as e:
raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}')
PYMILVUS_AVAILABLE = True
except ImportError:
PYMILVUS_AVAILABLE = False
# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports


def _construct_index_params():
Expand Down Expand Up @@ -158,6 +162,7 @@ def drop_collection(client: MilvusClient, collection_name: str):


@pytest.mark.require_docker_in_docker
@unittest.skipIf(not PYMILVUS_AVAILABLE, 'pymilvus is not installed.')
@unittest.skipUnless(
platform.system() == "Linux",
"Test runs only on Linux due to lack of support, as yet, for nested "
Expand Down
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/ml/transforms/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
import PIL
from PIL.Image import Image as PIL_Image
except ImportError:
PIL = None
PIL_Image = Any
PIL = None # type: ignore[assignment]
PIL_Image = Any # type: ignore[misc, assignment]

try:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
try:
from PIL import Image
except ImportError:
Image = None
Image = None # type: ignore[assignment]

_HF_TOKEN = os.environ.get('HF_INFERENCE_TOKEN')
test_query = "This is a test"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubImageEmbeddings
except ImportError:
TensorflowHubImageEmbeddings = None # type: ignore
Image = None
Image = None # type: ignore[assignment]


@unittest.skipIf(
Expand Down
Loading