diff --git a/docs/src/content/docs/features/shapes-tool.mdx b/docs/src/content/docs/features/shapes-tool.mdx
new file mode 100644
index 00000000000..6ad795aed7a
--- /dev/null
+++ b/docs/src/content/docs/features/shapes-tool.mdx
@@ -0,0 +1,96 @@
+---
+title: Shapes Tool
+description: Learn how to draw filled shapes on raster and inpaint mask layers with the Shapes tool.
+lastUpdated: 2026-05-11
+---
+
+import { Card, CardGrid } from '@astrojs/starlight/components';
+
+The Shapes tool is a general-purpose filled-shape drawing tool for the canvas. It replaces the old Rectangle tool and
+adds four shape modes under a single toolbar button:
+
+- **Rect**
+- **Oval**
+- **Polygon**
+- **Freehand**
+
+You can activate the Shapes tool from the canvas toolbar or with the default hotkey U.
+
+## Where Shapes Draws
+
+Shapes always draws into the **active raster target**:
+
+- On a regular raster layer, Shapes adds filled pixels to that layer.
+- On an active inpaint mask layer, Shapes draws directly into the mask.
+
+:::note
+Shapes overlaps with some Lasso workflows on mask layers, but the tools are not identical. Lasso is still the more
+specialized masking tool and can create a new mask layer automatically when one does not already exist.
+:::
+
+## Common Behavior
+
+- Shapes preview live while you draw.
+- The fill color uses the current active color.
+- The active color's alpha is respected when adding pixels.
+- Hold Ctrl on Windows/Linux or Cmd on macOS to switch to **subtractive** mode and cut pixels
+ out of the active layer.
+- In subtractive mode, alpha is ignored and the shape fully clears pixels.
+- Press Esc to cancel the current shape session.
+
+:::tip
+When subtractive mode is active, the canvas cursor shows a small minus badge so you can tell at a glance that the next
+shape will erase instead of fill.
+:::
+
+## Shape Modes
+
+
+
+ Drag to draw a rectangle. Hold Shift to constrain to a square. Hold Alt to draw from the
+ center instead of from a corner.
+
+
+ Drag to draw an ellipse. Hold Shift to constrain to a perfect circle. Hold Alt to draw from
+ the center.
+
+
+ Click to place vertices. Click the first point to close and commit the shape. Hold Shift to snap the
+ pending edge to horizontal, vertical, and 45 degree angles.
+
+
+ Click and drag to sketch a filled freehand contour. Release the pointer to commit the shape.
+
+
+
+## Moving and Panning During Drawing
+
+The Shapes tool supports different Space behavior depending on the current mode:
+
+- **Rect / Oval:** While the pointer is still down, hold Space to move the uncommitted shape instead of
+ resizing it. Release Space to continue resizing.
+- **Polygon / Freehand:** Hold Space during an active session to pan the viewport without discarding the
+ unfinished shape.
+
+This is especially useful when drawing large shapes that extend beyond the current viewport.
+
+## Color Picking While Using Shapes
+
+The Alt key behaves differently depending on the active Shapes mode:
+
+- **Rect / Oval:** Before you start dragging, Alt can be used for the temporary color-picker quick-switch.
+ Once a drag is active, Alt is reserved for drawing from the center.
+- **Polygon:** Alt remains available for the temporary color-picker quick-switch between vertex placements.
+- **Freehand:** Alt is available before the stroke starts, but not during an active stroke.
+
+## Practical Examples
+
+- Use **Rect** or **Oval** to block in clean mask regions quickly.
+- Use **Polygon** when you need straight edges and deliberate corner placement.
+- Use **Freehand** for irregular organic regions.
+- Use **subtractive mode** to cut holes back out of an existing raster or mask layer.
+
+## Summary
+
+The Shapes tool is the fastest way to add filled geometric or freeform regions to canvas layers. Use it for structured
+fills, mask authoring, and precise subtractive edits without switching away from the current raster target.
diff --git a/docs/src/generated/settings.json b/docs/src/generated/settings.json
index 88a42f8fbcf..f52420ca4b2 100644
--- a/docs/src/generated/settings.json
+++ b/docs/src/generated/settings.json
@@ -590,6 +590,20 @@
"type": "",
"validation": {}
},
+ {
+ "category": "GENERATION",
+ "default": "round_robin",
+ "description": "Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.",
+ "env_var": "INVOKEAI_SESSION_QUEUE_MODE",
+ "literal_values": [
+ "FIFO",
+ "round_robin"
+ ],
+ "name": "session_queue_mode",
+ "required": false,
+ "type": "typing.Literal['FIFO', 'round_robin']",
+ "validation": {}
+ },
{
"category": "GENERATION",
"default": false,
diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py
index 41a5a411c7a..d62cac5095f 100644
--- a/invokeai/app/api/routers/session_queue.py
+++ b/invokeai/app/api/routers/session_queue.py
@@ -141,12 +141,11 @@ async def get_queue_item_ids(
queue_id: str = Path(description="The queue id to perform this operation on"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
) -> ItemIdsResult:
- """Gets all queue item ids that match the given parameters. Non-admin users only see their own items."""
+ """Gets all queue item ids that match the given parameters. The IDs themselves are not sensitive;
+ per-item field redaction is performed when the items are fetched via list_all_queue_items or
+ get_queue_items_by_item_ids."""
try:
- user_id = None if current_user.is_admin else current_user.user_id
- return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(
- queue_id=queue_id, order_dir=order_dir, user_id=user_id
- )
+ return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}")
@@ -436,10 +435,15 @@ async def get_queue_status(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
- """Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it."""
+ """Gets the status of the session queue. Returns global counts plus the calling user's own
+ pending/in_progress counts (so the UI can show an X/Y badge). Non-admin users cannot see the
+ current item's identifiers unless they own it."""
try:
- user_id = None if current_user.is_admin else current_user.user_id
- queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=user_id)
+ queue = ApiDependencies.invoker.services.session_queue.get_queue_status(
+ queue_id,
+ user_id=current_user.user_id,
+ is_admin=current_user.is_admin,
+ )
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
except Exception as e:
diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py
index 5783b804c0b..b02b5bbb067 100644
--- a/invokeai/app/api/sockets.py
+++ b/invokeai/app/api/sockets.py
@@ -260,20 +260,37 @@ async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
+ def _owner_and_admin_sids(self, owner_user_id: str) -> list[str]:
+ """Sids belonging to the event's owner or to any admin.
+
+ Used as `skip_sid` when broadcasting a sanitized companion event to the queue room,
+ so the owner and admins (who already received the full event) don't get a second
+ copy that would clobber their cache with redacted values.
+ """
+ return [
+ sid
+ for sid, info in self._socket_users.items()
+ if info.get("user_id") == owner_user_id or info.get("is_admin")
+ ]
+
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
"""Handle queue events with user isolation.
- All queue item events (invocation events AND QueueItemStatusChangedEvent) are
- private to the owning user and admins. They carry unsanitized user_id, batch_id,
- session_id, origin, destination and error metadata, and must never be broadcast
- to the whole queue room — otherwise any other authenticated subscriber could
- observe cross-user queue activity.
+ Queue events split into two routing paths:
- RecallParametersUpdatedEvent is also private to the owner + admins.
+ 1. The owner and admins receive the full unsanitized event in their `user:{id}` /
+ `admin` rooms. The full payload may include batch_id, session_id, origin,
+ destination, error metadata, etc.
- BatchEnqueuedEvent carries the enqueuing user's batch_id/origin/counts and
- is also routed privately. QueueClearedEvent is the only queue event that
- is still broadcast to the whole queue room.
+ 2. For events that other authenticated users need to know about so their queue list
+ and badge counts stay in sync (QueueItemStatusChangedEvent and BatchEnqueuedEvent),
+ a sanitized companion event is also emitted to the full queue room with the
+ owner's and admins' sids in `skip_sid`. The companion uses `user_id="redacted"`
+ as a sentinel so the frontend handler knows to do tag invalidation only and skip
+ per-session side effects.
+
+ InvocationEventBase events stay private (owner + admins only). RecallParametersUpdatedEvent
+ is also private. QueueClearedEvent has no user identity and is broadcast to the queue room.
IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase
inherits from QueueItemEventBase. The order of isinstance checks matters!
@@ -302,10 +319,51 @@ async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room")
- # Other queue item events (QueueItemStatusChangedEvent) carry unsanitized
- # user_id, batch_id, session_id, origin, destination and error metadata.
- # They are private to the owning user + admins — never broadcast to the
- # full queue room.
+ # QueueItemStatusChangedEvent: full to owner+admin, sanitized to everyone else in
+ # the queue room so their queue list, badge, and item caches refresh.
+ elif isinstance(event_data, QueueItemStatusChangedEvent):
+ user_room = f"user:{event_data.user_id}"
+ await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
+ await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
+
+ sanitized = event_data.model_copy(
+ update={
+ "user_id": "redacted",
+ "batch_id": "redacted",
+ "session_id": "redacted",
+ "origin": None,
+ "destination": None,
+ "error_type": None,
+ "error_message": None,
+ "error_traceback": None,
+ }
+ )
+ # Strip identifying fields out of the embedded batch_status / queue_status too.
+ sanitized.batch_status = sanitized.batch_status.model_copy(
+ update={"batch_id": "redacted", "origin": None, "destination": None}
+ )
+ sanitized.queue_status = sanitized.queue_status.model_copy(
+ update={
+ "item_id": None,
+ "session_id": None,
+ "batch_id": None,
+ "user_pending": None,
+ "user_in_progress": None,
+ }
+ )
+ await self._sio.emit(
+ event=event_name,
+ data=sanitized.model_dump(mode="json"),
+ room=event_data.queue_id,
+ skip_sid=self._owner_and_admin_sids(event_data.user_id),
+ )
+
+ logger.debug(
+ f"Emitted queue_item_status_changed: full to {user_room}+admin, sanitized to queue {event_data.queue_id}"
+ )
+
+ # Other queue item events (currently none beyond QueueItemStatusChangedEvent that
+ # carry user_id) stay private to owner + admins.
elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"):
user_room = f"user:{event_data.user_id}"
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
@@ -320,14 +378,25 @@ async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
logger.debug(f"Emitted private recall_parameters_updated event to user room {user_room} and admin room")
- # BatchEnqueuedEvent carries the enqueuing user's batch_id, origin, and
- # enqueued counts. Route it privately to the owner + admins so other
- # users do not observe cross-user batch activity.
+ # BatchEnqueuedEvent: full to owner+admin, sanitized to everyone else in the queue
+ # room so their badge total and queue list pick up the new items.
elif isinstance(event_data, BatchEnqueuedEvent):
user_room = f"user:{event_data.user_id}"
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
- logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room")
+
+ sanitized = event_data.model_copy(
+ update={"user_id": "redacted", "batch_id": "redacted", "origin": None}
+ )
+ await self._sio.emit(
+ event=event_name,
+ data=sanitized.model_dump(mode="json"),
+ room=event_data.queue_id,
+ skip_sid=self._owner_and_admin_sids(event_data.user_id),
+ )
+ logger.debug(
+ f"Emitted batch_enqueued: full to {user_room}+admin, sanitized to queue {event_data.queue_id}"
+ )
else:
# For remaining queue events (e.g. QueueClearedEvent) that do not
diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 57004efca39..d85b170fbab 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -30,6 +30,7 @@
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
+SESSION_QUEUE_MODE = Literal["FIFO", "round_robin"]
IMAGE_SUBFOLDER_STRATEGY = Literal["flat", "date", "type", "hash"]
CONFIG_SCHEMA_VERSION = "4.0.3"
EXTERNAL_PROVIDER_CONFIG_FIELDS = (
@@ -114,6 +115,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
+ session_queue_mode: Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.
Valid values: `FIFO`, `round_robin`
clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`.
max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.
allow_nodes: List of nodes to allow. Omit to allow all.
@@ -214,6 +216,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
+ session_queue_mode: SESSION_QUEUE_MODE = Field(default="round_robin", description="Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup. If true, disables `max_queue_history`.")
max_queue_history: Optional[int] = Field(default=None, ge=0, description="Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.")
diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py
index d87221fbbae..7472ea07f63 100644
--- a/invokeai/app/services/session_queue/session_queue_common.py
+++ b/invokeai/app/services/session_queue/session_queue_common.py
@@ -309,6 +309,12 @@ class SessionQueueStatus(BaseModel):
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
+ user_pending: Optional[int] = Field(
+ default=None, description="Number of pending queue items for the calling user (multiuser only)"
+ )
+ user_in_progress: Optional[int] = Field(
+ default=None, description="Number of in-progress queue items for the calling user (multiuser only)"
+ )
class SessionQueueCountsByDestination(BaseModel):
diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py
index a05ed468857..c1bb71e0b74 100644
--- a/invokeai/app/services/session_queue/session_queue_sqlite.py
+++ b/invokeai/app/services/session_queue/session_queue_sqlite.py
@@ -210,9 +210,45 @@ async def enqueue_batch(
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
- with self._db.transaction() as cursor:
- cursor.execute(
- """--sql
+ config = self.__invoker.services.configuration
+ use_round_robin = config.multiuser and config.session_queue_mode == "round_robin"
+
+ if use_round_robin:
+ query = """--sql
+ WITH user_last_served AS (
+ -- Track when each user last had an item started, to determine whose turn it is.
+ SELECT user_id, MAX(started_at) AS last_served_at
+ FROM session_queue
+ WHERE started_at IS NOT NULL
+ GROUP BY user_id
+ ),
+ user_next_item AS (
+ -- For each user, select their single best pending item (highest priority, then oldest).
+ SELECT
+ user_id,
+ item_id,
+ ROW_NUMBER() OVER (
+ PARTITION BY user_id
+ ORDER BY priority DESC, item_id ASC
+ ) AS rn
+ FROM session_queue
+ WHERE status = 'pending'
+ )
+ SELECT
+ sq.*,
+ u.display_name AS user_display_name,
+ u.email AS user_email
+ FROM session_queue sq
+ LEFT JOIN users u ON sq.user_id = u.user_id
+ JOIN user_next_item uni ON sq.item_id = uni.item_id AND uni.rn = 1
+ LEFT JOIN user_last_served uls ON sq.user_id = uls.user_id
+ ORDER BY
+ COALESCE(uls.last_served_at, '1970-01-01') ASC,
+ sq.item_id ASC
+ LIMIT 1
+ """
+ else:
+ query = """--sql
SELECT
sq.*,
u.display_name as user_display_name,
@@ -225,7 +261,9 @@ def dequeue(self) -> Optional[SessionQueueItem]:
sq.item_id ASC
LIMIT 1
"""
- )
+
+ with self._db.transaction() as cursor:
+ cursor.execute(query)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
@@ -860,7 +898,18 @@ def get_queue_status(
acting_user_id: Optional[str] = None,
) -> SessionQueueStatus:
with self._db.transaction() as cursor:
- # When user_id is provided (non-admin), only count that user's items
+ cursor.execute(
+ """--sql
+ SELECT status, count(*)
+ FROM session_queue
+ WHERE queue_id = ?
+ GROUP BY status
+ """,
+ (queue_id,),
+ )
+ counts_result = cast(list[sqlite3.Row], cursor.fetchall())
+
+ user_counts_result: list[sqlite3.Row] = []
if user_id is not None:
cursor.execute(
"""--sql
@@ -871,22 +920,19 @@ def get_queue_status(
""",
(queue_id, user_id),
)
- else:
- cursor.execute(
- """--sql
- SELECT status, count(*)
- FROM session_queue
- WHERE queue_id = ?
- GROUP BY status
- """,
- (queue_id,),
- )
- counts_result = cast(list[sqlite3.Row], cursor.fetchall())
+ user_counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] or 0 for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
+ user_pending: Optional[int] = None
+ user_in_progress: Optional[int] = None
+ if user_id is not None:
+ user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result}
+ user_pending = user_counts.get("pending", 0)
+ user_in_progress = user_counts.get("in_progress", 0)
+
# Redaction is decided from the same current_item snapshot used to embed identifiers,
# so a concurrent transition (e.g. B finishing while A's status changes) cannot leave
# stale identifiers in the result. user_id (count filter) and acting_user_id
@@ -909,6 +955,8 @@ def get_queue_status(
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
+ user_pending=user_pending,
+ user_in_progress=user_in_progress,
)
def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus:
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index d99bb04a631..c164d1dafe1 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -728,8 +728,8 @@
"desc": "Select the move tool."
},
"selectRectTool": {
- "title": "Rect Tool",
- "desc": "Select the rect tool."
+ "title": "Shapes Tool",
+ "desc": "Select the shapes tool."
},
"selectLassoTool": {
"title": "Lasso Tool",
@@ -2881,6 +2881,10 @@
"polygon": "Polygon",
"polygonHint": "Click to add points, click the first point to close."
},
+ "shape": {
+ "rect": "Rect",
+ "oval": "Oval"
+ },
"modifierHints": {
"keys": {
"control": "Ctrl",
@@ -2896,11 +2900,12 @@
},
"labels": {
"pan": "Pan",
+ "moveShape": "Move shape",
"pickColor": "Pick color",
"straightLine": "Straight line",
"resizeBrush": "Resize brush",
"resizeEraser": "Resize eraser",
- "subtractMask": "Subtract mask",
+ "erase": "Erase",
"snap45Degrees": "Snap to 45deg",
"lockAspectRatio": "Lock ratio",
"unlockAspectRatio": "Unlock ratio",
@@ -2917,6 +2922,7 @@
"tool": {
"brush": "Brush",
"eraser": "Eraser",
+ "shapes": "Shapes",
"rectangle": "Rectangle",
"lasso": "Lasso",
"gradient": "Gradient",
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/GradientIcons.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/GradientIcons.tsx
index b09e46d7320..61074015e76 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/Tool/GradientIcons.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/GradientIcons.tsx
@@ -32,7 +32,7 @@ export const GradientLinearIcon = memo(() => {
const id = useId();
const gradientId = `${id}-gradient-linear-diagonal`;
return (
-
+
@@ -40,15 +40,15 @@ export const GradientLinearIcon = memo(() => {
);
@@ -59,7 +59,7 @@ export const GradientRadialIcon = memo(() => {
const id = useId();
const gradientId = `${id}-gradient-radial`;
return (
-
+
@@ -67,13 +67,13 @@ export const GradientRadialIcon = memo(() => {
);
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx
index 30d82722072..c0291f8e587 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx
@@ -5,7 +5,7 @@ import { ToolColorPickerButton } from 'features/controlLayers/components/Tool/To
import { ToolGradientButton } from 'features/controlLayers/components/Tool/ToolGradientButton';
import { ToolLassoButton } from 'features/controlLayers/components/Tool/ToolLassoButton';
import { ToolMoveButton } from 'features/controlLayers/components/Tool/ToolMoveButton';
-import { ToolRectButton } from 'features/controlLayers/components/Tool/ToolRectButton';
+import { ToolShapesButton } from 'features/controlLayers/components/Tool/ToolShapesButton';
import { ToolTextButton } from 'features/controlLayers/components/Tool/ToolTextButton';
import React from 'react';
@@ -18,7 +18,7 @@ export const ToolChooser: React.FC = () => {
-
+
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapeTypeToggle.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapeTypeToggle.tsx
new file mode 100644
index 00000000000..2e4530e2e27
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapeTypeToggle.tsx
@@ -0,0 +1,65 @@
+import { ButtonGroup, IconButton, Tooltip } from '@invoke-ai/ui-library';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { selectShapeType, settingsShapeTypeChanged } from 'features/controlLayers/store/canvasSettingsSlice';
+import { memo, useCallback } from 'react';
+import { useTranslation } from 'react-i18next';
+import { PiCircleBold, PiPolygonBold, PiRectangleBold, PiScribbleLoopBold } from 'react-icons/pi';
+
+export const ToolShapeTypeToggle = memo(() => {
+ const { t } = useTranslation();
+ const shapeType = useAppSelector(selectShapeType);
+ const dispatch = useAppDispatch();
+
+ const onRectClick = useCallback(() => dispatch(settingsShapeTypeChanged('rect')), [dispatch]);
+ const onOvalClick = useCallback(() => dispatch(settingsShapeTypeChanged('oval')), [dispatch]);
+ const onPolygonClick = useCallback(() => dispatch(settingsShapeTypeChanged('polygon')), [dispatch]);
+ const onFreehandClick = useCallback(() => dispatch(settingsShapeTypeChanged('freehand')), [dispatch]);
+
+ const rectLabel = t('controlLayers.shape.rect', { defaultValue: 'Rect' });
+ const ovalLabel = t('controlLayers.shape.oval', { defaultValue: 'Oval' });
+ const polygonLabel = t('controlLayers.lasso.polygon', { defaultValue: 'Polygon' });
+ const freehandLabel = t('controlLayers.lasso.freehand', { defaultValue: 'Freehand' });
+
+ return (
+
+
+ }
+ colorScheme={shapeType === 'rect' ? 'invokeBlue' : 'base'}
+ variant="solid"
+ onClick={onRectClick}
+ />
+
+
+ }
+ colorScheme={shapeType === 'oval' ? 'invokeBlue' : 'base'}
+ variant="solid"
+ onClick={onOvalClick}
+ />
+
+
+ }
+ colorScheme={shapeType === 'polygon' ? 'invokeBlue' : 'base'}
+ variant="solid"
+ onClick={onPolygonClick}
+ />
+
+
+ }
+ colorScheme={shapeType === 'freehand' ? 'invokeBlue' : 'base'}
+ variant="solid"
+ onClick={onFreehandClick}
+ />
+
+
+ );
+});
+
+ToolShapeTypeToggle.displayName = 'ToolShapeTypeToggle';
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolRectButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapesButton.tsx
similarity index 58%
rename from invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolRectButton.tsx
rename to invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapesButton.tsx
index 93029390883..3f6c546d2cf 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolRectButton.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapesButton.tsx
@@ -3,32 +3,33 @@ import { useSelectTool, useToolIsSelected } from 'features/controlLayers/compone
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
-import { PiRectangleBold } from 'react-icons/pi';
+import { PiShapesBold } from 'react-icons/pi';
-export const ToolRectButton = memo(() => {
+export const ToolShapesButton = memo(() => {
const { t } = useTranslation();
const isSelected = useToolIsSelected('rect');
- const selectRect = useSelectTool('rect');
+ const selectShapes = useSelectTool('rect');
+ const label = t('controlLayers.tool.shapes', { defaultValue: 'Shapes' });
useRegisteredHotkeys({
id: 'selectRectTool',
category: 'canvas',
- callback: selectRect,
+ callback: selectShapes,
options: { enabled: !isSelected },
- dependencies: [isSelected, selectRect],
+ dependencies: [isSelected, selectShapes],
});
return (
-
+
}
+ aria-label={`${label} (U)`}
+ icon={}
colorScheme={isSelected ? 'invokeBlue' : 'base'}
variant="solid"
- onClick={selectRect}
+ onClick={selectShapes}
/>
);
});
-ToolRectButton.displayName = 'ToolRectButton';
+ToolShapesButton.displayName = 'ToolShapesButton';
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx
index bee8f5d1a34..bd72306e2e3 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx
@@ -7,6 +7,7 @@ import { ToolGradientClipToggle } from 'features/controlLayers/components/Tool/T
import { ToolGradientModeToggle } from 'features/controlLayers/components/Tool/ToolGradientModeToggle';
import { ToolLassoModeToggle } from 'features/controlLayers/components/Tool/ToolLassoModeToggle';
import { ToolOptionsRowContainer } from 'features/controlLayers/components/Tool/ToolOptionsRowContainer';
+import { ToolShapeTypeToggle } from 'features/controlLayers/components/Tool/ToolShapeTypeToggle';
import { ToolWidthPicker } from 'features/controlLayers/components/Tool/ToolWidthPicker';
import { CanvasToolbarFitBboxToLayersButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToLayersButton';
import { CanvasToolbarFitBboxToMasksButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToMasksButton';
@@ -35,6 +36,7 @@ import { memo, useMemo } from 'react';
export const CanvasToolbar = memo(() => {
const isBrushSelected = useToolIsSelected('brush');
const isEraserSelected = useToolIsSelected('eraser');
+ const isShapeSelected = useToolIsSelected('rect');
const isTextSelected = useToolIsSelected('text');
const isLassoSelected = useToolIsSelected('lasso');
const isGradientSelected = useToolIsSelected('gradient');
@@ -56,9 +58,28 @@ export const CanvasToolbar = memo(() => {
useCanvasToggleBboxHotkey();
return (
-
-
+
+
+ {isShapeSelected && (
+
+
+
+ )}
{isGradientSelected && (
@@ -72,21 +93,24 @@ export const CanvasToolbar = memo(() => {
)}
{isTextSelected ? : showToolWithPicker && }
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
);
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts
index 9941761a2ee..b282580fec9 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts
@@ -9,8 +9,11 @@ import { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva
import { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso';
+import { CanvasObjectOval } from 'features/controlLayers/konva/CanvasObject/CanvasObjectOval';
+import { CanvasObjectPolygon } from 'features/controlLayers/konva/CanvasObject/CanvasObjectPolygon';
import { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect';
import type { AnyObjectRenderer, AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types';
+import { shouldPreserveSuspendableShapesSession } from 'features/controlLayers/konva/CanvasTool/toolHotkeys';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import Konva from 'konva';
import type { Logger } from 'roarr';
@@ -83,6 +86,21 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase {
this.subscriptions.add(
this.manager.tool.$tool.listen(() => {
if (this.hasBuffer() && !this.manager.$isBusy.get()) {
+ const isTemporaryShapesToolSwitch = shouldPreserveSuspendableShapesSession(
+ this.manager.tool.$tool.get(),
+ this.manager.tool.$toolBuffer.get(),
+ this.manager.tool.tools.rect.hasSuspendableSession()
+ );
+
+ if (isTemporaryShapesToolSwitch) {
+ return;
+ }
+
+ if (this.state?.type === 'polygon' && this.state.previewPoint) {
+ this.clearBuffer();
+ return;
+ }
+
this.commitBuffer();
}
})
@@ -153,6 +171,24 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase {
this.konva.group.add(this.renderer.konva.group);
}
+ didRender = this.renderer.update(this.state, true);
+ } else if (this.state.type === 'oval') {
+ assert(this.renderer instanceof CanvasObjectOval || !this.renderer);
+
+ if (!this.renderer) {
+ this.renderer = new CanvasObjectOval(this.state, this);
+ this.konva.group.add(this.renderer.konva.group);
+ }
+
+ didRender = this.renderer.update(this.state, true);
+ } else if (this.state.type === 'polygon') {
+ assert(this.renderer instanceof CanvasObjectPolygon || !this.renderer);
+
+ if (!this.renderer) {
+ this.renderer = new CanvasObjectPolygon(this.state, this);
+ this.konva.group.add(this.renderer.konva.group);
+ }
+
didRender = this.renderer.update(this.state, true);
} else if (this.state.type === 'lasso') {
assert(this.renderer instanceof CanvasObjectLasso || !this.renderer);
@@ -240,28 +276,40 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase {
this.log.trace({ buffer: this.renderer.repr() }, 'Committing buffer');
+ let committedState = this.state;
+
+ // Polygon previews render an outline while they are still live in the buffer.
+ // Clear that preview state before adopting the renderer into the persistent object group.
+ if (committedState.type === 'polygon' && this.renderer instanceof CanvasObjectPolygon) {
+ committedState = { ...committedState, previewPoint: undefined };
+ this.state = null;
+ this.renderer.update(committedState, true);
+ }
+
// Move the buffer to the persistent objects group/renderers
this.parent.renderer.adoptObjectRenderer(this.renderer);
if (pushToState) {
const entityIdentifier = this.parent.entityIdentifier;
- switch (this.state.type) {
+ switch (committedState.type) {
case 'brush_line':
case 'brush_line_with_pressure':
- this.manager.stateApi.addBrushLine({ entityIdentifier, brushLine: this.state });
+ this.manager.stateApi.addBrushLine({ entityIdentifier, brushLine: committedState });
break;
case 'eraser_line':
case 'eraser_line_with_pressure':
- this.manager.stateApi.addEraserLine({ entityIdentifier, eraserLine: this.state });
+ this.manager.stateApi.addEraserLine({ entityIdentifier, eraserLine: committedState });
break;
case 'rect':
- this.manager.stateApi.addRect({ entityIdentifier, rect: this.state });
+ case 'oval':
+ case 'polygon':
+ this.manager.stateApi.addShape({ entityIdentifier, shape: committedState });
break;
case 'lasso':
- this.manager.stateApi.addLasso({ entityIdentifier, lasso: this.state });
+ this.manager.stateApi.addLasso({ entityIdentifier, lasso: committedState });
break;
case 'gradient':
- this.manager.stateApi.addGradient({ entityIdentifier, gradient: this.state });
+ this.manager.stateApi.addGradient({ entityIdentifier, gradient: committedState });
break;
}
}
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts
index 903ccaa772c..f62ce3f9822 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts
@@ -11,6 +11,8 @@ import { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva
import { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso';
+import { CanvasObjectOval } from 'features/controlLayers/konva/CanvasObject/CanvasObjectOval';
+import { CanvasObjectPolygon } from 'features/controlLayers/konva/CanvasObject/CanvasObjectPolygon';
import { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect';
import type { AnyObjectRenderer, AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types';
import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters';
@@ -398,6 +400,26 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
this.konva.objectGroup.add(renderer.konva.group);
}
+ didRender = renderer.update(objectState, force || isFirstRender);
+ } else if (objectState.type === 'oval') {
+ assert(renderer instanceof CanvasObjectOval || !renderer);
+
+ if (!renderer) {
+ renderer = new CanvasObjectOval(objectState, this);
+ this.renderers.set(renderer.id, renderer);
+ this.konva.objectGroup.add(renderer.konva.group);
+ }
+
+ didRender = renderer.update(objectState, force || isFirstRender);
+ } else if (objectState.type === 'polygon') {
+ assert(renderer instanceof CanvasObjectPolygon || !renderer);
+
+ if (!renderer) {
+ renderer = new CanvasObjectPolygon(objectState, this);
+ this.renderers.set(renderer.id, renderer);
+ this.konva.objectGroup.add(renderer.konva.group);
+ }
+
didRender = renderer.update(objectState, force || isFirstRender);
} else if (objectState.type === 'lasso') {
assert(renderer instanceof CanvasObjectLasso || !renderer);
@@ -455,10 +477,24 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
renderer instanceof CanvasObjectEraserLine || renderer instanceof CanvasObjectEraserLineWithPressure;
const isSubtractingLasso =
renderer instanceof CanvasObjectLasso && renderer.state.compositeOperation === 'destination-out';
+ const isSubtractRect =
+ renderer instanceof CanvasObjectRect && renderer.state.compositeOperation === 'destination-out';
+ const isSubtractOval =
+ renderer instanceof CanvasObjectOval && renderer.state.compositeOperation === 'destination-out';
+ const isSubtractPolygon =
+ renderer instanceof CanvasObjectPolygon && renderer.state.compositeOperation === 'destination-out';
const isImage = renderer instanceof CanvasObjectImage;
const imageIgnoresTransparency = isImage && renderer.state.usePixelBbox === false;
const hasClip = renderer instanceof CanvasObjectBrushLine && renderer.state.clip;
- if (isEraserLine || isSubtractingLasso || hasClip || (isImage && !imageIgnoresTransparency)) {
+ if (
+ isEraserLine ||
+ isSubtractingLasso ||
+ isSubtractRect ||
+ isSubtractOval ||
+ isSubtractPolygon ||
+ hasClip ||
+ (isImage && !imageIgnoresTransparency)
+ ) {
needsPixelBbox = true;
break;
}
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectOval.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectOval.ts
new file mode 100644
index 00000000000..8c06268f768
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectOval.ts
@@ -0,0 +1,88 @@
+import { rgbaColorToString } from 'common/util/colorCodeTransformers';
+import { deepClone } from 'common/util/deepClone';
+import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
+import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
+import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
+import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
+import type { CanvasOvalState } from 'features/controlLayers/store/types';
+import Konva from 'konva';
+import type { Logger } from 'roarr';
+
+export class CanvasObjectOval extends CanvasModuleBase {
+ readonly type = 'object_oval';
+ readonly id: string;
+ readonly path: string[];
+ readonly parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer;
+ readonly manager: CanvasManager;
+ readonly log: Logger;
+
+ state: CanvasOvalState;
+ konva: {
+ group: Konva.Group;
+ ellipse: Konva.Ellipse;
+ };
+
+ constructor(state: CanvasOvalState, parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer) {
+ super();
+ this.id = state.id;
+ this.parent = parent;
+ this.manager = parent.manager;
+ this.path = this.manager.buildPath(this);
+ this.log = this.manager.buildLogger(this);
+
+ this.log.debug({ state }, 'Creating module');
+
+ this.konva = {
+ group: new Konva.Group({ name: `${this.type}:group`, listening: false }),
+ ellipse: new Konva.Ellipse({
+ name: `${this.type}:ellipse`,
+ listening: false,
+ radiusX: 0,
+ radiusY: 0,
+ perfectDrawEnabled: false,
+ }),
+ };
+ this.konva.group.add(this.konva.ellipse);
+ this.state = state;
+ }
+
+ update(state: CanvasOvalState, force = false): boolean {
+ if (force || this.state !== state) {
+ this.log.trace({ state }, 'Updating oval');
+ const { rect, color, compositeOperation } = state;
+ const fill = compositeOperation === 'destination-out' ? 'rgba(255,255,255,1)' : rgbaColorToString(color);
+ this.konva.ellipse.setAttrs({
+ x: rect.x + rect.width / 2,
+ y: rect.y + rect.height / 2,
+ radiusX: rect.width / 2,
+ radiusY: rect.height / 2,
+ fill,
+ globalCompositeOperation: compositeOperation,
+ });
+ this.state = state;
+ return true;
+ }
+
+ return false;
+ }
+
+ setVisibility(isVisible: boolean): void {
+ this.log.trace({ isVisible }, 'Setting oval visibility');
+ this.konva.group.visible(isVisible);
+ }
+
+ destroy = () => {
+ this.log.debug('Destroying module');
+ this.konva.group.destroy();
+ };
+
+ repr = () => {
+ return {
+ id: this.id,
+ type: this.type,
+ path: this.path,
+ parent: this.parent.id,
+ state: deepClone(this.state),
+ };
+ };
+}
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectPolygon.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectPolygon.ts
new file mode 100644
index 00000000000..dc54811569b
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectPolygon.ts
@@ -0,0 +1,113 @@
+import { rgbaColorToString } from 'common/util/colorCodeTransformers';
+import { deepClone } from 'common/util/deepClone';
+import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
+import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
+import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
+import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
+import type { CanvasPolygonState, RgbaColor } from 'features/controlLayers/store/types';
+import Konva from 'konva';
+import type { Logger } from 'roarr';
+
+const getPreviewStrokeColor = (color: RgbaColor) => rgbaColorToString({ ...color, a: Math.max(color.a, 0.9) });
+
+export class CanvasObjectPolygon extends CanvasModuleBase {
+ readonly type = 'object_polygon';
+ readonly id: string;
+ readonly path: string[];
+ readonly parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer;
+ readonly manager: CanvasManager;
+ readonly log: Logger;
+
+ state: CanvasPolygonState;
+ konva: {
+ group: Konva.Group;
+ fillPolygon: Konva.Line;
+ previewStroke: Konva.Line;
+ };
+
+ constructor(state: CanvasPolygonState, parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer) {
+ super();
+ this.id = state.id;
+ this.parent = parent;
+ this.manager = parent.manager;
+ this.path = this.manager.buildPath(this);
+ this.log = this.manager.buildLogger(this);
+
+ this.log.debug({ state }, 'Creating module');
+
+ this.konva = {
+ group: new Konva.Group({ name: `${this.type}:group`, listening: false }),
+ fillPolygon: new Konva.Line({
+ name: `${this.type}:fill_polygon`,
+ listening: false,
+ closed: true,
+ strokeEnabled: false,
+ perfectDrawEnabled: false,
+ }),
+ previewStroke: new Konva.Line({
+ name: `${this.type}:preview_stroke`,
+ listening: false,
+ closed: false,
+ fillEnabled: false,
+ lineCap: 'round',
+ lineJoin: 'round',
+ perfectDrawEnabled: false,
+ strokeWidth: 1,
+ }),
+ };
+ this.konva.group.add(this.konva.fillPolygon, this.konva.previewStroke);
+ this.state = state;
+ }
+
+ update(state: CanvasPolygonState, force = false): boolean {
+ if (force || this.state !== state) {
+ this.log.trace({ state }, 'Updating polygon');
+ const combinedPoints = state.previewPoint
+ ? [...state.points, state.previewPoint.x, state.previewPoint.y]
+ : state.points;
+ const hasRenderablePolygon = combinedPoints.length >= 6;
+ const isLiveBufferPreview = this.parent.type === 'buffer_renderer' && this.parent.state?.id === state.id;
+ const fill =
+ state.compositeOperation === 'destination-out' ? 'rgba(255,255,255,1)' : rgbaColorToString(state.color);
+
+ this.konva.fillPolygon.setAttrs({
+ points: combinedPoints,
+ visible: hasRenderablePolygon,
+ fill,
+ globalCompositeOperation: state.compositeOperation,
+ });
+
+ this.konva.previewStroke.setAttrs({
+ points: combinedPoints,
+ visible: (Boolean(state.previewPoint) || isLiveBufferPreview) && combinedPoints.length >= 4,
+ stroke: getPreviewStrokeColor(state.color),
+ globalCompositeOperation: 'source-over',
+ });
+
+ this.state = state;
+ return true;
+ }
+
+ return false;
+ }
+
+ setVisibility(isVisible: boolean): void {
+ this.log.trace({ isVisible }, 'Setting polygon visibility');
+ this.konva.group.visible(isVisible);
+ }
+
+ destroy = () => {
+ this.log.debug('Destroying module');
+ this.konva.group.destroy();
+ };
+
+ repr = () => {
+ return {
+ id: this.id,
+ type: this.type,
+ path: this.path,
+ parent: this.parent.id,
+ state: deepClone(this.state),
+ };
+ };
+}
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectRect.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectRect.ts
index 1ac8e5b5f37..e879dcd35ab 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectRect.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectRect.ts
@@ -46,13 +46,15 @@ export class CanvasObjectRect extends CanvasModuleBase {
this.isFirstRender = false;
this.log.trace({ state }, 'Updating rect');
- const { rect, color } = state;
+ const { rect, color, compositeOperation } = state;
+ const fill = compositeOperation === 'destination-out' ? 'rgba(255,255,255,1)' : rgbaColorToString(color);
this.konva.rect.setAttrs({
x: rect.x,
y: rect.y,
width: rect.width,
height: rect.height,
- fill: rgbaColorToString(color),
+ fill,
+ globalCompositeOperation: compositeOperation,
});
this.state = state;
return true;
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts
index f193c0b391e..620842a9426 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts
@@ -5,6 +5,8 @@ import type { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/
import type { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient';
import type { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import type { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso';
+import type { CanvasObjectOval } from 'features/controlLayers/konva/CanvasObject/CanvasObjectOval';
+import type { CanvasObjectPolygon } from 'features/controlLayers/konva/CanvasObject/CanvasObjectPolygon';
import type { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect';
import type {
CanvasBrushLineState,
@@ -14,6 +16,8 @@ import type {
CanvasGradientState,
CanvasImageState,
CanvasLassoState,
+ CanvasOvalState,
+ CanvasPolygonState,
CanvasRectState,
} from 'features/controlLayers/store/types';
@@ -28,6 +32,8 @@ export type AnyObjectRenderer =
| CanvasObjectEraserLineWithPressure
| CanvasObjectRect
| CanvasObjectLasso
+ | CanvasObjectOval
+ | CanvasObjectPolygon
| CanvasObjectImage
| CanvasObjectGradient;
/**
@@ -41,4 +47,6 @@ export type AnyObjectState =
| CanvasImageState
| CanvasRectState
| CanvasLassoState
+ | CanvasOvalState
+ | CanvasPolygonState
| CanvasGradientState;
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts
index 7d4c76b0c06..26abd908e51 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts
@@ -25,8 +25,8 @@ import {
entityMovedBy,
entityMovedTo,
entityRasterized,
- entityRectAdded,
entityReset,
+ entityShapeAdded,
inpaintMaskAdded,
rasterLayerAdded,
rgAdded,
@@ -48,7 +48,7 @@ import type {
EntityMovedByPayload,
EntityMovedToPayload,
EntityRasterizedPayload,
- EntityRectAddedPayload,
+ EntityShapeAddedPayload,
Rect,
RgbaColor,
} from 'features/controlLayers/store/types';
@@ -171,10 +171,10 @@ export class CanvasStateApiModule extends CanvasModuleBase {
};
/**
- * Adds a rectangle to an entity, pushing state to redux.
+ * Adds a shape to an entity, pushing state to redux.
*/
- addRect = (arg: EntityRectAddedPayload) => {
- this.store.dispatch(entityRectAdded(arg));
+ addShape = (arg: EntityShapeAddedPayload) => {
+ this.store.dispatch(entityShapeAdded(arg));
};
/**
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasRectToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasRectToolModule.ts
deleted file mode 100644
index 3f64b0c2fc1..00000000000
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasRectToolModule.ts
+++ /dev/null
@@ -1,102 +0,0 @@
-import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
-import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
-import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
-import { floorCoord, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util';
-import type { KonvaEventObject } from 'konva/lib/Node';
-import type { Logger } from 'roarr';
-
-export class CanvasRectToolModule extends CanvasModuleBase {
- readonly type = 'rect_tool';
- readonly id: string;
- readonly path: string[];
- readonly parent: CanvasToolModule;
- readonly manager: CanvasManager;
- readonly log: Logger;
-
- constructor(parent: CanvasToolModule) {
- super();
- this.id = getPrefixedId(this.type);
- this.parent = parent;
- this.manager = this.parent.manager;
- this.path = this.manager.buildPath(this);
- this.log = this.manager.buildLogger(this);
-
- this.log.debug('Creating module');
- }
-
- syncCursorStyle = () => {
- this.manager.stage.setCursor('crosshair');
- };
-
- onStagePointerDown = async (_e: KonvaEventObject) => {
- const cursorPos = this.parent.$cursorPos.get();
- const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get();
- const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
-
- if (!cursorPos || !isPrimaryPointerDown || !selectedEntity) {
- /**
- * Can't do anything without:
- * - A cursor position: the cursor is not on the stage
- * - The mouse is down: the user is not drawing
- * - A selected entity: there is no entity to draw on
- */
- return;
- }
-
- const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
-
- await selectedEntity.bufferRenderer.setBuffer({
- id: getPrefixedId('rect'),
- type: 'rect',
- rect: { x: Math.round(normalizedPoint.x), y: Math.round(normalizedPoint.y), width: 0, height: 0 },
- color: this.manager.stateApi.getCurrentColor(),
- });
- };
-
- onStagePointerUp = (_e: KonvaEventObject) => {
- const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
- if (!selectedEntity) {
- return;
- }
-
- if (selectedEntity.bufferRenderer.state?.type === 'rect' && selectedEntity.bufferRenderer.hasBuffer()) {
- selectedEntity.bufferRenderer.commitBuffer();
- } else {
- selectedEntity.bufferRenderer.clearBuffer();
- }
- };
-
- onStagePointerMove = async (_e: KonvaEventObject) => {
- const cursorPos = this.parent.$cursorPos.get();
-
- if (!cursorPos) {
- return;
- }
-
- if (!this.parent.$isPrimaryPointerDown.get()) {
- return;
- }
-
- const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
-
- if (!selectedEntity) {
- return;
- }
-
- const bufferState = selectedEntity.bufferRenderer.state;
-
- if (!bufferState) {
- return;
- }
-
- if (bufferState.type !== 'rect') {
- return;
- }
-
- const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position);
- const alignedPoint = floorCoord(normalizedPoint);
- bufferState.rect.width = Math.round(alignedPoint.x - bufferState.rect.x);
- bufferState.rect.height = Math.round(alignedPoint.y - bufferState.rect.y);
- await selectedEntity.bufferRenderer.setBuffer(bufferState);
- };
-}
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasShapeToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasShapeToolModule.ts
new file mode 100644
index 00000000000..89c3f7691eb
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasShapeToolModule.ts
@@ -0,0 +1,895 @@
+import { rgbaColorToString } from 'common/util/colorCodeTransformers';
+import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
+import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
+import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
+import { shouldPreserveSuspendableShapesSession } from 'features/controlLayers/konva/CanvasTool/toolHotkeys';
+import {
+ addCoords,
+ floorCoord,
+ getPrefixedId,
+ isDistanceMoreThanMin,
+ offsetCoord,
+} from 'features/controlLayers/konva/util';
+import { selectShapeType } from 'features/controlLayers/store/canvasSettingsSlice';
+import type {
+ CanvasEntityIdentifier,
+ CanvasPolygonState,
+ CanvasRectState,
+ Coordinate,
+} from 'features/controlLayers/store/types';
+import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
+import Konva from 'konva';
+import type { KonvaEventObject } from 'konva/lib/Node';
+import type { Logger } from 'roarr';
+
+type CanvasShapeToolModuleConfig = {
+ START_POINT_RADIUS_PX: number;
+ START_POINT_STROKE_WIDTH_PX: number;
+ START_POINT_HOVER_RADIUS_DELTA_PX: number;
+ POLYGON_CLOSE_RADIUS_PX: number;
+ MIN_FREEHAND_POINT_DISTANCE_PX: number;
+ MAX_FREEHAND_SEGMENT_LENGTH_PX: number;
+ FREEHAND_SIMPLIFY_MIN_POINTS: number;
+ FREEHAND_SIMPLIFY_TOLERANCE: number;
+ PREVIEW_STROKE_COLOR: string;
+};
+
+const DEFAULT_CONFIG: CanvasShapeToolModuleConfig = {
+ START_POINT_RADIUS_PX: 4,
+ START_POINT_STROKE_WIDTH_PX: 2,
+ START_POINT_HOVER_RADIUS_DELTA_PX: 2,
+ POLYGON_CLOSE_RADIUS_PX: 10,
+ MIN_FREEHAND_POINT_DISTANCE_PX: 1,
+ MAX_FREEHAND_SEGMENT_LENGTH_PX: 2,
+ FREEHAND_SIMPLIFY_MIN_POINTS: 200,
+ FREEHAND_SIMPLIFY_TOLERANCE: 0.6,
+ PREVIEW_STROKE_COLOR: rgbaColorToString({ r: 90, g: 175, b: 255, a: 1 }),
+};
+
+const SUBTRACT_CURSOR = `url("data:image/svg+xml,${encodeURIComponent(
+ ``
+)}") 12 12, crosshair`;
+
+const getAxisSign = (value: number, fallback: number): number => {
+ if (value === 0) {
+ return fallback === 0 ? 1 : Math.sign(fallback);
+ }
+ return Math.sign(value);
+};
+
+export class CanvasShapeToolModule extends CanvasModuleBase {
+ readonly type = 'shape_tool';
+ readonly id: string;
+ readonly path: string[];
+ readonly parent: CanvasToolModule;
+ readonly manager: CanvasManager;
+ readonly log: Logger;
+
+ config: CanvasShapeToolModuleConfig = DEFAULT_CONFIG;
+ subscriptions: Set<() => void> = new Set();
+
+ private activeEntityIdentifier: CanvasEntityIdentifier | null = null;
+ private shapeId: string | null = null;
+ private dragStartPoint: Coordinate | null = null;
+ private dragCurrentPoint: Coordinate | null = null;
+ private translatePreviousPointerPoint: Coordinate | null = null;
+ private freehandPoints: Coordinate[] = [];
+ private isDrawingFreehand = false;
+ private polygonPoints: Coordinate[] = [];
+ private polygonPointer: Coordinate | null = null;
+
+ konva: {
+ group: Konva.Group;
+ startPointIndicator: Konva.Circle;
+ };
+
+ constructor(parent: CanvasToolModule) {
+ super();
+ this.id = getPrefixedId(this.type);
+ this.parent = parent;
+ this.manager = this.parent.manager;
+ this.path = this.manager.buildPath(this);
+ this.log = this.manager.buildLogger(this);
+
+ this.log.debug('Creating module');
+
+ this.konva = {
+ group: new Konva.Group({ name: `${this.type}:group`, listening: false }),
+ startPointIndicator: new Konva.Circle({
+ name: `${this.type}:start_point_indicator`,
+ listening: false,
+ fillEnabled: false,
+ stroke: this.config.PREVIEW_STROKE_COLOR,
+ visible: false,
+ perfectDrawEnabled: false,
+ }),
+ };
+ this.konva.group.add(this.konva.startPointIndicator);
+
+ this.subscriptions.add(this.manager.stateApi.$altKey.listen(this.onModifierChanged));
+ this.subscriptions.add(this.manager.stateApi.$ctrlKey.listen(this.onModifierChanged));
+ this.subscriptions.add(this.manager.stateApi.$metaKey.listen(this.onModifierChanged));
+ this.subscriptions.add(this.manager.stateApi.$shiftKey.listen(this.onModifierChanged));
+ this.subscriptions.add(
+ this.manager.stateApi.createStoreSubscription(selectShapeType, () => {
+ if (this.hasActiveSession()) {
+ this.cancel();
+ }
+ this.render();
+ })
+ );
+ }
+
+ hasActiveSession = (): boolean => {
+ return Boolean(
+ this.dragStartPoint || this.isDrawingFreehand || this.freehandPoints.length || this.polygonPoints.length
+ );
+ };
+
+ hasSuspendableSession = (): boolean => {
+ return Boolean(this.isDrawingFreehand || this.freehandPoints.length || this.polygonPoints.length);
+ };
+
+ hasActiveDragSession = (): boolean => {
+ return Boolean(this.dragStartPoint || this.isDrawingFreehand);
+ };
+
+ hasActiveRectOvalDragSession = (): boolean => {
+ const shapeType = this.manager.stateApi.getSettings().shapeType;
+ return Boolean(this.dragStartPoint && this.dragCurrentPoint && (shapeType === 'rect' || shapeType === 'oval'));
+ };
+
+ hasActivePolygonSession = (): boolean => {
+ return this.polygonPoints.length > 0;
+ };
+
+ isTranslatingDragSession = (): boolean => {
+ return this.translatePreviousPointerPoint !== null;
+ };
+
+ freezePolygonPreview = async () => {
+ if (!this.hasActivePolygonSession()) {
+ return;
+ }
+
+ const activeEntity = this.getActiveEntityAdapter();
+ const cursorPos = this.parent.$cursorPos.get();
+ if (!activeEntity || !cursorPos) {
+ return;
+ }
+
+ const point = this.getEntityRelativePoint(cursorPos.relative, activeEntity.state.position);
+ this.polygonPointer = point;
+ await this.updatePolygonBuffer();
+ this.render();
+ };
+
+ onToolChanged = () => {
+ const tool = this.parent.$tool.get();
+ const isTemporaryToolSwitch = shouldPreserveSuspendableShapesSession(
+ tool,
+ this.parent.$toolBuffer.get(),
+ this.hasSuspendableSession()
+ );
+ if (tool !== 'rect' && !isTemporaryToolSwitch) {
+ this.cancel();
+ }
+ };
+
+ syncCursorStyle = () => {
+ this.manager.stage.setCursor(this.getCompositeOperation() === 'destination-out' ? SUBTRACT_CURSOR : 'crosshair');
+ };
+
+ render = () => {
+ const tool = this.parent.$tool.get();
+ const isTemporaryToolSwitch = shouldPreserveSuspendableShapesSession(
+ tool,
+ this.parent.$toolBuffer.get(),
+ this.hasSuspendableSession()
+ );
+ if (tool !== 'rect' && !isTemporaryToolSwitch) {
+ this.konva.startPointIndicator.visible(false);
+ return;
+ }
+
+ if (tool === 'rect') {
+ this.syncCursorStyle();
+ }
+
+ this.syncStartPointIndicator();
+ };
+
+ cancel = () => {
+ this.clearActiveBuffer();
+ this.resetState();
+ this.render();
+ };
+
+ startDragTranslation = () => {
+ const activeEntity = this.getActiveEntityAdapter();
+ const cursorPos = this.parent.$cursorPos.get();
+ if (!activeEntity || !cursorPos || !this.hasActiveRectOvalDragSession()) {
+ return;
+ }
+
+ this.translatePreviousPointerPoint = this.getEntityRelativePoint(cursorPos.relative, activeEntity.state.position);
+ };
+
+ stopDragTranslation = () => {
+ this.translatePreviousPointerPoint = null;
+ };
+
+ onStagePointerDown = async (e: KonvaEventObject) => {
+ const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
+ const cursorPos = this.parent.$cursorPos.get();
+ if (!selectedEntity || !cursorPos) {
+ return;
+ }
+
+ if (e.evt.button !== 0) {
+ return;
+ }
+
+ const shapeType = this.manager.stateApi.getSettings().shapeType;
+ const point = this.getEntityRelativePoint(cursorPos.relative, selectedEntity.state.position);
+
+ if (shapeType === 'polygon') {
+ await this.onPolygonPointerDown(point, selectedEntity.entityIdentifier, e.evt.shiftKey);
+ return;
+ }
+
+ if (shapeType === 'freehand') {
+ if (!this.parent.$isPrimaryPointerDown.get()) {
+ return;
+ }
+
+ await this.startFreehandSession(point, selectedEntity.entityIdentifier);
+ return;
+ }
+
+ if (!this.parent.$isPrimaryPointerDown.get()) {
+ return;
+ }
+
+ this.clearActiveBuffer();
+ this.resetState();
+ this.activeEntityIdentifier = selectedEntity.entityIdentifier;
+ this.shapeId = getPrefixedId(shapeType);
+ this.dragStartPoint = point;
+ this.dragCurrentPoint = point;
+ await this.updateDragBuffer();
+ };
+
+ onStagePointerMove = async (e: KonvaEventObject) => {
+ const shapeType = this.manager.stateApi.getSettings().shapeType;
+ const activeEntity = this.getActiveEntityAdapter();
+ const cursorPos = this.parent.$cursorPos.get();
+
+ if (!activeEntity || !cursorPos) {
+ return;
+ }
+
+ const point = this.getEntityRelativePoint(cursorPos.relative, activeEntity.state.position);
+
+ if (shapeType === 'polygon') {
+ if (!this.hasActivePolygonSession()) {
+ return;
+ }
+ this.polygonPointer = this.getPolygonPoint(point, e.evt.shiftKey);
+ await this.updatePolygonBuffer();
+ this.render();
+ return;
+ }
+
+ if (shapeType === 'freehand') {
+ await this.handleFreehandPointerMove(point);
+ return;
+ }
+
+ if (!this.parent.$isPrimaryPointerDown.get() || !this.dragStartPoint) {
+ return;
+ }
+
+ if (this.isTranslatingDragSession()) {
+ await this.translateDragShape(point);
+ return;
+ }
+
+ this.dragCurrentPoint = point;
+ await this.updateDragBuffer();
+ };
+
+ onWindowPointerMove = async () => {
+ const shapeType = this.manager.stateApi.getSettings().shapeType;
+ const activeEntity = this.getActiveEntityAdapter();
+ const cursorPos = this.parent.$cursorPos.get();
+
+ if (!activeEntity || !cursorPos || !this.parent.$isPrimaryPointerDown.get()) {
+ return;
+ }
+
+ const point = this.getEntityRelativePoint(cursorPos.relative, activeEntity.state.position);
+
+ if (shapeType === 'freehand') {
+ await this.handleFreehandPointerMove(point);
+ return;
+ }
+
+ if ((shapeType !== 'rect' && shapeType !== 'oval') || !this.dragStartPoint) {
+ return;
+ }
+
+ if (this.isTranslatingDragSession()) {
+ await this.translateDragShape(point);
+ return;
+ }
+
+ this.dragCurrentPoint = point;
+ await this.updateDragBuffer();
+ };
+
+ onStagePointerUp = async (_e: KonvaEventObject) => {
+ const shapeType = this.manager.stateApi.getSettings().shapeType;
+
+ if (shapeType === 'polygon') {
+ this.render();
+ return;
+ }
+
+ if (shapeType === 'freehand') {
+ await this.commitFreehand();
+ return;
+ }
+
+ this.finishDragShapeSession();
+ };
+
+ onWindowPointerUp = async () => {
+ if (this.isDrawingFreehand) {
+ await this.commitFreehand();
+ return;
+ }
+
+ if (!this.dragStartPoint) {
+ return;
+ }
+
+ this.finishDragShapeSession();
+ };
+
+ repr = () => {
+ return {
+ id: this.id,
+ type: this.type,
+ path: this.path,
+ activeEntityIdentifier: this.activeEntityIdentifier,
+ shapeId: this.shapeId,
+ dragStartPoint: this.dragStartPoint,
+ dragCurrentPoint: this.dragCurrentPoint,
+ translatePreviousPointerPoint: this.translatePreviousPointerPoint,
+ freehandPoints: this.freehandPoints,
+ isDrawingFreehand: this.isDrawingFreehand,
+ polygonPoints: this.polygonPoints,
+ polygonPointer: this.polygonPointer,
+ };
+ };
+
+ destroy = () => {
+ this.log.debug('Destroying module');
+ this.subscriptions.forEach((unsubscribe) => unsubscribe());
+ this.subscriptions.clear();
+ this.konva.group.destroy();
+ };
+
+ private onModifierChanged = () => {
+ const tool = this.parent.$tool.get();
+ const isTemporaryToolSwitch = shouldPreserveSuspendableShapesSession(
+ tool,
+ this.parent.$toolBuffer.get(),
+ this.hasSuspendableSession()
+ );
+ if (tool !== 'rect' && !isTemporaryToolSwitch) {
+ return;
+ }
+
+ if (tool === 'rect') {
+ this.syncCursorStyle();
+ }
+ void this.updateActivePreview();
+ this.render();
+ };
+
+ private updateActivePreview = async () => {
+ if (this.dragStartPoint) {
+ await this.updateDragBuffer();
+ return;
+ }
+
+ if (this.isDrawingFreehand || this.freehandPoints.length > 0) {
+ await this.updateFreehandBuffer();
+ return;
+ }
+
+ if (this.hasActivePolygonSession()) {
+ await this.updatePolygonBuffer();
+ }
+ };
+
+ private startFreehandSession = async (point: Coordinate, entityIdentifier: CanvasEntityIdentifier) => {
+ this.clearActiveBuffer();
+ this.resetState();
+ this.activeEntityIdentifier = entityIdentifier;
+ this.shapeId = getPrefixedId('polygon');
+ this.freehandPoints = [point];
+ this.isDrawingFreehand = true;
+ await this.updateFreehandBuffer();
+ };
+
+ private handleFreehandPointerMove = async (point: Coordinate) => {
+ if (!this.isDrawingFreehand || !this.parent.$isPrimaryPointerDown.get()) {
+ return;
+ }
+
+ const minDistance = this.manager.stage.unscale(this.config.MIN_FREEHAND_POINT_DISTANCE_PX);
+ const lastPoint = this.freehandPoints.at(-1) ?? null;
+ if (!isDistanceMoreThanMin(point, lastPoint, minDistance)) {
+ return;
+ }
+
+ this.appendFreehandPoint(point);
+ await this.updateFreehandBuffer();
+ };
+
+ private translateDragShape = async (point: Coordinate) => {
+ if (!this.translatePreviousPointerPoint || !this.dragStartPoint || !this.dragCurrentPoint) {
+ return;
+ }
+
+ const dx = point.x - this.translatePreviousPointerPoint.x;
+ const dy = point.y - this.translatePreviousPointerPoint.y;
+
+ if (dx === 0 && dy === 0) {
+ return;
+ }
+
+ this.dragStartPoint = {
+ x: this.dragStartPoint.x + dx,
+ y: this.dragStartPoint.y + dy,
+ };
+ this.dragCurrentPoint = {
+ x: this.dragCurrentPoint.x + dx,
+ y: this.dragCurrentPoint.y + dy,
+ };
+ this.translatePreviousPointerPoint = point;
+
+ await this.updateDragBuffer();
+ };
+
+ private onPolygonPointerDown = async (
+ point: Coordinate,
+ entityIdentifier: CanvasEntityIdentifier,
+ shouldSnap: boolean
+ ) => {
+ if (
+ this.activeEntityIdentifier &&
+ (this.activeEntityIdentifier.id !== entityIdentifier.id ||
+ this.activeEntityIdentifier.type !== entityIdentifier.type)
+ ) {
+ this.cancel();
+ }
+
+ this.activeEntityIdentifier = entityIdentifier;
+ this.dragStartPoint = null;
+ this.dragCurrentPoint = null;
+
+ if (this.polygonPoints.length === 0) {
+ this.shapeId = getPrefixedId('polygon');
+ this.polygonPoints = [point];
+ this.polygonPointer = point;
+ await this.updatePolygonBuffer();
+ this.render();
+ return;
+ }
+
+ const startPoint = this.polygonPoints[0];
+ if (!startPoint) {
+ return;
+ }
+
+ if (this.polygonPoints.length >= 3 && this.isPointNearStart(point)) {
+ await this.commitPolygon();
+ return;
+ }
+
+ const polygonPoint = this.getPolygonPoint(point, shouldSnap);
+ this.polygonPoints = [...this.polygonPoints, polygonPoint];
+ this.polygonPointer = polygonPoint;
+ await this.updatePolygonBuffer();
+ this.render();
+ };
+
+ private commitPolygon = async () => {
+ const activeEntity = this.getActiveEntityAdapter();
+ if (!activeEntity || !this.shapeId || this.polygonPoints.length < 3) {
+ this.cancel();
+ return;
+ }
+
+ const polygonState: CanvasPolygonState = {
+ id: this.shapeId,
+ type: 'polygon',
+ points: this.polygonPoints.flatMap((point) => [point.x, point.y]),
+ color: this.manager.stateApi.getCurrentColor(),
+ compositeOperation: this.getCompositeOperation(),
+ };
+
+ await activeEntity.bufferRenderer.setBuffer(polygonState);
+ activeEntity.bufferRenderer.commitBuffer();
+ this.resetState();
+ this.render();
+ };
+
+ private commitFreehand = async () => {
+ if (!this.isDrawingFreehand) {
+ return;
+ }
+
+ const activeEntity = this.getActiveEntityAdapter();
+ if (!activeEntity || !this.shapeId) {
+ this.cancel();
+ return;
+ }
+
+ const simplifiedPoints = this.simplifyFreehandContour(this.freehandPoints);
+ if (simplifiedPoints.length < 3) {
+ activeEntity.bufferRenderer.clearBuffer();
+ this.resetState();
+ this.render();
+ return;
+ }
+
+ const polygonState: CanvasPolygonState = {
+ id: this.shapeId,
+ type: 'polygon',
+ points: simplifiedPoints.flatMap((point) => [point.x, point.y]),
+ color: this.manager.stateApi.getCurrentColor(),
+ compositeOperation: this.getCompositeOperation(),
+ };
+
+ await activeEntity.bufferRenderer.setBuffer(polygonState);
+ activeEntity.bufferRenderer.commitBuffer();
+ this.resetState();
+ this.render();
+ };
+
+ private updateDragBuffer = async () => {
+ const activeEntity = this.getActiveEntityAdapter();
+ if (!activeEntity || !this.dragStartPoint || !this.dragCurrentPoint || !this.shapeId) {
+ return;
+ }
+
+ const shapeType = this.manager.stateApi.getSettings().shapeType;
+ if (shapeType !== 'rect' && shapeType !== 'oval') {
+ return;
+ }
+
+ const rect = this.getDragRect(this.dragStartPoint, this.dragCurrentPoint, {
+ fromCenter: this.manager.stateApi.$altKey.get(),
+ constrainSquare: this.manager.stateApi.$shiftKey.get(),
+ });
+
+ await activeEntity.bufferRenderer.setBuffer({
+ id: this.shapeId,
+ type: shapeType,
+ rect,
+ color: this.manager.stateApi.getCurrentColor(),
+ compositeOperation: this.getCompositeOperation(),
+ });
+ };
+
+ private updatePolygonBuffer = async () => {
+ const activeEntity = this.getActiveEntityAdapter();
+ if (!activeEntity || !this.shapeId || this.polygonPoints.length === 0) {
+ return;
+ }
+
+ await activeEntity.bufferRenderer.setBuffer({
+ id: this.shapeId,
+ type: 'polygon',
+ points: this.polygonPoints.flatMap((point) => [point.x, point.y]),
+ previewPoint: this.polygonPointer ?? this.polygonPoints.at(-1),
+ color: this.manager.stateApi.getCurrentColor(),
+ compositeOperation: this.getCompositeOperation(),
+ });
+ };
+
+ private updateFreehandBuffer = async () => {
+ const activeEntity = this.getActiveEntityAdapter();
+ if (!activeEntity || !this.shapeId || this.freehandPoints.length === 0) {
+ return;
+ }
+
+ await activeEntity.bufferRenderer.setBuffer({
+ id: this.shapeId,
+ type: 'polygon',
+ points: this.freehandPoints.flatMap((point) => [point.x, point.y]),
+ color: this.manager.stateApi.getCurrentColor(),
+ compositeOperation: this.getCompositeOperation(),
+ });
+ };
+
+ private syncStartPointIndicator = () => {
+ const activeEntity = this.getActiveEntityAdapter();
+ const startPoint = this.polygonPoints[0];
+ if (!activeEntity || !startPoint || this.manager.stateApi.getSettings().shapeType !== 'polygon') {
+ this.konva.startPointIndicator.visible(false);
+ return;
+ }
+
+ const isHoveringStartPoint = this.getIsHoveringStartPoint(startPoint, activeEntity.state.position);
+ const baseRadius = this.manager.stage.unscale(this.config.START_POINT_RADIUS_PX);
+ const stagePoint = addCoords(startPoint, activeEntity.state.position);
+
+ this.konva.startPointIndicator.setAttrs({
+ x: stagePoint.x,
+ y: stagePoint.y,
+ radius:
+ baseRadius +
+ (isHoveringStartPoint ? this.manager.stage.unscale(this.config.START_POINT_HOVER_RADIUS_DELTA_PX) : 0),
+ strokeWidth: this.manager.stage.unscale(this.config.START_POINT_STROKE_WIDTH_PX),
+ visible: true,
+ });
+ };
+
+ private getEntityRelativePoint = (point: Coordinate, position: Coordinate): Coordinate => {
+ return floorCoord(offsetCoord(point, position));
+ };
+
+ private getCompositeOperation = (): CanvasRectState['compositeOperation'] => {
+ return this.manager.stateApi.$ctrlKey.get() || this.manager.stateApi.$metaKey.get()
+ ? 'destination-out'
+ : 'source-over';
+ };
+
+ private getPolygonPoint = (point: Coordinate, shouldSnap: boolean): Coordinate => {
+ if (!shouldSnap) {
+ return point;
+ }
+
+ const lastPoint = this.polygonPoints.at(-1);
+ if (!lastPoint) {
+ return point;
+ }
+
+ const dx = point.x - lastPoint.x;
+ const dy = point.y - lastPoint.y;
+ const distance = Math.hypot(dx, dy);
+ if (distance === 0) {
+ return point;
+ }
+
+ const snapAngle = Math.PI / 4;
+ const angle = Math.atan2(dy, dx);
+ const snappedAngle = Math.round(angle / snapAngle) * snapAngle;
+
+ const snappedPoint = {
+ x: lastPoint.x + Math.cos(snappedAngle) * distance,
+ y: lastPoint.y + Math.sin(snappedAngle) * distance,
+ };
+
+ return this.alignPointToStart(snappedPoint);
+ };
+
+ private isPointNearStart = (point: Coordinate): boolean => {
+ const startPoint = this.polygonPoints[0];
+ if (!startPoint) {
+ return false;
+ }
+ return Math.hypot(point.x - startPoint.x, point.y - startPoint.y) <= this.getPolygonCloseRadius();
+ };
+
+ private getPolygonCloseRadius = (): number => {
+ return this.manager.stage.unscale(this.config.POLYGON_CLOSE_RADIUS_PX);
+ };
+
+ private getIsHoveringStartPoint = (startPoint: Coordinate, entityPosition: Coordinate): boolean => {
+ if (this.polygonPoints.length < 3) {
+ return false;
+ }
+
+ const pointerPoint = this.parent.$cursorPos.get()?.relative;
+ if (!pointerPoint) {
+ return false;
+ }
+
+ const entityRelativePointerPoint = this.getEntityRelativePoint(pointerPoint, entityPosition);
+ return (
+ Math.hypot(entityRelativePointerPoint.x - startPoint.x, entityRelativePointerPoint.y - startPoint.y) <=
+ this.getPolygonCloseRadius()
+ );
+ };
+
+ private alignPointToStart = (point: Coordinate): Coordinate => {
+ if (this.polygonPoints.length < 2) {
+ return point;
+ }
+
+ const startPoint = this.polygonPoints[0];
+ if (!startPoint) {
+ return point;
+ }
+
+ const alignThreshold = this.getPolygonCloseRadius();
+ const deltaX = Math.abs(point.x - startPoint.x);
+ const deltaY = Math.abs(point.y - startPoint.y);
+ const canAlignX = deltaX <= alignThreshold;
+ const canAlignY = deltaY <= alignThreshold;
+
+ if (!canAlignX && !canAlignY) {
+ return point;
+ }
+
+ if (canAlignX && canAlignY) {
+ if (deltaX <= deltaY) {
+ return { x: startPoint.x, y: point.y };
+ }
+ return { x: point.x, y: startPoint.y };
+ }
+
+ if (canAlignX) {
+ return { x: startPoint.x, y: point.y };
+ }
+
+ return { x: point.x, y: startPoint.y };
+ };
+
+ private appendFreehandPoint = (point: Coordinate) => {
+ const lastPoint = this.freehandPoints.at(-1) ?? null;
+ if (!lastPoint) {
+ this.freehandPoints.push(point);
+ return;
+ }
+
+ const maxSegmentLength = this.manager.stage.unscale(this.config.MAX_FREEHAND_SEGMENT_LENGTH_PX);
+ const dx = point.x - lastPoint.x;
+ const dy = point.y - lastPoint.y;
+ const distance = Math.hypot(dx, dy);
+
+ if (distance <= maxSegmentLength) {
+ this.freehandPoints.push(point);
+ return;
+ }
+
+ const steps = Math.ceil(distance / maxSegmentLength);
+ for (let i = 1; i <= steps; i++) {
+ const t = i / steps;
+ this.freehandPoints.push({
+ x: lastPoint.x + dx * t,
+ y: lastPoint.y + dy * t,
+ });
+ }
+ };
+
+ private simplifyFreehandContour = (points: Coordinate[]): Coordinate[] => {
+ if (points.length < this.config.FREEHAND_SIMPLIFY_MIN_POINTS) {
+ return points;
+ }
+
+ const simplifiedFlatPoints = simplifyFlatNumbersArray(
+ points.flatMap((point) => [point.x, point.y]),
+ {
+ tolerance: this.config.FREEHAND_SIMPLIFY_TOLERANCE,
+ highestQuality: true,
+ }
+ );
+
+ if (simplifiedFlatPoints.length < 6) {
+ return points;
+ }
+
+ const simplifiedPoints = this.flatNumbersToCoords(simplifiedFlatPoints);
+ if (simplifiedPoints.length < 3) {
+ return points;
+ }
+
+ return simplifiedPoints;
+ };
+
+ private flatNumbersToCoords = (points: number[]): Coordinate[] => {
+ const coords: Coordinate[] = [];
+ for (let i = 0; i < points.length; i += 2) {
+ const x = points[i];
+ const y = points[i + 1];
+ if (x === undefined || y === undefined) {
+ continue;
+ }
+ coords.push({ x, y });
+ }
+ return coords;
+ };
+
+ private getDragRect = (
+ start: Coordinate,
+ end: Coordinate,
+ options: { fromCenter: boolean; constrainSquare: boolean }
+ ): CanvasRectState['rect'] => {
+ let dx = end.x - start.x;
+ let dy = end.y - start.y;
+
+ if (options.constrainSquare) {
+ const size = Math.max(Math.abs(dx), Math.abs(dy));
+ const dxSign = getAxisSign(dx, dy);
+ const dySign = getAxisSign(dy, dx);
+ dx = dxSign * size;
+ dy = dySign * size;
+ }
+
+ const x1 = options.fromCenter ? start.x - dx : start.x;
+ const y1 = options.fromCenter ? start.y - dy : start.y;
+ const x2 = options.fromCenter ? start.x + dx : start.x + dx;
+ const y2 = options.fromCenter ? start.y + dy : start.y + dy;
+
+ return {
+ x: Math.min(x1, x2),
+ y: Math.min(y1, y2),
+ width: Math.abs(x2 - x1),
+ height: Math.abs(y2 - y1),
+ };
+ };
+
+ private getActiveEntityAdapter = () => {
+ if (!this.activeEntityIdentifier) {
+ return null;
+ }
+ return this.manager.getAdapter(this.activeEntityIdentifier);
+ };
+
+ private finishDragShapeSession = () => {
+ const activeEntity = this.getActiveEntityAdapter();
+ if (!activeEntity) {
+ this.resetState();
+ this.render();
+ return;
+ }
+
+ const bufferState = activeEntity.bufferRenderer.state;
+ if (
+ bufferState &&
+ (bufferState.type === 'rect' || bufferState.type === 'oval') &&
+ activeEntity.bufferRenderer.hasBuffer() &&
+ bufferState.rect.width > 0 &&
+ bufferState.rect.height > 0
+ ) {
+ activeEntity.bufferRenderer.commitBuffer();
+ } else {
+ activeEntity.bufferRenderer.clearBuffer();
+ }
+
+ this.resetState();
+ this.render();
+ };
+
+ private clearActiveBuffer = () => {
+ this.getActiveEntityAdapter()?.bufferRenderer.clearBuffer();
+ };
+
+ private resetState = () => {
+ this.activeEntityIdentifier = null;
+ this.shapeId = null;
+ this.dragStartPoint = null;
+ this.dragCurrentPoint = null;
+ this.translatePreviousPointerPoint = null;
+ this.freehandPoints = [];
+ this.isDrawingFreehand = false;
+ this.polygonPoints = [];
+ this.polygonPointer = null;
+ this.konva.startPointIndicator.visible(false);
+ };
+}
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts
index beca4d14a0a..98e1ab7d5d9 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts
@@ -1,5 +1,6 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
+import type { AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types';
import { CanvasBboxToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasBboxToolModule';
import { CanvasBrushToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasBrushToolModule';
import { CanvasColorPickerToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasColorPickerToolModule';
@@ -7,9 +8,15 @@ import { CanvasEraserToolModule } from 'features/controlLayers/konva/CanvasTool/
import { CanvasGradientToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasGradientToolModule';
import { CanvasLassoToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasLassoToolModule';
import { CanvasMoveToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasMoveToolModule';
-import { CanvasRectToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasRectToolModule';
+import { CanvasShapeToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasShapeToolModule';
import { CanvasTextToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasTextToolModule';
import { CanvasViewToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasViewToolModule';
+import {
+ getToolToCancelOnEscape,
+ shouldPreserveSuspendableShapesSession,
+ shouldQuickSwitchToColorPickerOnAlt,
+ shouldTranslateShapeDragOnSpace,
+} from 'features/controlLayers/konva/CanvasTool/toolHotkeys';
import { ZOOM_DRAG_CURSOR } from 'features/controlLayers/konva/cursors/zoomDragCursor';
import {
calculateNewBrushSizeFromWheelDelta,
@@ -64,7 +71,7 @@ export class CanvasToolModule extends CanvasModuleBase {
tools: {
brush: CanvasBrushToolModule;
eraser: CanvasEraserToolModule;
- rect: CanvasRectToolModule;
+ rect: CanvasShapeToolModule;
lasso: CanvasLassoToolModule;
gradient: CanvasGradientToolModule;
colorPicker: CanvasColorPickerToolModule;
@@ -124,7 +131,7 @@ export class CanvasToolModule extends CanvasModuleBase {
this.tools = {
brush: new CanvasBrushToolModule(this),
eraser: new CanvasEraserToolModule(this),
- rect: new CanvasRectToolModule(this),
+ rect: new CanvasShapeToolModule(this),
lasso: new CanvasLassoToolModule(this),
gradient: new CanvasGradientToolModule(this),
colorPicker: new CanvasColorPickerToolModule(this),
@@ -141,6 +148,7 @@ export class CanvasToolModule extends CanvasModuleBase {
this.konva.group.add(this.tools.brush.konva.group);
this.konva.group.add(this.tools.eraser.konva.group);
+ this.konva.group.add(this.tools.rect.konva.group);
this.konva.group.add(this.tools.colorPicker.konva.group);
this.konva.group.add(this.tools.text.konva.group);
this.konva.group.add(this.tools.bbox.konva.group);
@@ -152,17 +160,24 @@ export class CanvasToolModule extends CanvasModuleBase {
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasSlice, this.render));
this.subscriptions.add(
this.$tool.listen((tool, previousTool) => {
- // Preserve pointer state during temporary view switching so lasso sessions can freeze/resume on space.
- const shouldPreservePointerState =
+ // Preserve pointer state during temporary view switching so lasso and shapes sessions can freeze/resume on
+ // space.
+ const shouldPreserveLassoPointerState =
this.$toolBuffer.get() === 'lasso' &&
this.tools.lasso.hasActiveSession() &&
((previousTool === 'lasso' && tool === 'view') || (previousTool === 'view' && tool === 'lasso'));
+ const shouldPreserveShapesPointerState =
+ this.$toolBuffer.get() === 'rect' &&
+ this.tools.rect.hasSuspendableSession() &&
+ ((previousTool === 'rect' && tool === 'view') || (previousTool === 'view' && tool === 'rect'));
+ const shouldPreservePointerState = shouldPreserveLassoPointerState || shouldPreserveShapesPointerState;
if (!shouldPreservePointerState) {
// On tool switch, reset mouse state
this.manager.tool.$isPrimaryPointerDown.set(false);
}
+ this.tools.rect.onToolChanged();
this.tools.lasso.onToolChanged();
void this.tools.text.onToolChanged();
this.render();
@@ -239,6 +254,7 @@ export class CanvasToolModule extends CanvasModuleBase {
this.tools.brush.render();
this.tools.eraser.render();
+ this.tools.rect.render();
this.tools.colorPicker.render();
this.tools.text.render();
this.tools.bbox.render();
@@ -411,9 +427,8 @@ export class CanvasToolModule extends CanvasModuleBase {
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (
- selectedEntity?.bufferRenderer.state?.type !== 'rect' &&
- selectedEntity?.bufferRenderer.state?.type !== 'gradient' &&
- selectedEntity?.bufferRenderer.hasBuffer()
+ selectedEntity?.bufferRenderer.hasBuffer() &&
+ !this.shouldDeferEnterLeaveCommit(selectedEntity.bufferRenderer.state)
) {
selectedEntity.bufferRenderer.commitBuffer();
return;
@@ -467,7 +482,7 @@ export class CanvasToolModule extends CanvasModuleBase {
}
};
- onStagePointerUp = (e: KonvaEventObject) => {
+ onStagePointerUp = async (e: KonvaEventObject) => {
if (e.target !== this.konva.stage) {
return;
}
@@ -490,7 +505,7 @@ export class CanvasToolModule extends CanvasModuleBase {
} else if (tool === 'eraser') {
this.tools.eraser.onStagePointerUp(e);
} else if (tool === 'rect') {
- this.tools.rect.onStagePointerUp(e);
+ await this.tools.rect.onStagePointerUp(e);
} else if (tool === 'lasso') {
void this.tools.lasso.onStagePointerUp(e);
} else if (tool === 'gradient') {
@@ -534,6 +549,8 @@ export class CanvasToolModule extends CanvasModuleBase {
await this.tools.gradient.onStagePointerMove(e);
} else if (tool === 'text') {
// Already handled above
+ } else if (this.isTemporaryShapesToolSwitch()) {
+ // Preserve in-progress polygon/freehand shapes while temporarily switching to view or color picker.
} else {
this.manager.stateApi.getSelectedEntityAdapter()?.bufferRenderer.clearBuffer();
}
@@ -559,9 +576,8 @@ export class CanvasToolModule extends CanvasModuleBase {
if (
selectedEntity &&
- selectedEntity.bufferRenderer.state?.type !== 'rect' &&
- selectedEntity.bufferRenderer.state?.type !== 'gradient' &&
- selectedEntity.bufferRenderer.hasBuffer()
+ selectedEntity.bufferRenderer.hasBuffer() &&
+ !this.shouldDeferEnterLeaveCommit(selectedEntity.bufferRenderer.state)
) {
selectedEntity.bufferRenderer.commitBuffer();
}
@@ -604,20 +620,19 @@ export class CanvasToolModule extends CanvasModuleBase {
this.render();
};
- /**
- * Commit the buffer on window pointer up.
- *
- * The user may start drawing inside the stage and then release the mouse button outside of the stage. To prevent
- * whatever the user was drawing from being lost, or ending up with stale state, we need to commit the buffer
- * on window pointer up.
- */
- onWindowPointerUp = (_: PointerEvent) => {
+ onWindowPointerUp = async (_: PointerEvent) => {
try {
this.$isPrimaryPointerDown.set(false);
void this.tools.lasso.onWindowPointerUp();
+ await this.tools.rect.onWindowPointerUp();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
- if (selectedEntity && selectedEntity.bufferRenderer.hasBuffer() && !this.manager.$isBusy.get()) {
+ if (
+ selectedEntity &&
+ selectedEntity.bufferRenderer.hasBuffer() &&
+ !this.manager.$isBusy.get() &&
+ !this.shouldSkipWindowPointerUpCommit(selectedEntity.bufferRenderer.state)
+ ) {
selectedEntity.bufferRenderer.commitBuffer();
}
} finally {
@@ -625,36 +640,38 @@ export class CanvasToolModule extends CanvasModuleBase {
}
};
- onWindowPointerMove = (e: PointerEvent) => {
+ onWindowPointerMove = async (e: PointerEvent) => {
const target = e.target;
if (target instanceof Node && this.manager.stage.container.contains(target)) {
return;
}
- if (this.$tool.get() !== 'lasso') {
- return;
- }
-
- if (!this.getCanDraw()) {
- return;
- }
-
- if (!this.$isPrimaryPointerDown.get()) {
- return;
- }
-
- if (!this.tools.lasso.hasActiveSession()) {
- return;
- }
-
try {
this.$lastPointerType.set(e.pointerType);
+ if (!this.getCanDraw()) {
+ return;
+ }
+
+ if (!this.$isPrimaryPointerDown.get()) {
+ return;
+ }
+
if (!this.syncCursorPositionsFromWindowEvent(e)) {
return;
}
- this.tools.lasso.onWindowPointerMove(e);
+ if (this.$tool.get() === 'rect') {
+ if (!this.tools.rect.hasActiveDragSession()) {
+ return;
+ }
+ await this.tools.rect.onWindowPointerMove();
+ } else if (this.$tool.get() === 'lasso') {
+ if (!this.tools.lasso.hasActiveSession()) {
+ return;
+ }
+ this.tools.lasso.onWindowPointerMove(e);
+ }
} finally {
this.render();
}
@@ -665,6 +682,8 @@ export class CanvasToolModule extends CanvasModuleBase {
* and the color picker tool is still active when you come back.
*/
onWindowBlur = () => {
+ this.manager.stateApi.$spaceKey.set(false);
+ this.tools.rect.stopDragTranslation();
this.revertToolBuffer();
};
@@ -691,9 +710,25 @@ export class CanvasToolModule extends CanvasModuleBase {
if (e.key === KEY_ESCAPE) {
// Cancel shape drawing on escape
e.preventDefault();
- if (this.$tool.get() === 'lasso') {
+ const tool = this.$tool.get();
+ const toolToCancel = getToolToCancelOnEscape(
+ tool,
+ this.$toolBuffer.get(),
+ this.tools.lasso.hasActiveSession(),
+ this.tools.rect.hasSuspendableSession()
+ );
+
+ this.manager.stateApi.$spaceKey.set(false);
+ this.tools.rect.stopDragTranslation();
+ if (toolToCancel === 'rect') {
+ this.tools.rect.cancel();
+ }
+ if (toolToCancel === 'lasso') {
this.tools.lasso.reset();
}
+ if (toolToCancel && tool !== toolToCancel) {
+ this.revertToolBuffer();
+ }
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (
selectedEntity &&
@@ -707,16 +742,40 @@ export class CanvasToolModule extends CanvasModuleBase {
}
if (isSpaceKey) {
- // Select the view tool on space key down
e.preventDefault();
e.stopPropagation();
const currentTool = this.$tool.get();
- this.$toolBuffer.set(currentTool);
+ const shapeType = this.manager.stateApi.getSettings().shapeType;
+ const hasActiveShapeDragSession = this.tools.rect.hasActiveDragSession();
+ const isPrimaryPointerDown = this.$isPrimaryPointerDown.get();
this.manager.stateApi.$spaceKey.set(true);
+
+ if (shouldTranslateShapeDragOnSpace(currentTool, shapeType, hasActiveShapeDragSession, isPrimaryPointerDown)) {
+ this.tools.rect.startDragTranslation();
+ return;
+ }
+
+ if (currentTool === 'rect' && this.tools.rect.hasActivePolygonSession()) {
+ void this.tools.rect.freezePolygonPreview();
+ }
+
+ // Select the view tool on space key down
+ this.$toolBuffer.set(currentTool);
this.$tool.set('view');
if (currentTool === 'lasso' && this.tools.lasso.hasActiveSession() && this.$isPrimaryPointerDown.get()) {
// Start panning immediately if user is already drawing with freehand lasso.
this.manager.stage.startDragging();
+ } else if (
+ currentTool === 'rect' &&
+ this.tools.rect.hasSuspendableSession() &&
+ this.$isPrimaryPointerDown.get()
+ ) {
+ // Match lasso: allow an in-progress freehand shapes session to freeze and pan immediately on space.
+ this.manager.stage.startDragging();
+ } else if (currentTool === 'rect' && this.tools.rect.hasActivePolygonSession()) {
+ // Match polygon lasso: when a polygon session is active, Space should immediately enter panning without
+ // requiring an extra click on the canvas.
+ this.manager.stage.startDragging();
} else {
this.$cursorPos.set(null);
}
@@ -724,10 +783,17 @@ export class CanvasToolModule extends CanvasModuleBase {
}
if (e.key === KEY_ALT) {
+ const tool = this.$tool.get();
+ const shapeType = this.manager.stateApi.getSettings().shapeType;
+ const hasActiveShapeDragSession = this.tools.rect.hasActiveDragSession();
+ if (!shouldQuickSwitchToColorPickerOnAlt(tool, shapeType, hasActiveShapeDragSession)) {
+ e.preventDefault();
+ return;
+ }
// Select the color picker on alt key down
e.preventDefault();
e.stopPropagation();
- this.$toolBuffer.set(this.$tool.get());
+ this.$toolBuffer.set(tool);
this.$tool.set('colorPicker');
}
};
@@ -747,11 +813,15 @@ export class CanvasToolModule extends CanvasModuleBase {
}
if (e.key === KEY_SPACE || e.code === CODE_SPACE) {
- // Revert the tool to the previous tool on space key up
e.preventDefault();
e.stopPropagation();
- this.revertToolBuffer();
this.manager.stateApi.$spaceKey.set(false);
+ if (this.tools.rect.isTranslatingDragSession()) {
+ this.tools.rect.stopDragTranslation();
+ return;
+ }
+ // Revert the tool to the previous tool on space key up
+ this.revertToolBuffer();
return;
}
@@ -806,4 +876,28 @@ export class CanvasToolModule extends CanvasModuleBase {
}
this.konva.group.destroy();
};
+
+ private shouldDeferEnterLeaveCommit = (state: AnyObjectState | null) => {
+ if (!state) {
+ return false;
+ }
+
+ if (state.type === 'rect' || state.type === 'oval' || state.type === 'gradient') {
+ return true;
+ }
+
+ return state.type === 'polygon';
+ };
+
+ private shouldSkipWindowPointerUpCommit = (state: AnyObjectState | null) => {
+ return Boolean(state?.type === 'polygon' && state.previewPoint);
+ };
+
+ private isTemporaryShapesToolSwitch = () => {
+ return shouldPreserveSuspendableShapesSession(
+ this.$tool.get(),
+ this.$toolBuffer.get(),
+ this.tools.rect.hasSuspendableSession()
+ );
+ };
}
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.test.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.test.ts
new file mode 100644
index 00000000000..e4fd800d2c2
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.test.ts
@@ -0,0 +1,70 @@
+import { describe, expect, it } from 'vitest';
+
+import {
+ getToolToCancelOnEscape,
+ shouldPreserveSuspendableShapesSession,
+ shouldQuickSwitchToColorPickerOnAlt,
+ shouldTranslateShapeDragOnSpace,
+} from './toolHotkeys';
+
+describe('tool hotkeys', () => {
+ it('keeps the color-picker quick-switch available before starting rect and oval drags', () => {
+ expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'rect', false)).toBe(true);
+ expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'oval', false)).toBe(true);
+ });
+
+ it('blocks the color-picker quick-switch while a rect, oval, or freehand drag is active', () => {
+ expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'rect', true)).toBe(false);
+ expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'oval', true)).toBe(false);
+ expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'freehand', true)).toBe(false);
+ });
+
+ it('keeps the color-picker quick-switch for polygon mode and non-shape tools', () => {
+ expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'polygon', false)).toBe(true);
+ expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'polygon', true)).toBe(true);
+ expect(shouldQuickSwitchToColorPickerOnAlt('brush', 'rect', true)).toBe(true);
+ expect(shouldQuickSwitchToColorPickerOnAlt('lasso', 'polygon', false)).toBe(true);
+ });
+
+ it('uses Space to translate active rect and oval drags instead of switching to view', () => {
+ expect(shouldTranslateShapeDragOnSpace('rect', 'rect', true, true)).toBe(true);
+ expect(shouldTranslateShapeDragOnSpace('rect', 'oval', true, true)).toBe(true);
+ });
+
+ it('does not use Space translation outside active rect and oval drags', () => {
+ expect(shouldTranslateShapeDragOnSpace('rect', 'rect', false, true)).toBe(false);
+ expect(shouldTranslateShapeDragOnSpace('rect', 'rect', true, false)).toBe(false);
+ expect(shouldTranslateShapeDragOnSpace('rect', 'polygon', true, true)).toBe(false);
+ expect(shouldTranslateShapeDragOnSpace('rect', 'freehand', true, true)).toBe(false);
+ expect(shouldTranslateShapeDragOnSpace('brush', 'rect', true, true)).toBe(false);
+ });
+
+ it('preserves suspendable shapes sessions across temporary view and color-picker switches', () => {
+ expect(shouldPreserveSuspendableShapesSession('view', 'rect', true)).toBe(true);
+ expect(shouldPreserveSuspendableShapesSession('colorPicker', 'rect', true)).toBe(true);
+ expect(shouldPreserveSuspendableShapesSession('rect', 'rect', true)).toBe(true);
+ });
+
+ it('does not preserve suspendable shapes sessions for unrelated tool switches', () => {
+ expect(shouldPreserveSuspendableShapesSession('brush', 'rect', true)).toBe(false);
+ expect(shouldPreserveSuspendableShapesSession('view', null, true)).toBe(false);
+ expect(shouldPreserveSuspendableShapesSession('colorPicker', 'rect', false)).toBe(false);
+ });
+
+ it('cancels the active drawing tool directly on escape', () => {
+ expect(getToolToCancelOnEscape('rect', null, false, false)).toBe('rect');
+ expect(getToolToCancelOnEscape('lasso', null, false, false)).toBe('lasso');
+ });
+
+ it('cancels preserved drawing sessions while temporarily switched away', () => {
+ expect(getToolToCancelOnEscape('view', 'lasso', true, false)).toBe('lasso');
+ expect(getToolToCancelOnEscape('view', 'rect', false, true)).toBe('rect');
+ expect(getToolToCancelOnEscape('colorPicker', 'rect', false, true)).toBe('rect');
+ });
+
+ it('does not cancel unrelated buffered tools on escape', () => {
+ expect(getToolToCancelOnEscape('view', 'lasso', false, false)).toBeNull();
+ expect(getToolToCancelOnEscape('colorPicker', 'lasso', true, false)).toBeNull();
+ expect(getToolToCancelOnEscape('view', 'brush', false, true)).toBeNull();
+ });
+});
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.ts
new file mode 100644
index 00000000000..84cbdc93dce
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.ts
@@ -0,0 +1,65 @@
+import type { Tool } from 'features/controlLayers/store/types';
+
+type ShapeType = 'rect' | 'oval' | 'polygon' | 'freehand';
+
+export const shouldPreserveSuspendableShapesSession = (
+ tool: Tool,
+ toolBuffer: Tool | null,
+ hasSuspendableShapeSession: boolean
+): boolean => {
+ if (!hasSuspendableShapeSession || toolBuffer !== 'rect') {
+ return false;
+ }
+
+ return tool === 'view' || tool === 'colorPicker' || tool === 'rect';
+};
+
+export const shouldQuickSwitchToColorPickerOnAlt = (
+ tool: Tool,
+ shapeType: ShapeType,
+ hasActiveShapeDragSession: boolean
+): boolean => {
+ if (tool !== 'rect') {
+ return true;
+ }
+
+ if (shapeType === 'polygon') {
+ return true;
+ }
+
+ return !hasActiveShapeDragSession;
+};
+
+export const shouldTranslateShapeDragOnSpace = (
+ tool: Tool,
+ shapeType: ShapeType,
+ hasActiveShapeDragSession: boolean,
+ isPrimaryPointerDown: boolean
+): boolean => {
+ if (tool !== 'rect' || !hasActiveShapeDragSession || !isPrimaryPointerDown) {
+ return false;
+ }
+
+ return shapeType === 'rect' || shapeType === 'oval';
+};
+
+export const getToolToCancelOnEscape = (
+ tool: Tool,
+ toolBuffer: Tool | null,
+ hasActiveLassoSession: boolean,
+ hasSuspendableShapeSession: boolean
+): Tool | null => {
+ if (tool === 'rect' || tool === 'lasso') {
+ return tool;
+ }
+
+ if (tool === 'view' && toolBuffer === 'lasso' && hasActiveLassoSession) {
+ return 'lasso';
+ }
+
+ if ((tool === 'view' || tool === 'colorPicker') && toolBuffer === 'rect' && hasSuspendableShapeSession) {
+ return 'rect';
+ }
+
+ return null;
+};
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts
index 202b70e142d..509aefdaaf2 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts
@@ -14,6 +14,7 @@ export type TransformSmoothingMode = z.infer;
const zGradientType = z.enum(['linear', 'radial']);
const zLassoMode = z.enum(['freehand', 'polygon']);
+const zShapeType = z.enum(['rect', 'oval', 'polygon', 'freehand']);
const zCanvasSettingsState = z.object({
/**
@@ -115,6 +116,10 @@ const zCanvasSettingsState = z.object({
* The gradient tool type.
*/
gradientType: zGradientType.default('linear'),
+ /**
+ * The shape tool type.
+ */
+ shapeType: zShapeType.default('rect'),
/**
* Whether the gradient tool clips to the drag gesture.
*/
@@ -152,6 +157,7 @@ const getInitialState = (): CanvasSettingsState => ({
transformSmoothingEnabled: false,
transformSmoothingMode: 'bicubic',
gradientType: 'linear',
+ shapeType: 'rect',
gradientClipEnabled: true,
lassoMode: 'freehand',
});
@@ -248,6 +254,9 @@ const slice = createSlice({
settingsGradientTypeChanged: (state, action: PayloadAction) => {
state.gradientType = action.payload;
},
+ settingsShapeTypeChanged: (state, action: PayloadAction) => {
+ state.shapeType = action.payload;
+ },
settingsGradientClipToggled: (state) => {
state.gradientClipEnabled = !state.gradientClipEnabled;
},
@@ -284,6 +293,7 @@ export const {
settingsStagingAreaAutoSwitchChanged,
settingsFillColorPickerPinnedSet,
settingsGradientTypeChanged,
+ settingsShapeTypeChanged,
settingsGradientClipToggled,
settingsLassoModeChanged,
} = slice.actions;
@@ -326,5 +336,6 @@ export const selectTransformSmoothingEnabled = createCanvasSettingsSelector(
);
export const selectTransformSmoothingMode = createCanvasSettingsSelector((settings) => settings.transformSmoothingMode);
export const selectGradientType = createCanvasSettingsSelector((settings) => settings.gradientType);
+export const selectShapeType = createCanvasSettingsSelector((settings) => settings.shapeType);
export const selectGradientClipEnabled = createCanvasSettingsSelector((settings) => settings.gradientClipEnabled);
export const selectLassoMode = createCanvasSettingsSelector((settings) => settings.lassoMode);
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts
index 9e639c8e7af..a18a6ed308f 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts
@@ -70,7 +70,7 @@ import type {
EntityLassoAddedPayload,
EntityMovedToPayload,
EntityRasterizedPayload,
- EntityRectAddedPayload,
+ EntityShapeAddedPayload,
IPMethodV2,
T2IAdapterConfig,
ZImageControlConfig,
@@ -1568,8 +1568,8 @@ const slice = createSlice({
points: eraserLine.type === 'eraser_line' ? simplifyFlatNumbersArray(eraserLine.points) : eraserLine.points,
});
},
- entityRectAdded: (state, action: PayloadAction) => {
- const { entityIdentifier, rect } = action.payload;
+ entityShapeAdded: (state, action: PayloadAction) => {
+ const { entityIdentifier, shape } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
@@ -1577,7 +1577,7 @@ const slice = createSlice({
// TODO(psyche): If we add the object without splatting, the renderer will see it as the same object and not
// re-render it (reference equality check). I don't like this behaviour.
- entity.objects.push({ ...rect });
+ entity.objects.push({ ...shape });
},
entityLassoAdded: (state, action: PayloadAction) => {
const { entityIdentifier, lasso } = action.payload;
@@ -1910,7 +1910,7 @@ export const {
entityRasterized,
entityBrushLineAdded,
entityEraserLineAdded,
- entityRectAdded,
+ entityShapeAdded,
entityLassoAdded,
entityGradientAdded,
// Raster layer adjustments
@@ -2046,7 +2046,7 @@ export const canvasSliceConfig: SliceConfig = {
const doNotGroupMatcher = isAnyOf(
entityBrushLineAdded,
entityEraserLineAdded,
- entityRectAdded,
+ entityShapeAdded,
entityLassoAdded,
entityGradientAdded
);
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts
index 7a7ebeade71..cbeccdfa930 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts
@@ -260,6 +260,7 @@ const zCanvasRectState = z.object({
type: z.literal('rect'),
rect: zRect,
color: zRgbaColor,
+ compositeOperation: z.enum(['source-over', 'destination-out']).default('source-over'),
});
export type CanvasRectState = z.infer;
@@ -277,6 +278,28 @@ const zCanvasLassoState = z.object({
});
export type CanvasLassoState = z.infer;
+const zCanvasOvalState = z.object({
+ id: zId,
+ type: z.literal('oval'),
+ rect: zRect,
+ color: zRgbaColor,
+ compositeOperation: z.enum(['source-over', 'destination-out']).default('source-over'),
+});
+export type CanvasOvalState = z.infer;
+
+const zCanvasPolygonState = z.object({
+ id: zId,
+ type: z.literal('polygon'),
+ points: zPoints,
+ color: zRgbaColor,
+ compositeOperation: z.enum(['source-over', 'destination-out']).default('source-over'),
+ previewPoint: zCoordinate.optional(),
+});
+export type CanvasPolygonState = z.infer;
+
+const zCanvasShapeState = z.union([zCanvasRectState, zCanvasOvalState, zCanvasPolygonState]);
+type CanvasShapeState = z.infer;
+
// Gradient state includes clip metadata so the tool can optionally clip to drag gesture.
const zCanvasLinearGradientState = z.object({
id: zId,
@@ -325,7 +348,7 @@ const zCanvasObjectState = z.union([
zCanvasImageState,
zCanvasBrushLineState,
zCanvasEraserLineState,
- zCanvasRectState,
+ zCanvasShapeState,
zCanvasLassoState,
zCanvasBrushLineWithPressureState,
zCanvasEraserLineWithPressureState,
@@ -1020,8 +1043,8 @@ export type EntityBrushLineAddedPayload = EntityIdentifierPayload<{
export type EntityEraserLineAddedPayload = EntityIdentifierPayload<{
eraserLine: CanvasEraserLineState | CanvasEraserLineWithPressureState;
}>;
-export type EntityRectAddedPayload = EntityIdentifierPayload<{ rect: CanvasRectState }>;
export type EntityLassoAddedPayload = EntityIdentifierPayload<{ lasso: CanvasLassoState }>;
+export type EntityShapeAddedPayload = EntityIdentifierPayload<{ shape: CanvasShapeState }>;
export type EntityGradientAddedPayload = EntityIdentifierPayload<{ gradient: CanvasGradientState }>;
export type EntityRasterizedPayload = EntityIdentifierPayload<{
imageObject: CanvasImageState;
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts
index 8f0e31ef5cd..63179db844a 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts
+++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts
@@ -276,7 +276,7 @@ export const MODEL_FORMAT_TO_LONG_NAME: Record = {
unknown: 'Unknown',
};
-export const SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS: BaseModelType[] = ['flux', 'sd-3', 'z-image'];
+export const SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS: BaseModelType[] = ['flux', 'sd-3'];
export const SUPPORTS_REF_IMAGES_BASE_MODELS: BaseModelType[] = ['sd-1', 'sdxl', 'flux', 'flux2', 'qwen-image'];
diff --git a/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx b/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx
index e8636466066..1ba2ffd572d 100644
--- a/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx
+++ b/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx
@@ -1,4 +1,6 @@
import { Badge, Portal } from '@invoke-ai/ui-library';
+import { useAppSelector } from 'app/store/storeHooks';
+import { selectIsAuthenticated } from 'features/auth/store/authSlice';
import type { RefObject } from 'react';
import { memo, useEffect, useMemo, useState } from 'react';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
@@ -10,14 +12,24 @@ type Props = {
type SessionQueueStatus = components['schemas']['SessionQueueStatus'];
+const hasUserCounts = (queueData: SessionQueueStatus): boolean => {
+ return (
+ queueData.user_pending !== undefined &&
+ queueData.user_pending !== null &&
+ queueData.user_in_progress !== undefined &&
+ queueData.user_in_progress !== null
+ );
+};
+
/**
- * Calculates the appropriate badge text based on queue status.
+ * Calculates the appropriate badge text based on queue status and authentication state.
* Returns null if badge should be hidden.
*
- * In multiuser mode, the backend already scopes counts to the current user for non-admins,
- * so pending + in_progress reflects the user's own queue items.
+ * In multiuser mode, the badge is "X/Y" where X is the calling user's pending+in_progress count
+ * and Y is the total across all users. In single-user mode (or when user counts are unavailable)
+ * the badge shows the total only.
*/
-const getBadgeText = (queueData: SessionQueueStatus | undefined): string | null => {
+const getBadgeText = (queueData: SessionQueueStatus | undefined, isAuthenticated: boolean): string | null => {
if (!queueData) {
return null;
}
@@ -28,18 +40,24 @@ const getBadgeText = (queueData: SessionQueueStatus | undefined): string | null
return null;
}
+ if (isAuthenticated && hasUserCounts(queueData)) {
+ const userPending = queueData.user_pending! + queueData.user_in_progress!;
+ return `${userPending}/${totalPending}`;
+ }
+
return totalPending.toString();
};
export const QueueCountBadge = memo(({ targetRef }: Props) => {
const [badgePos, setBadgePos] = useState<{ x: string; y: string } | null>(null);
+ const isAuthenticated = useAppSelector(selectIsAuthenticated);
const { queueData } = useGetQueueStatusQuery(undefined, {
selectFromResult: (res) => ({
queueData: res.data?.queue,
}),
});
- const badgeText = useMemo(() => getBadgeText(queueData), [queueData]);
+ const badgeText = useMemo(() => getBadgeText(queueData, isAuthenticated), [queueData, isAuthenticated]);
useEffect(() => {
if (!targetRef.current) {
diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewCanvasHeaderActions.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewCanvasHeaderActions.tsx
index 024e92328f9..d486bc1137f 100644
--- a/invokeai/frontend/web/src/features/ui/layouts/DockviewCanvasHeaderActions.tsx
+++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewCanvasHeaderActions.tsx
@@ -3,7 +3,7 @@ import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import type { IDockviewHeaderActionsProps } from 'dockview';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
-import { selectLassoMode } from 'features/controlLayers/store/canvasSettingsSlice';
+import { selectLassoMode, selectShapeType } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectBbox } from 'features/controlLayers/store/selectors';
import type { Tool } from 'features/controlLayers/store/types';
import { IS_MAC_OS } from 'features/system/components/HotkeysModal/useHotkeyData';
@@ -16,6 +16,7 @@ import { WORKSPACE_PANEL_ID } from './shared';
const $fallbackTool = atom('move');
const $fallbackToolBuffer = atom(null);
+const $fallbackPrimaryPointerDown = atom(false);
const $fallbackTextSession = atom(null);
type CanvasToolModifierHintKey = ReturnType[number]['keys'][number];
@@ -45,10 +46,12 @@ export const DockviewCanvasHeaderActions = memo((props: IDockviewHeaderActionsPr
const { t } = useTranslation();
const canvasManager = useCanvasManagerSafe();
const lassoMode = useAppSelector(selectLassoMode);
+ const shapeType = useAppSelector(selectShapeType);
const bboxAspectRatioLocked = useAppSelector((state) => selectBbox(state).aspectRatio.isLocked);
const tool = useStore(canvasManager?.tool.$tool ?? $fallbackTool);
const toolBuffer = useStore(canvasManager?.tool.$toolBuffer ?? $fallbackToolBuffer);
+ const isPrimaryPointerDown = useStore(canvasManager?.tool.$isPrimaryPointerDown ?? $fallbackPrimaryPointerDown);
const textSession = useStore(canvasManager?.tool.tools.text.$session ?? $fallbackTextSession);
const effectiveTool = useMemo(() => {
@@ -66,10 +69,21 @@ export const DockviewCanvasHeaderActions = memo((props: IDockviewHeaderActionsPr
return getCanvasToolModifierHints({
tool: effectiveTool,
lassoMode,
+ shapeType,
bboxAspectRatioLocked,
hasActiveTextSession: Boolean(textSession),
+ isPrimaryPointerDown,
});
- }, [bboxAspectRatioLocked, canvasManager, effectiveTool, lassoMode, props.activePanel?.id, textSession]);
+ }, [
+ bboxAspectRatioLocked,
+ canvasManager,
+ effectiveTool,
+ isPrimaryPointerDown,
+ lassoMode,
+ props.activePanel?.id,
+ shapeType,
+ textSession,
+ ]);
if (hints.length === 0) {
return null;
diff --git a/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.test.ts b/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.test.ts
index 7c599efad23..244634fd381 100644
--- a/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.test.ts
+++ b/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.test.ts
@@ -2,88 +2,116 @@ import { describe, expect, it } from 'vitest';
import { getCanvasToolModifierHintIds } from './canvasToolModifierHints';
+const buildArgs = (overrides: Partial[0]> = {}) => ({
+ tool: 'brush' as const,
+ lassoMode: 'freehand' as const,
+ shapeType: 'rect' as const,
+ bboxAspectRatioLocked: false,
+ hasActiveTextSession: false,
+ isPrimaryPointerDown: false,
+ ...overrides,
+});
+
describe('getCanvasToolModifierHintIds', () => {
it('returns brush hints in priority order', () => {
- expect(
- getCanvasToolModifierHintIds({
- tool: 'brush',
- lassoMode: 'freehand',
- bboxAspectRatioLocked: false,
- hasActiveTextSession: false,
- })
- ).toEqual(['shiftStraightLine', 'modWheelResizeBrush', 'spacePan', 'altPickColor']);
+ expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'brush' }))).toEqual([
+ 'shiftStraightLine',
+ 'modWheelResizeBrush',
+ 'spacePan',
+ 'altPickColor',
+ ]);
});
it('omits alt color-picker hint for eraser', () => {
- expect(
- getCanvasToolModifierHintIds({
- tool: 'eraser',
- lassoMode: 'freehand',
- bboxAspectRatioLocked: false,
- hasActiveTextSession: false,
- })
- ).toEqual(['shiftStraightLine', 'modWheelResizeEraser', 'spacePan']);
+ expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'eraser' }))).toEqual([
+ 'shiftStraightLine',
+ 'modWheelResizeEraser',
+ 'spacePan',
+ ]);
});
it('adds polygon snapping for polygon lasso', () => {
- expect(
- getCanvasToolModifierHintIds({
- tool: 'lasso',
- lassoMode: 'polygon',
- bboxAspectRatioLocked: false,
- hasActiveTextSession: false,
- })
- ).toEqual(['modSubtractMask', 'shiftSnap45Degrees', 'spacePan']);
+ expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'lasso', lassoMode: 'polygon' }))).toEqual([
+ 'modErase',
+ 'shiftSnap45Degrees',
+ 'spacePan',
+ ]);
});
it('omits polygon snapping for freehand lasso', () => {
- expect(
- getCanvasToolModifierHintIds({
- tool: 'lasso',
- lassoMode: 'freehand',
- bboxAspectRatioLocked: false,
- hasActiveTextSession: false,
- })
- ).toEqual(['modSubtractMask', 'spacePan']);
+ expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'lasso' }))).toEqual(['modErase', 'spacePan']);
});
it('switches the bbox aspect-ratio hint based on lock state', () => {
- expect(
- getCanvasToolModifierHintIds({
- tool: 'bbox',
- lassoMode: 'freehand',
- bboxAspectRatioLocked: false,
- hasActiveTextSession: false,
- })
- ).toEqual(['shiftLockAspectRatio', 'altScaleFromCenter', 'modFineGrid']);
+ expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'bbox' }))).toEqual([
+ 'shiftLockAspectRatio',
+ 'altScaleFromCenter',
+ 'modFineGrid',
+ ]);
- expect(
- getCanvasToolModifierHintIds({
- tool: 'bbox',
- lassoMode: 'freehand',
- bboxAspectRatioLocked: true,
- hasActiveTextSession: false,
- })
- ).toEqual(['shiftUnlockAspectRatio', 'altScaleFromCenter', 'modFineGrid']);
+ expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'bbox', bboxAspectRatioLocked: true }))).toEqual([
+ 'shiftUnlockAspectRatio',
+ 'altScaleFromCenter',
+ 'modFineGrid',
+ ]);
});
it('only shows text-session hints when a text session is active', () => {
+ expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'text', hasActiveTextSession: true }))).toEqual([
+ 'enterCommitText',
+ 'shiftEnterNewLine',
+ 'escCancelText',
+ 'modDragText',
+ 'shiftSnapRotation',
+ ]);
+
+ expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'text' }))).toEqual(['spacePan', 'altPickColor']);
+ });
+
+ it('shows idle rect and oval shapes hints', () => {
+ expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'rect' }))).toEqual([
+ 'modErase',
+ 'shiftLockAspectRatio',
+ 'spacePan',
+ 'altPickColor',
+ ]);
+
+ expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'oval' }))).toEqual([
+ 'modErase',
+ 'shiftLockAspectRatio',
+ 'spacePan',
+ 'altPickColor',
+ ]);
+ });
+
+ it('shows active rect and oval drag hints', () => {
+ expect(
+ getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'rect', isPrimaryPointerDown: true }))
+ ).toEqual(['modErase', 'shiftLockAspectRatio', 'altScaleFromCenter', 'spaceMoveShape']);
+
expect(
- getCanvasToolModifierHintIds({
- tool: 'text',
- lassoMode: 'freehand',
- bboxAspectRatioLocked: false,
- hasActiveTextSession: true,
- })
- ).toEqual(['enterCommitText', 'shiftEnterNewLine', 'escCancelText', 'modDragText', 'shiftSnapRotation']);
+ getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'oval', isPrimaryPointerDown: true }))
+ ).toEqual(['modErase', 'shiftLockAspectRatio', 'altScaleFromCenter', 'spaceMoveShape']);
+ });
+
+ it('shows polygon shape hints', () => {
+ expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'polygon' }))).toEqual([
+ 'modErase',
+ 'shiftSnap45Degrees',
+ 'spacePan',
+ 'altPickColor',
+ ]);
+ });
+
+ it('omits alt color-picker hint during an active freehand stroke', () => {
+ expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'freehand' }))).toEqual([
+ 'modErase',
+ 'spacePan',
+ 'altPickColor',
+ ]);
expect(
- getCanvasToolModifierHintIds({
- tool: 'text',
- lassoMode: 'freehand',
- bboxAspectRatioLocked: false,
- hasActiveTextSession: false,
- })
- ).toEqual(['spacePan', 'altPickColor']);
+ getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'freehand', isPrimaryPointerDown: true }))
+ ).toEqual(['modErase', 'spacePan']);
});
});
diff --git a/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.ts b/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.ts
index a23682e2072..3543ac4358d 100644
--- a/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.ts
+++ b/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.ts
@@ -1,14 +1,20 @@
+import {
+ shouldQuickSwitchToColorPickerOnAlt,
+ shouldTranslateShapeDragOnSpace,
+} from 'features/controlLayers/konva/CanvasTool/toolHotkeys';
import type { Tool } from 'features/controlLayers/store/types';
+type ShapeType = 'rect' | 'oval' | 'polygon' | 'freehand';
type CanvasToolModifierHintKey = 'mod' | 'shift' | 'alt' | 'space' | 'wheel' | 'arrows' | 'enter' | 'esc';
type CanvasToolModifierHintId =
| 'spacePan'
+ | 'spaceMoveShape'
| 'altPickColor'
| 'shiftStraightLine'
| 'modWheelResizeBrush'
| 'modWheelResizeEraser'
- | 'modSubtractMask'
+ | 'modErase'
| 'shiftSnap45Degrees'
| 'shiftLockAspectRatio'
| 'shiftUnlockAspectRatio'
@@ -35,6 +41,11 @@ const HINTS: Record = {
keys: ['space'],
labelKey: 'controlLayers.modifierHints.labels.pan',
},
+ spaceMoveShape: {
+ id: 'spaceMoveShape',
+ keys: ['space'],
+ labelKey: 'controlLayers.modifierHints.labels.moveShape',
+ },
altPickColor: {
id: 'altPickColor',
keys: ['alt'],
@@ -55,10 +66,10 @@ const HINTS: Record = {
keys: ['mod', 'wheel'],
labelKey: 'controlLayers.modifierHints.labels.resizeEraser',
},
- modSubtractMask: {
- id: 'modSubtractMask',
+ modErase: {
+ id: 'modErase',
keys: ['mod'],
- labelKey: 'controlLayers.modifierHints.labels.subtractMask',
+ labelKey: 'controlLayers.modifierHints.labels.erase',
},
shiftSnap45Degrees: {
id: 'shiftSnap45Degrees',
@@ -120,8 +131,10 @@ const HINTS: Record = {
type GetCanvasToolModifierHintsArg = {
tool: Tool;
lassoMode: 'freehand' | 'polygon';
+ shapeType: ShapeType;
bboxAspectRatioLocked: boolean;
hasActiveTextSession: boolean;
+ isPrimaryPointerDown: boolean;
};
const mapHintIdsToHints = (hintIds: readonly CanvasToolModifierHintId[]): CanvasToolModifierHint[] =>
@@ -130,8 +143,10 @@ const mapHintIdsToHints = (hintIds: readonly CanvasToolModifierHintId[]): Canvas
export const getCanvasToolModifierHintIds = ({
tool,
lassoMode,
+ shapeType,
bboxAspectRatioLocked,
hasActiveTextSession,
+ isPrimaryPointerDown,
}: GetCanvasToolModifierHintsArg): CanvasToolModifierHintId[] => {
// Resolver map: each tool returns the relevant hint ids based on the provided args.
const TOOL_HINT_RESOLVERS: Record<
@@ -141,7 +156,7 @@ export const getCanvasToolModifierHintIds = ({
brush: () => ['shiftStraightLine', 'modWheelResizeBrush', ...SHARED_HINT_IDS],
eraser: () => ['shiftStraightLine', 'modWheelResizeEraser', 'spacePan'],
lasso: ({ lassoMode: lm }) =>
- lm === 'polygon' ? ['modSubtractMask', 'shiftSnap45Degrees', 'spacePan'] : ['modSubtractMask', 'spacePan'],
+ lm === 'polygon' ? ['modErase', 'shiftSnap45Degrees', 'spacePan'] : ['modErase', 'spacePan'],
bbox: ({ bboxAspectRatioLocked: locked }) => [
locked ? 'shiftUnlockAspectRatio' : 'shiftLockAspectRatio',
'altScaleFromCenter',
@@ -155,7 +170,21 @@ export const getCanvasToolModifierHintIds = ({
view: () => ['altPickColor'],
colorPicker: () => ['spacePan'],
gradient: () => [...SHARED_HINT_IDS],
- rect: () => [...SHARED_HINT_IDS],
+ rect: ({ shapeType: st, isPrimaryPointerDown: pointerDown }) => {
+ if (st === 'polygon') {
+ return ['modErase', 'shiftSnap45Degrees', 'spacePan', 'altPickColor'];
+ }
+
+ if (st === 'freehand') {
+ return shouldQuickSwitchToColorPickerOnAlt('rect', st, pointerDown)
+ ? ['modErase', 'spacePan', 'altPickColor']
+ : ['modErase', 'spacePan'];
+ }
+
+ return shouldTranslateShapeDragOnSpace('rect', st, pointerDown, pointerDown)
+ ? ['modErase', 'shiftLockAspectRatio', 'altScaleFromCenter', 'spaceMoveShape']
+ : ['modErase', 'shiftLockAspectRatio', 'spacePan', 'altPickColor'];
+ },
};
const resolver = TOOL_HINT_RESOLVERS[tool];
@@ -163,7 +192,9 @@ export const getCanvasToolModifierHintIds = ({
if (!resolver) {
return [];
}
- return Array.from(resolver({ tool, lassoMode, bboxAspectRatioLocked, hasActiveTextSession }));
+ return Array.from(
+ resolver({ tool, lassoMode, shapeType, bboxAspectRatioLocked, hasActiveTextSession, isPrimaryPointerDown })
+ );
};
export const getCanvasToolModifierHints = (args: GetCanvasToolModifierHintsArg): CanvasToolModifierHint[] =>
diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts
index 68f24a26ec1..ef1d7d92f43 100644
--- a/invokeai/frontend/web/src/services/api/schema.ts
+++ b/invokeai/frontend/web/src/services/api/schema.ts
@@ -1875,7 +1875,9 @@ export type paths = {
};
/**
* Get Queue Item Ids
- * @description Gets all queue item ids that match the given parameters. Non-admin users only see their own items.
+ * @description Gets all queue item ids that match the given parameters. The IDs themselves are not sensitive;
+ * per-item field redaction is performed when the items are fetched via list_all_queue_items or
+ * get_queue_items_by_item_ids.
*/
get: operations["get_queue_item_ids"];
put?: never;
@@ -2135,7 +2137,9 @@ export type paths = {
};
/**
* Get Queue Status
- * @description Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it.
+ * @description Gets the status of the session queue. Returns global counts plus the calling user's own
+ * pending/in_progress counts (so the UI can show an X/Y badge). Non-admin users cannot see the
+ * current item's identifiers unless they own it.
*/
get: operations["get_queue_status"];
put?: never;
@@ -16183,6 +16187,7 @@ export type components = {
* force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
* pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
* max_queue_size: Maximum number of items in the session queue.
+ * session_queue_mode: Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.
Valid values: `FIFO`, `round_robin`
* clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`.
* max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.
* allow_nodes: List of nodes to allow. Omit to allow all.
@@ -16525,6 +16530,13 @@ export type components = {
* @default 10000
*/
max_queue_size?: number;
+ /**
+ * Session Queue Mode
+ * @description Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.
+ * @default round_robin
+ * @enum {string}
+ */
+ session_queue_mode?: "FIFO" | "round_robin";
/**
* Clear Queue On Startup
* @description Empties session queue on startup. If true, disables `max_queue_history`.
@@ -28086,6 +28098,16 @@ export type components = {
* @description Total number of queue items
*/
total: number;
+ /**
+ * User Pending
+ * @description Number of pending queue items for the calling user (multiuser only)
+ */
+ user_pending?: number | null;
+ /**
+ * User In Progress
+ * @description Number of in-progress queue items for the calling user (multiuser only)
+ */
+ user_in_progress?: number | null;
};
/**
* SetupRequest
diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx
index 1e73abb2027..b9d364cf3b4 100644
--- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx
+++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx
@@ -389,6 +389,24 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
});
socket.on('queue_item_status_changed', (data) => {
+ // Sanitized companion event sent to non-owner queue subscribers in multiuser mode. The
+ // backend sets user_id="redacted" and clears identifiers/error fields. We must not run
+ // payload-driven cache mutations or per-session side effects (node state reset, progress
+ // clear, completion bookkeeping) — those belong to the owner. Just invalidate queue tags
+ // so the non-owner's queue list and badge counts refetch with sanitized data.
+ if (data.user_id === 'redacted') {
+ log.trace({ data }, `Sanitized queue_item_status_changed for item ${data.item_id}`);
+ const tags: ApiTagDescription[] = [
+ 'SessionQueueStatus',
+ 'SessionQueueItemIdList',
+ { type: 'SessionQueueItem', id: data.item_id },
+ { type: 'SessionQueueItem', id: LIST_TAG },
+ { type: 'SessionQueueItem', id: LIST_ALL_TAG },
+ ];
+ dispatch(queueApi.util.invalidateTags(tags));
+ return;
+ }
+
if (finishedQueueItemIds.has(data.item_id)) {
log.trace({ data }, `Received event for already-finished queue item ${data.item_id}`);
return;
diff --git a/tests/app/routers/test_multiuser_authorization.py b/tests/app/routers/test_multiuser_authorization.py
index 3461f37e7e9..90bc82b30cf 100644
--- a/tests/app/routers/test_multiuser_authorization.py
+++ b/tests/app/routers/test_multiuser_authorization.py
@@ -1333,14 +1333,31 @@ def test_get_queue_status_hides_current_item_for_non_owner(self):
assert status_obj.session_id is None
assert status_obj.batch_id is None
- def test_session_queue_status_no_user_fields(self):
- """SessionQueueStatus should not have user_pending/user_in_progress fields anymore.
- Non-admin users now get their own counts in the main pending/in_progress fields."""
+ def test_session_queue_status_has_user_fields(self):
+ """SessionQueueStatus exposes user_pending/user_in_progress so the queue badge
+ can render an X/Y count (X = caller's jobs, Y = global total)."""
from invokeai.app.services.session_queue.session_queue_common import SessionQueueStatus
fields = set(SessionQueueStatus.model_fields.keys())
- assert "user_pending" not in fields
- assert "user_in_progress" not in fields
+ assert "user_pending" in fields
+ assert "user_in_progress" in fields
+
+ status_obj = SessionQueueStatus(
+ queue_id="default",
+ item_id=None,
+ session_id=None,
+ batch_id=None,
+ pending=5,
+ in_progress=1,
+ completed=0,
+ failed=0,
+ canceled=0,
+ total=6,
+ user_pending=2,
+ user_in_progress=1,
+ )
+ assert status_obj.user_pending == 2
+ assert status_obj.user_in_progress == 1
# ===========================================================================
@@ -1708,8 +1725,11 @@ def test_batch_enqueued_event_carries_user_id(self) -> None:
assert event.queue_id == "default"
def test_queue_item_status_changed_routed_privately(self, socketio: Any) -> None:
- """Verify that _handle_queue_event emits QueueItemStatusChangedEvent ONLY to
- user:{user_id} and admin rooms, never to the queue_id room."""
+ """_handle_queue_event must emit the FULL QueueItemStatusChangedEvent only to the
+ owner's user room and the admin room. A sanitized companion (user_id="redacted",
+ identifiers stripped) is also emitted to the queue_id room so other users' UIs can
+ refresh, with the owner's and admins' sids in skip_sid so they don't get a duplicate
+ that would clobber their cache."""
import asyncio
from unittest.mock import AsyncMock
@@ -1758,20 +1778,60 @@ def test_queue_item_status_changed_routed_privately(self, socketio: Any) -> None
),
)
+ # Track owner sid so we can verify skip_sid is honored
+ socketio._socket_users["sid-owner"] = {"user_id": "owner-xyz", "is_admin": False}
+ socketio._socket_users["sid-admin"] = {"user_id": "admin-1", "is_admin": True}
+ socketio._socket_users["sid-other"] = {"user_id": "other-user", "is_admin": False}
+
mock_emit = AsyncMock()
socketio._sio.emit = mock_emit
asyncio.run(socketio._handle_queue_event(("queue_item_status_changed", event)))
- rooms_emitted_to = [call.kwargs.get("room") for call in mock_emit.call_args_list]
- assert "user:owner-xyz" in rooms_emitted_to
- assert "admin" in rooms_emitted_to
- # CRITICAL: must NOT emit to the queue_id room — that would leak to other users
- assert "default" not in rooms_emitted_to
+ # Collect (room, payload, skip_sid) for each emit call
+ emits = [
+ (c.kwargs.get("room"), c.kwargs.get("data"), c.kwargs.get("skip_sid")) for c in mock_emit.call_args_list
+ ]
+
+ # Full event must go to owner room and admin room with original sensitive fields
+ owner_emits = [(p, s) for r, p, s in emits if r == "user:owner-xyz"]
+ admin_emits = [(p, s) for r, p, s in emits if r == "admin"]
+ assert len(owner_emits) == 1 and len(admin_emits) == 1
+ for payload, _ in owner_emits + admin_emits:
+ assert payload["user_id"] == "owner-xyz"
+ assert payload["batch_id"] == "batch-private"
+ assert payload["session_id"] == "sess-private"
+ assert payload["destination"] == "canvas"
+
+ # A sanitized companion event must go to the queue_id room with sensitive fields cleared
+ queue_emits = [(p, s) for r, p, s in emits if r == "default"]
+ assert len(queue_emits) == 1, "expected exactly one sanitized emit to queue room"
+ sanitized_payload, skip_sid = queue_emits[0]
+ assert sanitized_payload["user_id"] == "redacted"
+ assert sanitized_payload["batch_id"] == "redacted"
+ assert sanitized_payload["session_id"] == "redacted"
+ assert sanitized_payload["origin"] is None
+ assert sanitized_payload["destination"] is None
+ assert sanitized_payload["error_type"] is None
+ assert sanitized_payload["batch_status"]["batch_id"] == "redacted"
+ assert sanitized_payload["batch_status"]["destination"] is None
+ assert sanitized_payload["queue_status"]["item_id"] is None
+ assert sanitized_payload["queue_status"]["batch_id"] is None
+ assert sanitized_payload["queue_status"]["user_pending"] is None
+ # Owner and admin sids must be skipped so they don't receive the duplicate
+ assert "sid-owner" in skip_sid
+ assert "sid-admin" in skip_sid
+ # Third-party user must NOT be skipped — they need the sanitized event
+ assert "sid-other" not in skip_sid
+ # Status (non-sensitive) is preserved so the non-owner UI knows what changed
+ assert sanitized_payload["status"] == "in_progress"
+ assert sanitized_payload["item_id"] == 1
def test_batch_enqueued_routed_privately(self, socketio: Any) -> None:
- """Verify that _handle_queue_event emits BatchEnqueuedEvent ONLY to
- user:{user_id} and admin rooms, never to the queue_id room."""
+ """_handle_queue_event must emit the FULL BatchEnqueuedEvent only to the owner's
+ user room and the admin room. A sanitized companion (user_id="redacted", batch_id
+ and origin stripped) is also emitted to the queue_id room so other users' badge
+ totals refresh, with owner/admin sids in skip_sid."""
import asyncio
from unittest.mock import AsyncMock
@@ -1792,15 +1852,39 @@ def test_batch_enqueued_routed_privately(self, socketio: Any) -> None:
)
event = BatchEnqueuedEvent.build(enqueue_result, user_id="owner-zzz")
+ socketio._socket_users["sid-owner"] = {"user_id": "owner-zzz", "is_admin": False}
+ socketio._socket_users["sid-admin"] = {"user_id": "admin-1", "is_admin": True}
+ socketio._socket_users["sid-other"] = {"user_id": "other-user", "is_admin": False}
+
mock_emit = AsyncMock()
socketio._sio.emit = mock_emit
asyncio.run(socketio._handle_queue_event(("batch_enqueued", event)))
- rooms_emitted_to = [call.kwargs.get("room") for call in mock_emit.call_args_list]
- assert "user:owner-zzz" in rooms_emitted_to
- assert "admin" in rooms_emitted_to
- assert "default" not in rooms_emitted_to
+ emits = [
+ (c.kwargs.get("room"), c.kwargs.get("data"), c.kwargs.get("skip_sid")) for c in mock_emit.call_args_list
+ ]
+
+ # Full event to owner + admin contains the real batch_id and origin
+ owner_emits = [(p, s) for r, p, s in emits if r == "user:owner-zzz"]
+ admin_emits = [(p, s) for r, p, s in emits if r == "admin"]
+ assert len(owner_emits) == 1 and len(admin_emits) == 1
+ for payload, _ in owner_emits + admin_emits:
+ assert payload["user_id"] == "owner-zzz"
+ assert payload["batch_id"] == "batch-pvt"
+ assert payload["origin"] == "workflows"
+
+ # Sanitized event to queue room: user/batch/origin redacted, owner+admin skipped
+ queue_emits = [(p, s) for r, p, s in emits if r == "default"]
+ assert len(queue_emits) == 1
+ sanitized_payload, skip_sid = queue_emits[0]
+ assert sanitized_payload["user_id"] == "redacted"
+ assert sanitized_payload["batch_id"] == "redacted"
+ assert sanitized_payload["origin"] is None
+ assert sanitized_payload["enqueued"] == 5 # count is non-sensitive
+ assert "sid-owner" in skip_sid
+ assert "sid-admin" in skip_sid
+ assert "sid-other" not in skip_sid
def test_queue_cleared_still_broadcast(self, socketio: Any) -> None:
"""QueueClearedEvent does not carry user identity and should still be broadcast
diff --git a/tests/app/services/session_queue/test_session_queue_dequeue.py b/tests/app/services/session_queue/test_session_queue_dequeue.py
new file mode 100644
index 00000000000..0f82f2babaa
--- /dev/null
+++ b/tests/app/services/session_queue/test_session_queue_dequeue.py
@@ -0,0 +1,214 @@
+"""Tests for session queue dequeue() ordering: FIFO and round-robin modes."""
+
+import json
+import uuid
+from typing import Optional
+
+import pytest
+from pydantic_core import to_jsonable_python
+
+from invokeai.app.services.config.config_default import InvokeAIAppConfig
+from invokeai.app.services.invoker import Invoker
+from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
+from invokeai.app.services.shared.graph import Graph, GraphExecutionState
+
+_EMPTY_SESSION_JSON = json.dumps(to_jsonable_python(GraphExecutionState(graph=Graph()).model_dump()))
+
+
+@pytest.fixture
+def session_queue_fifo(mock_invoker: Invoker) -> SqliteSessionQueue:
+ """Queue backed by a single-user (FIFO) invoker."""
+ # Default config has multiuser=False, so FIFO is always used.
+ db = mock_invoker.services.board_records._db
+ queue = SqliteSessionQueue(db=db)
+ queue.start(mock_invoker)
+ return queue
+
+
+@pytest.fixture
+def session_queue_round_robin(mock_invoker: Invoker) -> SqliteSessionQueue:
+ """Queue backed by a multiuser invoker with round_robin mode."""
+ mock_invoker.services.configuration = InvokeAIAppConfig(
+ use_memory_db=True,
+ node_cache_size=0,
+ multiuser=True,
+ session_queue_mode="round_robin",
+ )
+ db = mock_invoker.services.board_records._db
+ queue = SqliteSessionQueue(db=db)
+ queue.start(mock_invoker)
+ return queue
+
+
+def _insert_queue_item(
+ session_queue: SqliteSessionQueue,
+ queue_id: str,
+ user_id: str,
+ priority: int = 0,
+) -> int:
+ """Directly insert a minimal queue item and return its item_id."""
+ session_id = str(uuid.uuid4())
+ batch_id = str(uuid.uuid4())
+ with session_queue._db.transaction() as cursor:
+ cursor.execute(
+ """--sql
+ INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id, user_id)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ (queue_id, _EMPTY_SESSION_JSON, session_id, batch_id, None, priority, None, None, None, None, user_id),
+ )
+ return cursor.lastrowid # type: ignore[return-value]
+
+
+def _dequeue_user_ids(session_queue: SqliteSessionQueue, count: int) -> list[Optional[str]]:
+ """Dequeue `count` items and return the list of user_ids in dequeue order."""
+ result = []
+ for _ in range(count):
+ item = session_queue.dequeue()
+ result.append(item.user_id if item is not None else None)
+ return result
+
+
+# ---------------------------------------------------------------------------
+# FIFO tests
+# ---------------------------------------------------------------------------
+
+
+def test_fifo_single_user_order(session_queue_fifo: SqliteSessionQueue) -> None:
+ """FIFO: items from a single user are dequeued in insertion order."""
+ queue_id = "default"
+ _insert_queue_item(session_queue_fifo, queue_id, "user_a")
+ _insert_queue_item(session_queue_fifo, queue_id, "user_a")
+ _insert_queue_item(session_queue_fifo, queue_id, "user_a")
+
+ user_ids = _dequeue_user_ids(session_queue_fifo, 3)
+ assert user_ids == ["user_a", "user_a", "user_a"]
+
+
+def test_fifo_multi_user_preserves_insertion_order(session_queue_fifo: SqliteSessionQueue) -> None:
+ """FIFO: jobs from multiple users are dequeued in strict insertion order, not interleaved."""
+ queue_id = "default"
+ # Insert A1, A2, B1, C1, C2, A3 – FIFO should preserve this exact order.
+ _insert_queue_item(session_queue_fifo, queue_id, "user_a")
+ _insert_queue_item(session_queue_fifo, queue_id, "user_a")
+ _insert_queue_item(session_queue_fifo, queue_id, "user_b")
+ _insert_queue_item(session_queue_fifo, queue_id, "user_c")
+ _insert_queue_item(session_queue_fifo, queue_id, "user_c")
+ _insert_queue_item(session_queue_fifo, queue_id, "user_a")
+
+ user_ids = _dequeue_user_ids(session_queue_fifo, 6)
+ assert user_ids == ["user_a", "user_a", "user_b", "user_c", "user_c", "user_a"]
+
+
+def test_fifo_priority_respected(session_queue_fifo: SqliteSessionQueue) -> None:
+ """FIFO: higher-priority items are dequeued before lower-priority ones."""
+ queue_id = "default"
+ _insert_queue_item(session_queue_fifo, queue_id, "user_a", priority=0)
+ _insert_queue_item(session_queue_fifo, queue_id, "user_a", priority=10)
+
+ user_ids = _dequeue_user_ids(session_queue_fifo, 2)
+ # Both are user_a; second inserted item has higher priority and should come first.
+ assert user_ids == ["user_a", "user_a"]
+
+
+def test_fifo_returns_none_when_empty(session_queue_fifo: SqliteSessionQueue) -> None:
+ """FIFO: dequeue returns None when the queue is empty."""
+ assert session_queue_fifo.dequeue() is None
+
+
+# ---------------------------------------------------------------------------
+# Round-robin tests
+# ---------------------------------------------------------------------------
+
+
+def test_round_robin_interleaves_users(session_queue_round_robin: SqliteSessionQueue) -> None:
+ """Round-robin: jobs from multiple users are interleaved one per user per round.
+
+ Queue insertion order (matching the issue example):
+ A job 1, A job 2, B job 1, C job 1, C job 2, A job 3
+
+ Expected dequeue order:
+ A job 1, B job 1, C job 1, A job 2, C job 2, A job 3
+ """
+ queue_id = "default"
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_a")
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_a")
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_b")
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_c")
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_c")
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_a")
+
+ user_ids = _dequeue_user_ids(session_queue_round_robin, 6)
+ assert user_ids == ["user_a", "user_b", "user_c", "user_a", "user_c", "user_a"]
+
+
+def test_round_robin_single_user_behaves_like_fifo(session_queue_round_robin: SqliteSessionQueue) -> None:
+ """Round-robin with only one user produces the same order as FIFO."""
+ queue_id = "default"
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_a")
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_a")
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_a")
+
+ user_ids = _dequeue_user_ids(session_queue_round_robin, 3)
+ assert user_ids == ["user_a", "user_a", "user_a"]
+
+
+def test_round_robin_handles_user_joining_mid_queue(session_queue_round_robin: SqliteSessionQueue) -> None:
+ """Round-robin: a user who joins later is correctly interleaved."""
+ queue_id = "default"
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_a")
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_a")
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_b")
+
+ user_ids = _dequeue_user_ids(session_queue_round_robin, 3)
+ # Round 1: A (oldest rank-1 item), B (rank-1 item)
+ # Round 2: A (rank-2 item)
+ assert user_ids == ["user_a", "user_b", "user_a"]
+
+
+def test_round_robin_returns_none_when_empty(session_queue_round_robin: SqliteSessionQueue) -> None:
+ """Round-robin: dequeue returns None when the queue is empty."""
+ assert session_queue_round_robin.dequeue() is None
+
+
+def test_round_robin_priority_within_user_respected(session_queue_round_robin: SqliteSessionQueue) -> None:
+ """Round-robin: within a single user's items, higher priority is dequeued first."""
+ queue_id = "default"
+ # Insert low-priority item first, then high-priority for same user.
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_a", priority=0)
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_a", priority=10)
+ _insert_queue_item(session_queue_round_robin, queue_id, "user_b", priority=0)
+
+ # Round 1: user_a's best item (priority 10), user_b's only item.
+ # Round 2: user_a's remaining item (priority 0).
+ items = []
+ for _ in range(3):
+ item = session_queue_round_robin.dequeue()
+ assert item is not None
+ items.append((item.user_id, item.priority))
+
+ assert items[0] == ("user_a", 10)
+ assert items[1] == ("user_b", 0)
+ assert items[2] == ("user_a", 0)
+
+
+def test_round_robin_ignored_in_single_user_mode(mock_invoker: Invoker) -> None:
+ """When multiuser=False, round_robin config is ignored and FIFO is used."""
+ mock_invoker.services.configuration = InvokeAIAppConfig(
+ use_memory_db=True,
+ node_cache_size=0,
+ multiuser=False,
+ session_queue_mode="round_robin",
+ )
+ db = mock_invoker.services.board_records._db
+ queue = SqliteSessionQueue(db=db)
+ queue.start(mock_invoker)
+
+ queue_id = "default"
+ _insert_queue_item(queue, queue_id, "user_a")
+ _insert_queue_item(queue, queue_id, "user_a")
+ _insert_queue_item(queue, queue_id, "user_b")
+
+ # FIFO order: user_a, user_a, user_b
+ user_ids = _dequeue_user_ids(queue, 3)
+ assert user_ids == ["user_a", "user_a", "user_b"]