Source code for kubeflow.trainer.types.types

# Copyright 2024 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from urllib.parse import urlparse

import kubeflow.common.constants as common_constants
from kubeflow.trainer.constants import constants


# Configuration for the Custom Trainer.
[docs] @dataclass class CustomTrainer: """Custom Trainer configuration. Configure the self-contained function that encapsulates the entire model training process. Args: func (`Callable`): The function that encapsulates the entire model training process. func_args (`Optional[dict]`): The arguments to pass to the function. image (`Optional[str]`): The optional container image to use in TrainJob. packages_to_install (`Optional[list[str]]`): A list of Python packages to install before running the function. pip_index_urls (`list[str]`): The PyPI URLs from which to install Python packages. The first URL will be the index-url, and remaining ones are extra-index-urls. num_nodes (`Optional[int]`): The number of nodes to use for training. resources_per_node (`Optional[dict]`): The computing resources to allocate per node. ```python resources_per_node = {"gpu": 4, "cpu": 5, "memory": "10G"} ``` If your compute supports fractional GPUs (e.g. multi-instance GPU), you can set the resources as follows (request 1 GPU slice of 5Gb) : ```python resources_per_node = {"mig-1g.5gb": 1} ``` env (`Optional[dict[str, str]]`): The environment variables to set in the training nodes. """ func: Callable func_args: dict | None = None image: str | None = None packages_to_install: list[str] | None = None pip_index_urls: list[str] = field( default_factory=lambda: list(constants.DEFAULT_PIP_INDEX_URLS) ) num_nodes: int | None = None resources_per_node: dict | None = None env: dict[str, str] | None = None
# Configuration for the Custom Trainer Container.
[docs] @dataclass class CustomTrainerContainer: """Custom Trainer Container configuration. Configure the container image that encapsulates the entire model training process. Args: image (`str`): The container image that encapsulates the entire model training process. num_nodes (`Optional[int]`): The number of nodes to use for training. resources_per_node (`Optional[dict]`): The computing resources to allocate per node. ```python resources_per_node = {"gpu": 4, "cpu": 5, "memory": "10G"} ``` If your compute supports fractional GPUs (e.g. multi-instance GPU), you can set the resources as follows (request 1 GPU slice of 5Gb) : ```python resources_per_node = {"mig-1g.5gb": 1} ``` env (`Optional[dict[str, str]]`): The environment variables to set in the training nodes. """ image: str num_nodes: int | None = None resources_per_node: dict | None = None env: dict[str, str] | None = None
# TODO(Electronic-Waste): Add more loss functions. # Loss function for the TorchTune LLM Trainer. class Loss(Enum): """Loss function for the TorchTune LLM Trainer.""" CEWithChunkedOutputLoss = "torchtune.modules.loss.CEWithChunkedOutputLoss" # Data type for the TorchTune LLM Trainer. class DataType(Enum): """Data type for the TorchTune LLM Trainer.""" BF16 = "bf16" FP32 = "fp32" # Data file type for the TorchTune LLM Trainer. class DataFormat(Enum): """Data file type for the TorchTune LLM Trainer.""" JSON = "json" CSV = "csv" PARQUET = "parquet" ARROW = "arrow" TEXT = "text" XML = "xml" # Configuration for the TorchTune Instruct dataset. @dataclass class TorchTuneInstructDataset: """ Configuration for the custom dataset with user instruction prompts and model responses. REF: https://pytorch.org/torchtune/main/generated/torchtune.datasets.instruct_dataset.html Args: source (`Optional[DataFormat]`): Data file type. split (`Optional[str]`): The split of the dataset to use. You can use this argument to load a subset of a given split, e.g. split="train[:10%]". Default is `train`. train_on_input (`Optional[bool]`): Whether the model is trained on the user prompt or not. Default is False. new_system_prompt (`Optional[str]`): The new system prompt to use. If specified, prepend a system message. This can serve as instructions to guide the model response. Default is None. column_map (`Optional[Dict[str, str]]`): A mapping to change the expected "input" and "output" column names to the actual column names in the dataset. Keys should be "input" and "output" and values should be the actual column names. Default is None, keeping the default "input" and "output" column names. """ source: DataFormat | None = None split: str | None = None train_on_input: bool | None = None new_system_prompt: str | None = None column_map: dict[str, str] | None = None @dataclass class LoraConfig: """Configuration for the LoRA/QLoRA/DoRA. REF: https://meta-pytorch.org/torchtune/main/tutorials/memory_optimizations.html Args: apply_lora_to_mlp (`Optional[bool]`): Whether to apply LoRA to the MLP in each transformer layer. apply_lora_to_output (`Optional[bool]`): Whether to apply LoRA to the model's final output projection. lora_attn_modules (`list[str]`): A list of strings specifying which layers of the model to apply LoRA, default is ["q_proj", "v_proj", "output_proj"]: 1. "q_proj" applies LoRA to the query projection layer. 2. "k_proj" applies LoRA to the key projection layer. 3. "v_proj" applies LoRA to the value projection layer. 4. "output_proj" applies LoRA to the attention output projection layer. lora_rank (`Optional[int]`): The rank of the low rank decomposition. lora_alpha (`Optional[int]`): The scaling factor that adjusts the magnitude of the low-rank matrices' output. lora_dropout (`Optional[float]`): The probability of applying Dropout to the low rank updates. quantize_base (`Optional[bool]`): Whether to enable model quantization. use_dora (`Optional[bool]`): Whether to enable DoRA. """ apply_lora_to_mlp: bool | None = None apply_lora_to_output: bool | None = None lora_attn_modules: list[str] = field( default_factory=lambda: ["q_proj", "v_proj", "output_proj"] ) lora_rank: int | None = None lora_alpha: int | None = None lora_dropout: float | None = None quantize_base: bool | None = None use_dora: bool | None = None # Configuration for the TorchTune LLM Trainer. @dataclass class TorchTuneConfig: """TorchTune LLM Trainer configuration. Configure the parameters in the TorchTune LLM Trainer that already includes the fine-tuning logic. Args: dtype (`Optional[Dtype]`): The underlying data type used to represent the model and optimizer parameters. Currently, we only support `bf16` and `fp32`. batch_size (`Optional[int]`): The number of samples processed before updating model weights. epochs (`Optional[int]`): The number of complete passes over the training dataset. loss (`Optional[Loss]`): The loss algorithm we use to fine-tune the LLM, e.g. `torchtune.modules.loss.CEWithChunkedOutputLoss`. num_nodes (`Optional[int]`): The number of nodes to use for training. peft_config (`Optional[LoraConfig]`): Configuration for the PEFT(Parameter-Efficient Fine-Tuning), including LoRA/QLoRA/DoRA, etc. dataset_preprocess_config (`Optional[TorchTuneInstructDataset]`): Configuration for the dataset preprocessing. resources_per_node (`Optional[Dict]`): The computing resources to allocate per node. """ dtype: DataType | None = None batch_size: int | None = None epochs: int | None = None loss: Loss | None = None num_nodes: int | None = None peft_config: LoraConfig | None = None dataset_preprocess_config: TorchTuneInstructDataset | None = None resources_per_node: dict | None = None # Configuration for the Builtin Trainer.
[docs] @dataclass class BuiltinTrainer: """ Builtin Trainer configuration. Configure the builtin trainer that already includes the fine-tuning logic, requiring only parameter adjustments. Args: config (`TorchTuneConfig`): The configuration for the builtin trainer. """ config: TorchTuneConfig
# Change it to list: BUILTIN_CONFIGS, once we support more Builtin Trainer configs. TORCH_TUNE = BuiltinTrainer.__annotations__["config"].__name__.lower().replace("config", "") class TrainerType(Enum): CUSTOM_TRAINER = CustomTrainer.__name__ BUILTIN_TRAINER = BuiltinTrainer.__name__ # Representation for the Trainer of the runtime. @dataclass class RuntimeTrainer: trainer_type: TrainerType framework: str image: str num_nodes: int = 1 # The default value is set in the APIs. device: str = common_constants.UNKNOWN device_count: str = common_constants.UNKNOWN __command: tuple[str, ...] = field(init=False, repr=False) @property def command(self) -> tuple[str, ...]: return self.__command def set_command(self, command: tuple[str, ...]): self.__command = command # Representation for the Training Runtime. @dataclass class Runtime: name: str trainer: RuntimeTrainer pretrained_model: str | None = None # Representation for the TrainJob steps. @dataclass class Step: name: str status: str | None pod_name: str device: str = common_constants.UNKNOWN device_count: str = common_constants.UNKNOWN # Representation for the TrainJob. @dataclass class TrainJob: name: str runtime: Runtime steps: list[Step] num_nodes: int creation_timestamp: datetime status: str = common_constants.UNKNOWN # Representation for TrainJob events. @dataclass class Event: """Event object that represents a Kubernetes event related to a TrainJob. Args: involved_object_kind (`str`): The kind of object this event is about (e.g., 'TrainJob', 'Pod'). involved_object_name (`str`): The name of the object this event is about. message (`str`): Human-readable description of the event. reason (`str`): Short, machine understandable string describing why this event was generated. event_time (`datetime`): The time at which the event was first recorded. """ involved_object_kind: str involved_object_name: str message: str reason: str event_time: datetime @dataclass class BaseInitializer(abc.ABC): """Base class for all initializers""" storage_uri: str
[docs] @dataclass class HuggingFaceDatasetInitializer(BaseInitializer): """Configuration for downloading datasets from HuggingFace Hub. Args: storage_uri (`str`): The HuggingFace Hub dataset identifier in the format 'hf://username/repo_name'. ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download. access_token (`Optional[str]`): HuggingFace Hub access token for private datasets. """ ignore_patterns: list[str] | None = None access_token: str | None = None
[docs] def __post_init__(self): """Validate HuggingFaceDatasetInitializer parameters.""" if not self.storage_uri.startswith("hf://"): raise ValueError(f"storage_uri must start with 'hf://', got {self.storage_uri}") if urlparse(self.storage_uri).path == "": raise ValueError( "storage_uri: must have absolute path with 'hf://<user_name>/<dataset_name>', got " f"{self.storage_uri}" )
[docs] @dataclass class S3DatasetInitializer(BaseInitializer): """Configuration for downloading datasets from S3-compatible storage. Args: storage_uri (`str`): The S3 URI for the dataset in the format 's3://bucket-name/path/to/dataset'. ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download. endpoint (`Optional[str]`): Custom S3 endpoint URL. access_key_id (`Optional[str]`): Access key for authentication. secret_access_key (`Optional[str]`): Secret key for authentication. region (`Optional[str]`): Region used in instantiating the client. role_arn (`Optional[str]`): The ARN of the role you want to assume. """ ignore_patterns: list[str] | None = None endpoint: str | None = None access_key_id: str | None = None secret_access_key: str | None = None region: str | None = None role_arn: str | None = None
[docs] def __post_init__(self): """Validate S3DatasetInitializer parameters.""" if not self.storage_uri.startswith("s3://"): raise ValueError(f"storage_uri must start with 's3://', got {self.storage_uri}")
[docs] @dataclass class DataCacheInitializer(BaseInitializer): """Configuration for distributed data caching system for training workloads. Args: storage_uri (`str`): The URI for the cached data in the format 'cache://<SCHEMA_NAME>/<TABLE_NAME>'. This specifies the location where the data cache will be stored and accessed. metadata_loc (`str`): The metadata file path of an iceberg table. num_data_nodes (`int`): The number of data nodes in the distributed cache system. Must be greater than 1. head_cpu (`Optional[str]`): The CPU resources to allocate for the cache head node. head_mem (`Optional[str]`): The memory resources to allocate for the cache head node. worker_cpu (`Optional[str]`): The CPU resources to allocate for each cache worker node. worker_mem (`Optional[str]`): The memory resources to allocate for each cache worker node. iam_role (`Optional[str]`): The IAM role to use for accessing metadata_loc file. """ metadata_loc: str num_data_nodes: int head_cpu: str | None = None head_mem: str | None = None worker_cpu: str | None = None worker_mem: str | None = None iam_role: str | None = None
[docs] def __post_init__(self): """Validate DataCacheInitializer parameters.""" if self.num_data_nodes <= 1: raise ValueError(f"num_data_nodes must be greater than 1, got {self.num_data_nodes}") # Validate storage_uri format if not self.storage_uri.startswith("cache://"): raise ValueError(f"storage_uri must start with 'cache://', got {self.storage_uri}") uri_path = self.storage_uri[len("cache://") :] parts = uri_path.split("/") if len(parts) != 2: raise ValueError( f"storage_uri must be in format " f"'cache://<SCHEMA_NAME>/<TABLE_NAME>', got {self.storage_uri}" )
[docs] @dataclass class HuggingFaceModelInitializer(BaseInitializer): """Configuration for downloading models from HuggingFace Hub. Args: storage_uri (`str`): The HuggingFace Hub model identifier in the format 'hf://username/repo_name'. ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download. access_token (`Optional[str]`): HuggingFace Hub access token. """ ignore_patterns: list[str] | None = field( default_factory=lambda: constants.INITIALIZER_DEFAULT_IGNORE_PATTERNS ) access_token: str | None = None
[docs] def __post_init__(self): """Validate HuggingFaceModelInitializer parameters.""" if not self.storage_uri.startswith("hf://"): raise ValueError(f"storage_uri must start with 'hf://', got {self.storage_uri}")
[docs] @dataclass class S3ModelInitializer(BaseInitializer): """Configuration for downloading models from S3-compatible storage. Args: storage_uri (`str`): The S3 URI for the model in the format 's3://bucket-name/path/to/model'. ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download. Defaults to `['*.msgpack', '*.h5', '*.bin', '*.pt', '*.pth']`. endpoint (`Optional[str]`): Custom S3 endpoint URL. access_key_id (`Optional[str]`): Access key for authentication. secret_access_key (`Optional[str]`): Secret key for authentication. region (`Optional[str]`): Region used in instantiating the client. role_arn (`Optional[str]`): The ARN of the role you want to assume. """ ignore_patterns: list[str] | None = field( default_factory=lambda: constants.INITIALIZER_DEFAULT_IGNORE_PATTERNS ) endpoint: str | None = None access_key_id: str | None = None secret_access_key: str | None = None region: str | None = None role_arn: str | None = None
[docs] def __post_init__(self): """Validate S3ModelInitializer parameters.""" if not self.storage_uri.startswith("s3://"): raise ValueError(f"storage_uri must start with 's3://', got {self.storage_uri}")
[docs] @dataclass class Initializer: """Initializer defines configurations for dataset and pre-trained model initialization Args: dataset (`Optional[Union[HuggingFaceDatasetInitializer, S3DatasetInitializer, DataCacheInitializer]]`): The configuration for one of the supported dataset initializers. model (`Optional[Union[HuggingFaceModelInitializer, S3ModelInitializer]]`): The configuration for one of the supported model initializers. """ # noqa: E501 dataset: HuggingFaceDatasetInitializer | S3DatasetInitializer | DataCacheInitializer | None = ( None ) model: HuggingFaceModelInitializer | S3ModelInitializer | None = None
# TODO (andreyvelich): Add train() and optimize() methods to this class. @dataclass class TrainJobTemplate: """TrainJob template configuration. Args: trainer (`CustomTrainer`): Configuration for a CustomTrainer. runtime (`Optional[Union[str, Runtime]]`): Optional, reference to one of the existing runtimes. It can accept the runtime name or Runtime object from the `get_runtime()` API. Defaults to the torch-distributed runtime if not provided. initializer (`Optional[Initializer]`): Optional configuration for the dataset and model initializers. """ trainer: CustomTrainer runtime: str | Runtime | None = None initializer: Initializer | None = None def keys(self): return ["trainer", "runtime", "initializer"] def __getitem__(self, key): return getattr(self, key)