Source code for kubeflow.hub.api.model_registry_client

# Copyright 2025 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING

from kubeflow.hub.types.types import StorageConfig

if TYPE_CHECKING:
    from model_registry.types import (
        ModelArtifact,
        ModelVersion,
        RegisteredModel,
        SupportedTypes,
    )


[docs] class ModelRegistryClient: """Client for Kubeflow Model Registry operations. Requires the model-registry package to be installed. Install it with: pip install 'kubeflow[hub]' """
[docs] def __init__( self, base_url: str, port: int | None = None, *, author: str | None = None, is_secure: bool | None = None, user_token: str | None = None, custom_ca: str | None = None, ): """Initialize the ModelRegistryClient. Args: base_url: Base URL of the model registry server including scheme. Examples: "https://registry.example.com", "http://localhost" The scheme is used to infer is_secure and port if not explicitly provided. Keyword Args: port: Server port. If not provided, inferred from base_url scheme: - https:// defaults to 443 - http:// defaults to 8080 - no scheme defaults to 443 author: Name of the author. is_secure: Whether to use a secure connection. If not provided, inferred from base_url: - https:// sets is_secure=True - http:// sets is_secure=False - no scheme defaults to True user_token: The PEM-encoded user token as a string. custom_ca: Path to the PEM-encoded root certificates as a string. Raises: ImportError: If model-registry is not installed. Examples: ModelRegistryClient("https://example.org", port=456) # port kwarg ModelRegistryClient("https://example.org:456") # base_url (including port) ModelRegistryClient("https://example.org") # default port (`443` for https, `8080` for http) """ try: from model_registry import ModelRegistry except ImportError as e: raise ImportError( "model-registry is not installed. Install it with:\n\n" # fmt: skip " pip install 'kubeflow[hub]'\n" ) from e is_http = base_url.startswith("http://") if is_secure is None: is_secure = not is_http if port is None: port = 8080 if is_http else 443 self._registry = ModelRegistry( server_address=base_url, port=port, author=author, # type: ignore[arg-type] is_secure=is_secure, user_token=user_token, custom_ca=custom_ca, )
[docs] def register_model( self, name: str, uri: str, *, version: str, model_format_name: str | None = None, model_format_version: str | None = None, author: str | None = None, owner: str | None = None, version_description: str | None = None, metadata: Mapping[str, SupportedTypes] | None = None, storage_config: StorageConfig | None = None, ) -> RegisteredModel: """Register a model. This registers a model in the model registry. The model is not downloaded, and has to be stored prior to registration. Most models can be registered using their URI, along with an optional `storage_config` describing how KServe should fetch the model at inference time. URI builder utilities are recommended when referring to specialized storage; for example `utils.s3_uri_from` when using S3 object storage data connections. Args: name: Name of the model. uri: URI of the model. Keyword Args: version: Version of the model. Has to be unique. model_format_name: Name of the model format (e.g., "pytorch", "tensorflow", "onnx"). Used by KServe to select the appropriate serving runtime. model_format_version: Version of the model format (e.g., "2.0", "1.15"). author: Author of the model. Defaults to the client author. owner: Owner of the model. Defaults to the client author. version_description: Description of the model version. metadata: Additional version metadata. storage_config: Storage credentials for the model artifact. Groups `storage_key`, `storage_path`, and `service_account_name` used by KServe's StorageInitializer. See `StorageConfig` for details. Returns: Registered model. """ storage = storage_config or StorageConfig() return self._registry.register_model( name=name, uri=uri, model_format_name=model_format_name, # type: ignore[arg-type] model_format_version=model_format_version, # type: ignore[arg-type] version=version, author=author, owner=owner, description=version_description, metadata=metadata, storage_key=storage.storage_key, storage_path=storage.storage_path, service_account_name=storage.service_account_name, )
[docs] def update_model(self, model: RegisteredModel) -> RegisteredModel: """Update a registered model. Args: model: The registered model to update. Must have an ID. Returns: Updated registered model. Raises: TypeError: If model is not a RegisteredModel instance. model_registry.exceptions.StoreError: If model does not have an ID. """ from model_registry.types import RegisteredModel if not isinstance(model, RegisteredModel): raise TypeError(f"Expected RegisteredModel, got {type(model).__name__}. ") return self._registry.update(model)
[docs] def update_model_version(self, model_version: ModelVersion) -> ModelVersion: """Update a model version. Args: model_version: The model version to update. Must have an ID. Returns: Updated model version. Raises: TypeError: If model_version is not a ModelVersion instance. model_registry.exceptions.StoreError: If model version does not have an ID. """ from model_registry.types import ModelVersion if not isinstance(model_version, ModelVersion): raise TypeError(f"Expected ModelVersion, got {type(model_version).__name__}. ") return self._registry.update(model_version)
[docs] def update_model_artifact(self, model_artifact: ModelArtifact) -> ModelArtifact: """Update a model artifact. Args: model_artifact: The model artifact to update. Must have an ID. Returns: Updated model artifact. Raises: TypeError: If model_artifact is not a ModelArtifact instance. model_registry.exceptions.StoreError: If model artifact does not have an ID. """ from model_registry.types import ModelArtifact if not isinstance(model_artifact, ModelArtifact): raise TypeError(f"Expected ModelArtifact, got {type(model_artifact).__name__}. ") return self._registry.update(model_artifact)
[docs] def get_model(self, name: str) -> RegisteredModel: """Get a registered model. Args: name: Name of the model. Returns: Registered model. Raises: ValueError: If the model does not exist. """ model = self._registry.get_registered_model(name) if model is None: raise ValueError(f"Model {name!r} not found") return model
[docs] def get_model_version(self, name: str, version: str) -> ModelVersion: """Get a model version. Args: name: Name of the model. version: Version of the model. Returns: Model version. Raises: model_registry.exceptions.StoreError: If the model does not exist. ValueError: If the version does not exist. """ model_version = self._registry.get_model_version(name, version) if model_version is None: raise ValueError(f"Model version {version!r} not found for model {name!r}") return model_version
[docs] def get_model_artifact(self, name: str, version: str) -> ModelArtifact: """Get a model artifact. Args: name: Name of the model. version: Version of the model. Returns: Model artifact. Raises: model_registry.exceptions.StoreError: If either the model or the version don't exist. ValueError: If the artifact does not exist. """ artifact = self._registry.get_model_artifact(name, version) if artifact is None: raise ValueError(f"Model artifact not found for model {name!r} version {version!r}") return artifact
[docs] def list_models(self) -> Iterator[RegisteredModel]: """Get an iterator for registered models. Yields: Registered models. """ yield from self._registry.get_registered_models()
[docs] def list_model_versions(self, name: str) -> Iterator[ModelVersion]: """Get an iterator for model versions. Args: name: Name of the model. Yields: Model versions. Raises: model_registry.exceptions.StoreError: If the model does not exist. """ yield from self._registry.get_model_versions(name)