Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 169 additions & 9 deletions plugins/flytekit-spark/flytekitplugins/spark/connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import http
import json
import logging
import os
import typing
from dataclasses import dataclass
Expand All @@ -13,12 +14,14 @@
from flytekit.extend.backend.utils import convert_to_flyte_phase, get_connector_secret
from flytekit.models.core.execution import TaskLog
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.models.task import TaskExecutionMetadata, TaskTemplate

from .utils import is_serverless_config as _is_serverless_config

aiohttp = lazy_module("aiohttp")

logger = logging.getLogger(__name__)

DATABRICKS_API_ENDPOINT = "/api/2.1/jobs"
DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE"
DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY = "FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER"
Expand All @@ -28,6 +31,7 @@
class DatabricksJobMetadata(ResourceMeta):
databricks_instance: str
run_id: str
auth_token: Optional[str] = None # Store auth token for get/delete operations


def _configure_serverless(databricks_job: dict, envs: dict) -> str:
Expand Down Expand Up @@ -252,7 +256,11 @@ def __init__(self):
super().__init__(task_type_name="spark", metadata_type=DatabricksJobMetadata)

async def create(
self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs
self,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
task_execution_metadata: Optional[TaskExecutionMetadata] = None,
**kwargs,
) -> DatabricksJobMetadata:
data = json.dumps(_get_databricks_job_spec(task_template))
databricks_instance = task_template.custom.get(
Expand All @@ -264,24 +272,43 @@ async def create(
f"Missing databricks instance. Please set the value through the task config or set the {DEFAULT_DATABRICKS_INSTANCE_ENV_KEY} environment variable in the connector."
)

# Get workflow-specific token or fall back to default
namespace = task_execution_metadata.namespace if task_execution_metadata else None

# Extract custom secret name from task template (if provided)
custom_secret_name = task_template.custom.get("databricksTokenSecret")

logger.info(f"Creating Databricks job for namespace: {namespace or 'unknown'}")
if custom_secret_name:
logger.info(f"Using custom secret name: {custom_secret_name}")

auth_token = get_databricks_token(
namespace=namespace, task_template=task_template, secret_name=custom_secret_name
)
databricks_url = f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/submit"

async with aiohttp.ClientSession() as session:
async with session.post(databricks_url, headers=get_header(), data=data) as resp:
async with session.post(databricks_url, headers=get_header(auth_token=auth_token), data=data) as resp:
response = await resp.json()
if resp.status != http.HTTPStatus.OK:
raise RuntimeError(f"Failed to create databricks job with error: {response}")

return DatabricksJobMetadata(databricks_instance=databricks_instance, run_id=str(response["run_id"]))
logger.info(f"Successfully created Databricks job with run_id: {response['run_id']}")
return DatabricksJobMetadata(
databricks_instance=databricks_instance, run_id=str(response["run_id"]), auth_token=auth_token
)

async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource:
databricks_instance = resource_meta.databricks_instance
databricks_url = (
f"https://{databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/get?run_id={resource_meta.run_id}"
)

# Use the stored auth token if available, otherwise fall back to default
headers = get_header(auth_token=resource_meta.auth_token)

async with aiohttp.ClientSession() as session:
async with session.get(databricks_url, headers=get_header()) as resp:
async with session.get(databricks_url, headers=headers) as resp:
if resp.status != http.HTTPStatus.OK:
raise RuntimeError(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}")
response = await resp.json()
Expand Down Expand Up @@ -312,8 +339,11 @@ async def delete(self, resource_meta: DatabricksJobMetadata, **kwargs):
databricks_url = f"https://{resource_meta.databricks_instance}{DATABRICKS_API_ENDPOINT}/runs/cancel"
data = json.dumps({"run_id": resource_meta.run_id})

# Use the stored auth token if available, otherwise fall back to default
headers = get_header(auth_token=resource_meta.auth_token)

async with aiohttp.ClientSession() as session:
async with session.post(databricks_url, headers=get_header(), data=data) as resp:
async with session.post(databricks_url, headers=headers, data=data) as resp:
if resp.status != http.HTTPStatus.OK:
raise RuntimeError(
f"Failed to cancel databricks job {resource_meta.run_id} with error: {resp.reason}"
Expand All @@ -334,9 +364,139 @@ def __init__(self):
super(DatabricksConnector, self).__init__(task_type_name="databricks", metadata_type=DatabricksJobMetadata)


def get_header() -> typing.Dict[str, str]:
token = get_connector_secret("FLYTE_DATABRICKS_ACCESS_TOKEN")
return {"Authorization": f"Bearer {token}", "content-type": "application/json"}
def get_secret_from_k8s(secret_name: str, secret_key: str, namespace: str) -> Optional[str]:
"""Read a secret from Kubernetes using the Kubernetes Python client.

Args:
secret_name (str): Name of the Kubernetes secret (e.g., "databricks-token").
secret_key (str): Key within the secret (e.g., "token").
namespace (str): Kubernetes namespace where the secret is stored.

Returns:
Optional[str]: The secret value as a string, or None if not found.
"""
try:
import base64

from kubernetes import client, config

# Try to load in-cluster config first (when running in K8s)
try:
config.load_incluster_config()
except config.ConfigException:
# Fall back to kubeconfig (for local testing)
try:
config.load_kube_config()
except Exception as e:
logger.warning(f"Failed to load Kubernetes config: {e}")
return None

v1 = client.CoreV1Api()

try:
secret = v1.read_namespaced_secret(name=secret_name, namespace=namespace)
if secret.data and secret_key in secret.data:
# Kubernetes secrets are base64 encoded
secret_value = base64.b64decode(secret.data[secret_key]).decode("utf-8")
return secret_value
else:
logger.debug(
f"Secret '{secret_name}' exists but key '{secret_key}' not found in namespace '{namespace}'"
)
return None
except client.exceptions.ApiException as e:
if e.status == 404:
logger.debug(f"Secret '{secret_name}' not found in namespace '{namespace}'")
else:
logger.warning(f"Error reading secret '{secret_name}' from namespace '{namespace}': {e}")
return None

except ImportError:
logger.warning("kubernetes Python package not installed - cannot read namespace secrets")
return None
except Exception as e:
logger.warning(f"Unexpected error reading K8s secret: {e}")
return None


def get_databricks_token(
namespace: Optional[str] = None, task_template: Optional[TaskTemplate] = None, secret_name: Optional[str] = None
) -> str:
"""Get the Databricks access token with multi-tenant support.

Token resolution: namespace K8s secret -> FLYTE_DATABRICKS_ACCESS_TOKEN env var.

Args:
namespace (Optional[str]): Kubernetes namespace for workflow-specific token lookup.
task_template (Optional[TaskTemplate]): Optional TaskTemplate (kept for API compatibility).
secret_name (Optional[str]): Custom secret name. Defaults to 'databricks-token'.

Returns:
str: The Databricks access token.

Raises:
ValueError: If no token is found from any source.
"""
token = None
token_source = "unknown"

# Use custom secret name or default to 'databricks-token'
k8s_secret_name = secret_name or "databricks-token"

# Step 1: Try namespace-specific K8s secret (cross-namespace lookup)
if namespace:
logger.info(f"Looking for Databricks token in workflow namespace: {namespace} (secret: {k8s_secret_name})")
token = get_secret_from_k8s(secret_name=k8s_secret_name, secret_key="token", namespace=namespace)

if token:
logger.info(f"Found Databricks token in namespace '{namespace}' from secret '{k8s_secret_name}'")
token_source = f"k8s_namespace:{namespace}/secret:{k8s_secret_name}"
else:
logger.info(
f"Databricks token not found in secret '{k8s_secret_name}' in namespace '{namespace}' - trying fallback"
)
else:
logger.info("No namespace provided for cross-namespace lookup")

# Step 2: Fall back to environment variable (backward compatibility)
if token is None:
logger.info("Falling back to default Databricks token (FLYTE_DATABRICKS_ACCESS_TOKEN)")
try:
token = get_connector_secret("FLYTE_DATABRICKS_ACCESS_TOKEN")
token_source = "env_variable"
except Exception as e:
logger.error(f"Failed to get default Databricks token: {e}")
raise ValueError(
"No Databricks token found from any source:\n"
f"1. Namespace-specific K8s secret '{k8s_secret_name}'\n"
"2. FLYTE_DATABRICKS_ACCESS_TOKEN environment variable\n"
f"Workflow namespace: {namespace or 'N/A'}"
)

if not token:
raise ValueError("Databricks token is empty")

# Log token info without exposing the actual token value
token_preview = f"{token[:8]}..." if len(token) > 8 else "***"
logger.info(f"Using Databricks token from: {token_source} (preview: {token_preview})")

return token


def get_header(task_template: Optional[TaskTemplate] = None, auth_token: Optional[str] = None) -> typing.Dict[str, str]:
"""Get the authorization header for Databricks API calls.

Args:
task_template (Optional[TaskTemplate]): TaskTemplate with workflow-specific secret requests.
auth_token (Optional[str]): Pre-fetched auth token to use directly.

Returns:
typing.Dict[str, str]: Authorization and content-type headers.
"""
if auth_token is None:
auth_token = get_databricks_token(task_template)

return {"Authorization": f"Bearer {auth_token}", "content-type": "application/json"}


def result_state_is_available(life_cycle_state: str) -> bool:
Expand Down
17 changes: 10 additions & 7 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class DatabricksV2(Spark):
Use the form <account>.cloud.databricks.com.
databricks_service_credential_provider (Optional[str]): Provider name for Databricks
Service Credentials for S3 access. Falls back to FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER env var.
databricks_token_secret (Optional[str]): Custom name for the K8s secret containing
the Databricks token. Defaults to 'databricks-token' if not specified.
notebook_path (Optional[str]): Path to Databricks notebook
(e.g., "/Users/user@example.com/notebook").
notebook_base_parameters (Optional[Dict[str, str]]): Parameters to pass to the notebook.
Expand Down Expand Up @@ -194,12 +196,11 @@ class DatabricksV2(Spark):
"""

databricks_conf: Optional[Dict[str, Union[str, dict]]] = None
databricks_instance: Optional[str] = None # Falls back to FLYTE_DATABRICKS_INSTANCE env var
databricks_service_credential_provider: Optional[str] = (
None # Falls back to FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER env var
)
notebook_path: Optional[str] = None # Path to Databricks notebook (e.g., "/Users/user@example.com/notebook")
notebook_base_parameters: Optional[Dict[str, str]] = None # Parameters to pass to the notebook
databricks_instance: Optional[str] = None
databricks_service_credential_provider: Optional[str] = None
databricks_token_secret: Optional[str] = None
notebook_path: Optional[str] = None
notebook_base_parameters: Optional[Dict[str, str]] = None


# This method does not reset the SparkSession since it's a bit hard to handle multiple
Expand Down Expand Up @@ -311,6 +312,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
cfg = cast(DatabricksV2, self.task_config)
if cfg.databricks_service_credential_provider:
custom_dict["databricksServiceCredentialProvider"] = cfg.databricks_service_credential_provider
if cfg.databricks_token_secret:
custom_dict["databricksTokenSecret"] = cfg.databricks_token_secret
if cfg.notebook_path:
custom_dict["notebookPath"] = cfg.notebook_path
if cfg.notebook_base_parameters:
Expand Down Expand Up @@ -479,7 +482,7 @@ def execute(self, **kwargs) -> Any:
if ctx.execution_state and ctx.execution_state.is_local_execution():
return AsyncConnectorExecutorMixin.execute(self, **kwargs)
except Exception as e:
click.secho(f"Connector failed to run the task with error: {e}", fg="red")
click.secho(f"Connector failed to run the task with error: {e}", fg="red")
click.secho("Falling back to local execution", fg="red")
return PythonFunctionTask.execute(self, **kwargs)

Expand Down
Loading
Loading