Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 104 additions & 1 deletion invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import List, Optional, Type

import huggingface_hub
import requests
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.routing import APIRouter
Expand Down Expand Up @@ -41,8 +42,13 @@
Main_Checkpoint_SDXLRefiner_Config,
)
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.civitai import CivitaiMetadataFetch, is_civitai_model_version_url
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.metadata.metadata_base import (
CivitaiMetadata,
ModelMetadataWithFiles,
UnknownMetadataException,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.starter_models import (
Expand Down Expand Up @@ -110,6 +116,59 @@ def prepare_model_config_for_response(config: AnyModelConfig, dependencies: Type
return add_cover_image_to_model_config(config, dependencies)


def _get_civitai_source_urls(config: AnyModelConfig) -> list[str]:
"""Return version-specific CivitAI URLs stored on a model config, preserving lookup order."""
urls = []
for value in (config.source_url, config.source):
if value and is_civitai_model_version_url(value) and value not in urls:
urls.append(value)
return urls


def _get_civitai_hash(config: AnyModelConfig) -> str | None:
"""Return the hash value to pass to CivitAI's by-hash lookup."""
if not config.hash:
return None
return config.hash.partition(":")[2] or config.hash


def _fetch_civitai_metadata_for_config(config: AnyModelConfig) -> CivitaiMetadata:
"""Resolve CivitAI metadata from live CivitAI sources, then cached CivitAI metadata."""
fetcher = CivitaiMetadataFetch()
errors: list[str] = []

for url in _get_civitai_source_urls(config):
try:
metadata = fetcher.from_url(url) # type: ignore[arg-type]
if isinstance(metadata, CivitaiMetadata):
return metadata
except (UnknownMetadataException, requests.RequestException, ValueError) as e:
errors.append(str(e))

hash_value = _get_civitai_hash(config)
if hash_value:
try:
return fetcher.from_hash(hash_value)
except (UnknownMetadataException, requests.RequestException, ValueError) as e:
errors.append(str(e))

if config.source_api_response:
try:
return fetcher.from_api_response(config.source_api_response)
except (UnknownMetadataException, ValueError) as e:
errors.append(str(e))

details = "; ".join(errors) if errors else "No CivitAI source URL or hash is available"
raise UnknownMetadataException(details)


def _get_refreshed_civitai_source_url(config: AnyModelConfig, metadata: CivitaiMetadata) -> str | None:
"""Return the source URL to save after refreshing CivitAI metadata."""
if config.source_url and is_civitai_model_version_url(config.source_url):
return config.source_url
return metadata.source_url or config.source_url


##############################################################################
# These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example.
Expand Down Expand Up @@ -325,6 +384,50 @@ async def reidentify_model(
raise HTTPException(status_code=404, detail=str(e))


@model_manager_router.post(
"/i/{key}/refresh_trigger_phrases",
operation_id="refresh_model_trigger_phrases",
responses={
200: {
"description": "The model trigger phrases were refreshed successfully",
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "The model or CivitAI metadata could not be found"},
},
)
async def refresh_model_trigger_phrases(
key: Annotated[str, Path(description="Key of the model to refresh trigger phrases for.")],
current_admin: AdminUserOrDefault,
) -> AnyModelConfig:
"""Refresh a LoRA model's trigger phrases from CivitAI metadata."""
record_store = ApiDependencies.invoker.services.model_manager.store
try:
config = record_store.get_model(key)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))

if config.type != ModelType.LoRA:
raise HTTPException(status_code=400, detail="Trigger phrase refresh is only supported for LoRA models")

try:
metadata = _fetch_civitai_metadata_for_config(config)
except UnknownMetadataException as e:
raise HTTPException(status_code=404, detail=f"Unable to resolve CivitAI metadata: {e}")

existing_phrases = set(config.trigger_phrases or [])
changes_kwargs = {
"source_api_response": metadata.api_response,
"source_url": _get_refreshed_civitai_source_url(config, metadata),
}
if metadata.trained_words:
changes_kwargs["trigger_phrases"] = existing_phrases.union(metadata.trained_words)

changes = ModelRecordChanges(**changes_kwargs)
updated_config = record_store.update_model(key, changes=changes, allow_class_change=True)
return prepare_model_config_for_response(updated_config, ApiDependencies)


class FoundModel(BaseModel):
path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed")
Expand Down
31 changes: 26 additions & 5 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from huggingface_hub import get_token as hf_get_token
from pydantic.networks import AnyHttpUrl
from pydantic_core import Url
from requests import Session
from requests import RequestException, Session

from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
Expand Down Expand Up @@ -54,12 +54,18 @@
from invokeai.backend.model_manager.configs.unknown import Unknown_Config
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
HuggingFaceMetadataFetch,
ModelMetadataFetchBase,
ModelMetadataWithFiles,
RemoteModelFile,
)
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.metadata.fetch.civitai import is_civitai_model_version_url
from invokeai.backend.model_manager.metadata.metadata_base import (
CivitaiMetadata,
HuggingFaceMetadata,
UnknownMetadataException,
)
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
Expand Down Expand Up @@ -755,12 +761,20 @@ def _remote_files_from_source(
if isinstance(source, URLModelSource):
try:
fetcher = self.get_fetcher_from_url(str(source.url))
except ValueError:
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None

try:
kwargs: dict[str, Any] = {"session": self._session}
metadata = fetcher(**kwargs).from_url(source.url)
assert isinstance(metadata, ModelMetadataWithFiles)
return metadata.download_urls(session=self._session), metadata
except ValueError:
pass
except (UnknownMetadataException, RequestException, ValueError) as e:
if fetcher is not CivitaiMetadataFetch:
raise
self._logger.warning(
f"Unable to fetch metadata for {source.url}: {e}. Falling back to direct download."
)

return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None

Expand Down Expand Up @@ -881,8 +895,13 @@ def _register_or_install(self, job: ModelInstallJob) -> None:
job.config_in.source = str(job.source)
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
if isinstance(job.source_metadata, (HuggingFaceMetadata, CivitaiMetadata)):
job.config_in.source_api_response = job.source_metadata.api_response
if isinstance(job.source_metadata, CivitaiMetadata):
if not job.config_in.source_url:
job.config_in.source_url = job.source_metadata.source_url
if job.source_metadata.trained_words and not job.config_in.trigger_phrases:
job.config_in.trigger_phrases = set(job.source_metadata.trained_words)

if job._install_tmpdir is not None:
self._delete_install_marker(job._install_tmpdir)
Expand Down Expand Up @@ -1487,4 +1506,6 @@ def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
"""
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
if is_civitai_model_version_url(url):
return CivitaiMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")
9 changes: 8 additions & 1 deletion invokeai/backend/model_manager/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
assert isinstance(data, HuggingFaceMetadata)
"""

from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch, ModelMetadataFetchBase
from invokeai.backend.model_manager.metadata.fetch import (
CivitaiMetadataFetch,
HuggingFaceMetadataFetch,
ModelMetadataFetchBase,
)
from invokeai.backend.model_manager.metadata.metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
BaseMetadata,
CivitaiMetadata,
HuggingFaceMetadata,
ModelMetadataWithFiles,
RemoteModelFile,
Expand All @@ -30,6 +35,8 @@
__all__ = [
"AnyModelRepoMetadata",
"AnyModelRepoMetadataValidator",
"CivitaiMetadata",
"CivitaiMetadataFetch",
"HuggingFaceMetadata",
"HuggingFaceMetadataFetch",
"ModelMetadataFetchBase",
Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/model_manager/metadata/fetch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
assert isinstance(data, HuggingFaceMetadata)
"""

from invokeai.backend.model_manager.metadata.fetch.civitai import CivitaiMetadataFetch
from invokeai.backend.model_manager.metadata.fetch.fetch_base import ModelMetadataFetchBase
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch

__all__ = ["ModelMetadataFetchBase", "HuggingFaceMetadataFetch"]
__all__ = ["ModelMetadataFetchBase", "HuggingFaceMetadataFetch", "CivitaiMetadataFetch"]
Loading
Loading