Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 63 additions & 4 deletions invokeai/app/api/routers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,26 @@

import torch
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
from fastapi import Body, HTTPException
from dynamicprompts.wildcards import WildcardManager
from fastapi import Body, HTTPException, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from pyparsing import ParseException
from transformers import AutoProcessor, AutoTokenizer, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor

from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.image_files.image_files_common import ImageFileNotFoundException
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.wildcards import (
WildcardsResponse,
WildcardValuesResponse,
clean_dynamic_prompt_outputs,
find_missing_wildcard_references,
get_wildcard_values,
get_wildcards_path,
index_wildcards,
)
from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline
from invokeai.backend.model_manager.taxonomy import ModelType
from invokeai.backend.text_llm_pipeline import DEFAULT_SYSTEM_PROMPT, TextLLMPipeline
Expand All @@ -31,6 +42,41 @@
class DynamicPromptsResponse(BaseModel):
prompts: list[str]
error: Optional[str] = None
warnings: list[str] = Field(default_factory=list)
missing_wildcards: list[str] = Field(default_factory=list)


@utilities_router.get(
"/wildcards",
operation_id="list_wildcards",
responses={
200: {"model": WildcardsResponse},
},
)
async def list_wildcards(_: CurrentUserOrDefault) -> WildcardsResponse:
"""List local dynamic prompt wildcards from INVOKEAI_ROOT/wildcards."""
wildcards_path = get_wildcards_path(ApiDependencies.invoker.services.configuration.root_path)
return index_wildcards(wildcards_path)


@utilities_router.get(
"/wildcards/values",
operation_id="get_wildcard_values",
responses={
200: {"model": WildcardValuesResponse},
},
)
async def list_wildcard_values(
_: CurrentUserOrDefault,
path: str = Query(description="The relative wildcard path to read values for"),
limit: int = Query(default=200, ge=1, le=1000, description="The max number of wildcard values to return"),
) -> WildcardValuesResponse:
"""List values for a single local dynamic prompt wildcard."""
wildcards_path = get_wildcards_path(ApiDependencies.invoker.services.configuration.root_path)
values = get_wildcard_values(wildcards_path, path, limit)
if values is None:
raise HTTPException(status_code=404, detail=f"Wildcard '{path}' not found")
return values


@utilities_router.post(
Expand All @@ -41,6 +87,7 @@ class DynamicPromptsResponse(BaseModel):
},
)
async def parse_dynamicprompts(
_: CurrentUserOrDefault,
prompt: str = Body(description="The prompt to parse with dynamicprompts"),
max_prompts: int = Body(ge=1, le=10000, default=1000, description="The max number of prompts to generate"),
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
Expand All @@ -49,18 +96,30 @@ async def parse_dynamicprompts(
"""Creates a batch process"""
max_prompts = min(max_prompts, 10000)
generator: Union[RandomPromptGenerator, CombinatorialPromptGenerator]
warnings: list[str] = []
wildcards_path = get_wildcards_path(ApiDependencies.invoker.services.configuration.root_path)
wildcard_index = index_wildcards(wildcards_path)
missing_wildcards = find_missing_wildcard_references(prompt, wildcard_index.wildcards)
if wildcard_index.errors:
warnings.append("Some wildcard files could not be indexed.")
try:
error: Optional[str] = None
wildcard_manager = WildcardManager(wildcards_path) if wildcards_path.is_dir() else None
if combinatorial:
generator = CombinatorialPromptGenerator()
generator = CombinatorialPromptGenerator(wildcard_manager=wildcard_manager)
prompts = generator.generate(prompt, max_prompts=max_prompts)
else:
generator = RandomPromptGenerator(seed=seed)
generator = RandomPromptGenerator(wildcard_manager=wildcard_manager, seed=seed)
prompts = generator.generate(prompt, num_images=max_prompts)
except ParseException as e:
prompts = [prompt]
error = str(e)
return DynamicPromptsResponse(prompts=prompts if prompts else [""], error=error)
return DynamicPromptsResponse(
prompts=clean_dynamic_prompt_outputs(prompts) if prompts else [""],
error=error,
warnings=warnings,
missing_wildcards=missing_wildcards,
)


# --- Expand Prompt ---
Expand Down
10 changes: 7 additions & 3 deletions invokeai/app/invocations/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

import numpy as np
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
from dynamicprompts.wildcards import WildcardManager
from pydantic import field_validator

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField, UIComponent
from invokeai.app.invocations.primitives import StringCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.wildcards import clean_dynamic_prompt_outputs, get_wildcards_path


@invocation(
Expand All @@ -30,14 +32,16 @@ class DynamicPromptInvocation(BaseInvocation):
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")

def invoke(self, context: InvocationContext) -> StringCollectionOutput:
wildcards_path = get_wildcards_path(context.config.get().root_path)
wildcard_manager = WildcardManager(wildcards_path) if wildcards_path.is_dir() else None
if self.combinatorial:
generator = CombinatorialPromptGenerator()
generator = CombinatorialPromptGenerator(wildcard_manager=wildcard_manager)
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
else:
generator = RandomPromptGenerator()
generator = RandomPromptGenerator(wildcard_manager=wildcard_manager)
prompts = generator.generate(self.prompt, num_images=self.max_prompts)

return StringCollectionOutput(collection=prompts)
return StringCollectionOutput(collection=clean_dynamic_prompt_outputs(prompts))


@invocation(
Expand Down
Loading
Loading