Source code for kubeflow.spark.api.spark_client

# Copyright 2025 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""SparkClient for Kubeflow SDK."""

from collections.abc import Iterator
import logging

from pyspark.sql import SparkSession

from kubeflow.common.types import KubernetesBackendConfig
from kubeflow.spark.backends.kubernetes import KubernetesBackend
from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url
from kubeflow.spark.types.types import Driver, Executor, SparkConnectInfo

logger = logging.getLogger(__name__)


[docs] class SparkClient: """Stateless Spark client for Kubeflow."""
[docs] def __init__(self, backend_config: KubernetesBackendConfig | None = None): """Initialize SparkClient.""" if backend_config is None: backend_config = KubernetesBackendConfig() if isinstance(backend_config, KubernetesBackendConfig): self.backend = KubernetesBackend(backend_config) else: raise ValueError(f"Invalid backend config: {type(backend_config)}")
[docs] def connect( self, base_url: str | None = None, token: str | None = None, num_executors: int | None = None, resources_per_executor: dict[str, str] | None = None, spark_conf: dict[str, str] | None = None, driver: Driver | None = None, executor: Executor | None = None, options: list | None = None, timeout: int = 300, connect_timeout: int = 120, ) -> SparkSession: """Connect to or create a SparkConnect session (KEP-107 lines 298-347). This method supports two modes based on parameters: - **Connect mode**: When `base_url` is provided, connects to an existing Spark Connect server - **Create mode**: When `base_url` is not provided, creates a new Spark Connect session Args: base_url: Optional URL to existing Spark Connect server (e.g., "sc://server:15002"). If provided, connects to existing server. If None, creates new session. token: Optional authentication token for existing server. num_executors: Number of executor instances (create mode only). resources_per_executor: Resource requirements per executor as dict. Format: `{"cpu": "5", "memory": "10Gi"}` (create mode only). spark_conf: Spark configuration dictionary (create mode only). driver: Driver configuration object (create mode only). executor: Executor configuration object (create mode only). options: List of configuration options (create mode only). Use Name option for custom session name. timeout: Timeout in seconds to wait for session ready. connect_timeout: Timeout in seconds for SparkSession.getOrCreate() (create mode only). Returns: SparkSession connected to Spark (self-managing). Examples: # Connect to existing server spark = client.connect(base_url="sc://server:15002") # Create with simple parameters spark = client.connect( num_executors=5, resources_per_executor={"cpu": "5", "memory": "10Gi"}, spark_conf={"spark.sql.adaptive.enabled": "true"} ) # Create with custom name from kubeflow.spark.types.options import Name spark = client.connect(options=[Name("my-session")]) # Create with advanced configuration spark = client.connect( driver=Driver(resources={"cpu": "2", "memory": "4Gi"}), executor=Executor( num_instances=5, resources_per_executor={"cpu": "4", "memory": "8Gi"} ) ) # Minimal - use all defaults (auto-generated name) spark = client.connect() Note: Server port defaults to 15002 (Spark Connect gRPC). PySpark and server Spark major.minor should match; see constants and pyproject.toml [spark]. """ if base_url: validate_spark_connect_url(base_url) builder = SparkSession.builder.remote(base_url) if token: builder = builder.config("spark.connect.authenticate.token", token) return builder.getOrCreate() return self.backend.create_and_connect( num_executors=num_executors, resources_per_executor=resources_per_executor, spark_conf=spark_conf, driver=driver, executor=executor, options=options, timeout=timeout, connect_timeout=connect_timeout, )
[docs] def list_sessions(self) -> list[SparkConnectInfo]: """List all SparkConnect sessions.""" return self.backend.list_sessions()
[docs] def get_session(self, name: str) -> SparkConnectInfo: """Get session info by name.""" return self.backend.get_session(name)
[docs] def delete_session(self, name: str) -> None: """Delete a SparkConnect session.""" self.backend.delete_session(name)
[docs] def get_session_logs(self, name: str, follow: bool = False) -> Iterator[str]: """Get logs from a session.""" return self.backend.get_session_logs(name, follow=follow)