Source code for kubeflow.trainer.options.kubernetes

# Copyright 2025 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Kubernetes-specific training options for the Kubeflow Trainer SDK."""

import dataclasses
from dataclasses import dataclass
from typing import Any

from kubeflow.trainer.backends.base import RuntimeBackend
from kubeflow.trainer.types.types import BuiltinTrainer, CustomTrainer, CustomTrainerContainer


[docs] @dataclass class ContainerPatch: """Configuration for patching a specific container in a pod. Args: name: Name of the container to patch (must exist in the Runtime). env: Environment variables to add/merge with the container. Each dict should have 'name' and 'value' or 'valueFrom' keys. volume_mounts: Volume mounts to add/merge with the container. Each dict should have 'name' and 'mountPath' keys at minimum. security_context: Security context for the container. """ name: str env: list[dict] | None = None volume_mounts: list[dict] | None = None security_context: dict | None = None
[docs] def __post_init__(self): """Validate the container patch configuration.""" if not self.name or not self.name.strip(): raise ValueError("Container name must be a non-empty string") if self.env is not None: if not isinstance(self.env, list): raise ValueError("env must be a list of dictionaries") for env_var in self.env: if not isinstance(env_var, dict): raise ValueError("Each env entry must be a dictionary") if "name" not in env_var: raise ValueError("Each env entry must have a 'name' key") if not env_var.get("name"): raise ValueError("env 'name' must be a non-empty string") if "value" not in env_var and "valueFrom" not in env_var: raise ValueError("Each env entry must have either 'value' or 'valueFrom' key") if "valueFrom" in env_var: value_from = env_var["valueFrom"] if not isinstance(value_from, dict): raise ValueError("env 'valueFrom' must be a dictionary") valid_keys = {"configMapKeyRef", "secretKeyRef", "fieldRef", "resourceFieldRef"} if not any(key in value_from for key in valid_keys): raise ValueError( f"env 'valueFrom' must contain one of: {', '.join(valid_keys)}" ) if self.volume_mounts is not None: if not isinstance(self.volume_mounts, list): raise ValueError("volume_mounts must be a list of dictionaries") for mount in self.volume_mounts: if not isinstance(mount, dict): raise ValueError("Each volume_mounts entry must be a dictionary") if "name" not in mount: raise ValueError("Each volume_mounts entry must have a 'name' key") if not mount.get("name"): raise ValueError("volume_mounts 'name' must be a non-empty string") if "mountPath" not in mount: raise ValueError("Each volume_mounts entry must have a 'mountPath' key") mount_path = mount.get("mountPath") if not mount_path or not isinstance(mount_path, str): raise ValueError("volume_mounts 'mountPath' must be a non-empty string") if not mount_path.startswith("/"): raise ValueError( f"volume_mounts 'mountPath' must be an absolute path " f"(start with /): {mount_path}" )
[docs] @dataclass class PodSpecPatch: """Configuration for patching pod spec fields that managers are permitted to set. Args: service_account_name: Service account to use for the pods. volumes: Volumes to add/merge with the pod. init_containers: Init containers to add/merge with the pod. containers: Containers to add/merge with the pod. image_pull_secrets: Image pull secrets for the pods. security_context: Pod-level security context. node_selector: Node selector to place pods on specific nodes. affinity: Affinity rules for pod scheduling. tolerations: Tolerations for pod scheduling. scheduling_gates: Scheduling gates for the pods. """ service_account_name: str | None = None volumes: list[dict] | None = None init_containers: list[ContainerPatch] | None = None containers: list[ContainerPatch] | None = None image_pull_secrets: list[dict] | None = None security_context: dict | None = None node_selector: dict[str, str] | None = None affinity: dict | None = None tolerations: list[dict] | None = None scheduling_gates: list[dict] | None = None
[docs] @dataclass class PodTemplatePatch: """Configuration for patching a Pod template within a Job. Args: metadata: Metadata patches (labels, annotations) for the Pod template. spec: Pod spec patches. """ metadata: dict | None = None spec: PodSpecPatch | None = None
[docs] @dataclass class JobSpecPatch: """Configuration for patching the Job spec. Args: template: Pod template patches for this Job. """ template: PodTemplatePatch | None = None
[docs] @dataclass class JobTemplatePatch: """Configuration for patching a Job template within a replicated job. Args: metadata: Metadata patches (labels, annotations) for the Job template. spec: Job spec patches. """ metadata: dict | None = None spec: JobSpecPatch | None = None
[docs] @dataclass class ReplicatedJobPatch: """Configuration for patching a specific replicated job within the JobSet. Args: name: Name of the replicated job to patch (e.g. "node", "launcher"). template: Job template patches. """ name: str template: JobTemplatePatch | None = None
[docs] @dataclass class JobSetSpecPatch: """Configuration for patching the JobSet spec. Args: replicated_jobs: Per-job patches, keyed by job name. """ replicated_jobs: list[ReplicatedJobPatch] | None = None
[docs] @dataclass class JobSetTemplatePatch: """Configuration for patching the JobSet template. Args: metadata: Metadata patches (labels, annotations) for the JobSet. spec: JobSet spec patches. """ metadata: dict | None = None spec: JobSetSpecPatch | None = None
[docs] @dataclass class TrainingRuntimeSpecPatch: """Configuration for patching the TrainingRuntime spec. Args: template: JobSet template patches. """ template: JobSetTemplatePatch | None = None
[docs] @dataclass class RuntimePatch: """Add runtime patches to the TrainJob (.spec.runtimePatches). Runtime patches allow controllers, admission webhooks, and custom clients to attach structured patches to a TrainJob without conflicting with each other. Each patch is keyed by a unique manager field, which is automatically set to "trainer.kubeflow.org/kubeflow-sdk" by the SDK. Supported backends: - Kubernetes Args: training_runtime_spec: Allowed patches for ClusterTrainingRuntime or TrainingRuntime-based jobs. """ training_runtime_spec: TrainingRuntimeSpecPatch | None = None manager: str = dataclasses.field( default="trainer.kubeflow.org/kubeflow-sdk", init=False, repr=False )
[docs] def __call__( self, job_spec: dict[str, Any], trainer: CustomTrainer | BuiltinTrainer | None, backend: RuntimeBackend, ) -> None: """Apply runtime patch to the job specification. Args: job_spec: Job specification dictionary to modify. trainer: Optional trainer instance for context. backend: Backend instance for validation. Raises: ValueError: If backend does not support runtime patches. """ from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend if not isinstance(backend, KubernetesBackend): raise ValueError( f"RuntimePatch option is not compatible with {type(backend).__name__}. " f"Supported backends: KubernetesBackend" ) spec = job_spec.setdefault("spec", {}) runtime_patches = spec.setdefault("runtimePatches", []) runtime_patches.append(_patch_to_dict(self))
def _to_camel_case(snake_str: str) -> str: """Convert a snake_case string to camelCase.""" parts = snake_str.split("_") return parts[0] + "".join(word.capitalize() for word in parts[1:]) def _patch_to_dict(obj: Any) -> Any: """Recursively convert a patch dataclass to its API dict representation. Converts snake_case field names to camelCase and strips None/empty values. Non-dataclass values (dicts, lists, primitives) are passed through as-is. """ if not dataclasses.is_dataclass(obj) or isinstance(obj, type): return obj result: dict[str, Any] = {} for f in dataclasses.fields(obj): value = getattr(obj, f.name) if value is None: continue key = _to_camel_case(f.name) if isinstance(value, list): converted = [_patch_to_dict(item) for item in value] if converted: result[key] = converted elif dataclasses.is_dataclass(value): converted = _patch_to_dict(value) if converted: result[key] = converted else: result[key] = value return result
[docs] @dataclass class Labels: """Add labels to the TrainJob resource metadata (.metadata.labels). Supported backends: - Kubernetes Args: labels: Dictionary of label key-value pairs to add to TrainJob metadata. """ labels: dict[str, str]
[docs] def __call__( self, job_spec: dict[str, Any], trainer: CustomTrainer | BuiltinTrainer | None, backend: RuntimeBackend, ) -> None: """Apply labels to the job specification. Args: job_spec: Job specification dictionary to modify. trainer: Optional trainer instance for context. backend: Backend instance for validation. Raises: ValueError: If backend does not support labels. """ from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend if not isinstance(backend, KubernetesBackend): raise ValueError( f"Labels option is not compatible with {type(backend).__name__}. " f"Supported backends: KubernetesBackend" ) metadata = job_spec.setdefault("metadata", {}) metadata["labels"] = self.labels
[docs] @dataclass class Annotations: """Add annotations to the TrainJob resource metadata (.metadata.annotations). Supported backends: - Kubernetes Args: annotations: Dictionary of annotation key-value pairs to add to TrainJob metadata. """ annotations: dict[str, str]
[docs] def __call__( self, job_spec: dict[str, Any], trainer: CustomTrainer | BuiltinTrainer | None, backend: RuntimeBackend, ) -> None: """Apply annotations to the job specification. Args: job_spec: Job specification dictionary to modify. trainer: Optional trainer instance for context. backend: Backend instance for validation. Raises: ValueError: If backend does not support annotations. """ from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend if not isinstance(backend, KubernetesBackend): raise ValueError( f"Annotations option is not compatible with {type(backend).__name__}. " f"Supported backends: KubernetesBackend" ) metadata = job_spec.setdefault("metadata", {}) metadata["annotations"] = self.annotations
[docs] @dataclass class TrainerCommand: """Override the trainer container command (.spec.trainer.command). Can only be used with CustomTrainerContainer. CustomTrainer generates its own command from the function, and BuiltinTrainer uses pre-configured commands. Supported backends: - Kubernetes Args: command: List of command strings to override the default trainer command. """ command: list[str]
[docs] def __call__( self, job_spec: dict[str, Any], trainer: CustomTrainer | BuiltinTrainer | CustomTrainerContainer | None, backend: RuntimeBackend, ) -> None: """Apply trainer command override to the job specification. Args: job_spec: The job specification to modify. trainer: Optional trainer context for validation. backend: Backend instance for validation. Raises: ValueError: If backend doesn't support or trainer type conflicts. """ from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend if not isinstance(backend, KubernetesBackend): raise ValueError( f"TrainerCommand option is not compatible with {type(backend).__name__}. " f"Supported backends: KubernetesBackend" ) if trainer is not None and not isinstance(trainer, CustomTrainerContainer): raise ValueError( "TrainerCommand can only be used with CustomTrainerContainer. " "CustomTrainer generates its own command from the function, and " "BuiltinTrainer uses pre-configured commands." ) spec = job_spec.setdefault("spec", {}) trainer_spec = spec.setdefault("trainer", {}) trainer_spec["command"] = self.command
[docs] @dataclass class TrainerArgs: """Override the trainer container arguments (.spec.trainer.args). Can only be used with CustomTrainerContainer. CustomTrainer generates its own arguments from the function, and BuiltinTrainer uses pre-configured arguments. Supported backends: - Kubernetes Args: args: List of argument strings to override the default trainer arguments. """ args: list[str]
[docs] def __call__( self, job_spec: dict[str, Any], trainer: CustomTrainer | BuiltinTrainer | CustomTrainerContainer | None, backend: RuntimeBackend, ) -> None: """Apply trainer args override to the job specification. Args: job_spec: The job specification to modify. trainer: Optional trainer context for validation. backend: Backend instance for validation. Raises: ValueError: If backend doesn't support or trainer type conflicts. """ from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend if not isinstance(backend, KubernetesBackend): raise ValueError( f"TrainerArgs option is not compatible with {type(backend).__name__}. " f"Supported backends: KubernetesBackend" ) if trainer is not None and not isinstance(trainer, CustomTrainerContainer): raise ValueError( "TrainerArgs can only be used with CustomTrainerContainer. " "CustomTrainer generates its own arguments from the function, and " "BuiltinTrainer uses pre-configured arguments." ) spec = job_spec.setdefault("spec", {}) trainer_spec = spec.setdefault("trainer", {}) trainer_spec["args"] = self.args