diff --git a/docs/src/content/docs/features/shapes-tool.mdx b/docs/src/content/docs/features/shapes-tool.mdx new file mode 100644 index 00000000000..6ad795aed7a --- /dev/null +++ b/docs/src/content/docs/features/shapes-tool.mdx @@ -0,0 +1,96 @@ +--- +title: Shapes Tool +description: Learn how to draw filled shapes on raster and inpaint mask layers with the Shapes tool. +lastUpdated: 2026-05-11 +--- + +import { Card, CardGrid } from '@astrojs/starlight/components'; + +The Shapes tool is a general-purpose filled-shape drawing tool for the canvas. It replaces the old Rectangle tool and +adds four shape modes under a single toolbar button: + +- **Rect** +- **Oval** +- **Polygon** +- **Freehand** + +You can activate the Shapes tool from the canvas toolbar or with the default hotkey U. + +## Where Shapes Draws + +Shapes always draws into the **active raster target**: + +- On a regular raster layer, Shapes adds filled pixels to that layer. +- On an active inpaint mask layer, Shapes draws directly into the mask. + +:::note +Shapes overlaps with some Lasso workflows on mask layers, but the tools are not identical. Lasso is still the more +specialized masking tool and can create a new mask layer automatically when one does not already exist. +::: + +## Common Behavior + +- Shapes preview live while you draw. +- The fill color uses the current active color. +- The active color's alpha is respected when adding pixels. +- Hold Ctrl on Windows/Linux or Cmd on macOS to switch to **subtractive** mode and cut pixels + out of the active layer. +- In subtractive mode, alpha is ignored and the shape fully clears pixels. +- Press Esc to cancel the current shape session. + +:::tip +When subtractive mode is active, the canvas cursor shows a small minus badge so you can tell at a glance that the next +shape will erase instead of fill. +::: + +## Shape Modes + + + + Drag to draw a rectangle. Hold Shift to constrain to a square. Hold Alt to draw from the + center instead of from a corner. + + + Drag to draw an ellipse. Hold Shift to constrain to a perfect circle. Hold Alt to draw from + the center. + + + Click to place vertices. Click the first point to close and commit the shape. Hold Shift to snap the + pending edge to horizontal, vertical, and 45 degree angles. + + + Click and drag to sketch a filled freehand contour. Release the pointer to commit the shape. + + + +## Moving and Panning During Drawing + +The Shapes tool supports different Space behavior depending on the current mode: + +- **Rect / Oval:** While the pointer is still down, hold Space to move the uncommitted shape instead of + resizing it. Release Space to continue resizing. +- **Polygon / Freehand:** Hold Space during an active session to pan the viewport without discarding the + unfinished shape. + +This is especially useful when drawing large shapes that extend beyond the current viewport. + +## Color Picking While Using Shapes + +The Alt key behaves differently depending on the active Shapes mode: + +- **Rect / Oval:** Before you start dragging, Alt can be used for the temporary color-picker quick-switch. + Once a drag is active, Alt is reserved for drawing from the center. +- **Polygon:** Alt remains available for the temporary color-picker quick-switch between vertex placements. +- **Freehand:** Alt is available before the stroke starts, but not during an active stroke. + +## Practical Examples + +- Use **Rect** or **Oval** to block in clean mask regions quickly. +- Use **Polygon** when you need straight edges and deliberate corner placement. +- Use **Freehand** for irregular organic regions. +- Use **subtractive mode** to cut holes back out of an existing raster or mask layer. + +## Summary + +The Shapes tool is the fastest way to add filled geometric or freeform regions to canvas layers. Use it for structured +fills, mask authoring, and precise subtractive edits without switching away from the current raster target. diff --git a/docs/src/generated/settings.json b/docs/src/generated/settings.json index 88a42f8fbcf..f52420ca4b2 100644 --- a/docs/src/generated/settings.json +++ b/docs/src/generated/settings.json @@ -590,6 +590,20 @@ "type": "", "validation": {} }, + { + "category": "GENERATION", + "default": "round_robin", + "description": "Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.", + "env_var": "INVOKEAI_SESSION_QUEUE_MODE", + "literal_values": [ + "FIFO", + "round_robin" + ], + "name": "session_queue_mode", + "required": false, + "type": "typing.Literal['FIFO', 'round_robin']", + "validation": {} + }, { "category": "GENERATION", "default": false, diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 41a5a411c7a..d62cac5095f 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -141,12 +141,11 @@ async def get_queue_item_ids( queue_id: str = Path(description="The queue id to perform this operation on"), order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"), ) -> ItemIdsResult: - """Gets all queue item ids that match the given parameters. Non-admin users only see their own items.""" + """Gets all queue item ids that match the given parameters. The IDs themselves are not sensitive; + per-item field redaction is performed when the items are fetched via list_all_queue_items or + get_queue_items_by_item_ids.""" try: - user_id = None if current_user.is_admin else current_user.user_id - return ApiDependencies.invoker.services.session_queue.get_queue_item_ids( - queue_id=queue_id, order_dir=order_dir, user_id=user_id - ) + return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}") @@ -436,10 +435,15 @@ async def get_queue_status( current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> SessionQueueAndProcessorStatus: - """Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it.""" + """Gets the status of the session queue. Returns global counts plus the calling user's own + pending/in_progress counts (so the UI can show an X/Y badge). Non-admin users cannot see the + current item's identifiers unless they own it.""" try: - user_id = None if current_user.is_admin else current_user.user_id - queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=user_id) + queue = ApiDependencies.invoker.services.session_queue.get_queue_status( + queue_id, + user_id=current_user.user_id, + is_admin=current_user.is_admin, + ) processor = ApiDependencies.invoker.services.session_processor.get_status() return SessionQueueAndProcessorStatus(queue=queue, processor=processor) except Exception as e: diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 5783b804c0b..b02b5bbb067 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -260,20 +260,37 @@ async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) + def _owner_and_admin_sids(self, owner_user_id: str) -> list[str]: + """Sids belonging to the event's owner or to any admin. + + Used as `skip_sid` when broadcasting a sanitized companion event to the queue room, + so the owner and admins (who already received the full event) don't get a second + copy that would clobber their cache with redacted values. + """ + return [ + sid + for sid, info in self._socket_users.items() + if info.get("user_id") == owner_user_id or info.get("is_admin") + ] + async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): """Handle queue events with user isolation. - All queue item events (invocation events AND QueueItemStatusChangedEvent) are - private to the owning user and admins. They carry unsanitized user_id, batch_id, - session_id, origin, destination and error metadata, and must never be broadcast - to the whole queue room — otherwise any other authenticated subscriber could - observe cross-user queue activity. + Queue events split into two routing paths: - RecallParametersUpdatedEvent is also private to the owner + admins. + 1. The owner and admins receive the full unsanitized event in their `user:{id}` / + `admin` rooms. The full payload may include batch_id, session_id, origin, + destination, error metadata, etc. - BatchEnqueuedEvent carries the enqueuing user's batch_id/origin/counts and - is also routed privately. QueueClearedEvent is the only queue event that - is still broadcast to the whole queue room. + 2. For events that other authenticated users need to know about so their queue list + and badge counts stay in sync (QueueItemStatusChangedEvent and BatchEnqueuedEvent), + a sanitized companion event is also emitted to the full queue room with the + owner's and admins' sids in `skip_sid`. The companion uses `user_id="redacted"` + as a sentinel so the frontend handler knows to do tag invalidation only and skip + per-session side effects. + + InvocationEventBase events stay private (owner + admins only). RecallParametersUpdatedEvent + is also private. QueueClearedEvent has no user identity and is broadcast to the queue room. IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase inherits from QueueItemEventBase. The order of isinstance checks matters! @@ -302,10 +319,51 @@ async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room") - # Other queue item events (QueueItemStatusChangedEvent) carry unsanitized - # user_id, batch_id, session_id, origin, destination and error metadata. - # They are private to the owning user + admins — never broadcast to the - # full queue room. + # QueueItemStatusChangedEvent: full to owner+admin, sanitized to everyone else in + # the queue room so their queue list, badge, and item caches refresh. + elif isinstance(event_data, QueueItemStatusChangedEvent): + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + + sanitized = event_data.model_copy( + update={ + "user_id": "redacted", + "batch_id": "redacted", + "session_id": "redacted", + "origin": None, + "destination": None, + "error_type": None, + "error_message": None, + "error_traceback": None, + } + ) + # Strip identifying fields out of the embedded batch_status / queue_status too. + sanitized.batch_status = sanitized.batch_status.model_copy( + update={"batch_id": "redacted", "origin": None, "destination": None} + ) + sanitized.queue_status = sanitized.queue_status.model_copy( + update={ + "item_id": None, + "session_id": None, + "batch_id": None, + "user_pending": None, + "user_in_progress": None, + } + ) + await self._sio.emit( + event=event_name, + data=sanitized.model_dump(mode="json"), + room=event_data.queue_id, + skip_sid=self._owner_and_admin_sids(event_data.user_id), + ) + + logger.debug( + f"Emitted queue_item_status_changed: full to {user_room}+admin, sanitized to queue {event_data.queue_id}" + ) + + # Other queue item events (currently none beyond QueueItemStatusChangedEvent that + # carry user_id) stay private to owner + admins. elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"): user_room = f"user:{event_data.user_id}" await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) @@ -320,14 +378,25 @@ async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") logger.debug(f"Emitted private recall_parameters_updated event to user room {user_room} and admin room") - # BatchEnqueuedEvent carries the enqueuing user's batch_id, origin, and - # enqueued counts. Route it privately to the owner + admins so other - # users do not observe cross-user batch activity. + # BatchEnqueuedEvent: full to owner+admin, sanitized to everyone else in the queue + # room so their badge total and queue list pick up the new items. elif isinstance(event_data, BatchEnqueuedEvent): user_room = f"user:{event_data.user_id}" await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") - logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room") + + sanitized = event_data.model_copy( + update={"user_id": "redacted", "batch_id": "redacted", "origin": None} + ) + await self._sio.emit( + event=event_name, + data=sanitized.model_dump(mode="json"), + room=event_data.queue_id, + skip_sid=self._owner_and_admin_sids(event_data.user_id), + ) + logger.debug( + f"Emitted batch_enqueued: full to {user_room}+admin, sanitized to queue {event_data.queue_id}" + ) else: # For remaining queue events (e.g. QueueClearedEvent) that do not diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 57004efca39..d85b170fbab 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -30,6 +30,7 @@ ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] +SESSION_QUEUE_MODE = Literal["FIFO", "round_robin"] IMAGE_SUBFOLDER_STRATEGY = Literal["flat", "date", "type", "hash"] CONFIG_SCHEMA_VERSION = "4.0.3" EXTERNAL_PROVIDER_CONFIG_FIELDS = ( @@ -114,6 +115,7 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. max_queue_size: Maximum number of items in the session queue. + session_queue_mode: Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.
Valid values: `FIFO`, `round_robin` clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`. max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. allow_nodes: List of nodes to allow. Omit to allow all. @@ -214,6 +216,7 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") + session_queue_mode: SESSION_QUEUE_MODE = Field(default="round_robin", description="Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.") clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup. If true, disables `max_queue_history`.") max_queue_history: Optional[int] = Field(default=None, ge=0, description="Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.") diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index d87221fbbae..7472ea07f63 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -309,6 +309,12 @@ class SessionQueueStatus(BaseModel): failed: int = Field(..., description="Number of queue items with status 'error'") canceled: int = Field(..., description="Number of queue items with status 'canceled'") total: int = Field(..., description="Total number of queue items") + user_pending: Optional[int] = Field( + default=None, description="Number of pending queue items for the calling user (multiuser only)" + ) + user_in_progress: Optional[int] = Field( + default=None, description="Number of in-progress queue items for the calling user (multiuser only)" + ) class SessionQueueCountsByDestination(BaseModel): diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index a05ed468857..c1bb71e0b74 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -210,9 +210,45 @@ async def enqueue_batch( return enqueue_result def dequeue(self) -> Optional[SessionQueueItem]: - with self._db.transaction() as cursor: - cursor.execute( - """--sql + config = self.__invoker.services.configuration + use_round_robin = config.multiuser and config.session_queue_mode == "round_robin" + + if use_round_robin: + query = """--sql + WITH user_last_served AS ( + -- Track when each user last had an item started, to determine whose turn it is. + SELECT user_id, MAX(started_at) AS last_served_at + FROM session_queue + WHERE started_at IS NOT NULL + GROUP BY user_id + ), + user_next_item AS ( + -- For each user, select their single best pending item (highest priority, then oldest). + SELECT + user_id, + item_id, + ROW_NUMBER() OVER ( + PARTITION BY user_id + ORDER BY priority DESC, item_id ASC + ) AS rn + FROM session_queue + WHERE status = 'pending' + ) + SELECT + sq.*, + u.display_name AS user_display_name, + u.email AS user_email + FROM session_queue sq + LEFT JOIN users u ON sq.user_id = u.user_id + JOIN user_next_item uni ON sq.item_id = uni.item_id AND uni.rn = 1 + LEFT JOIN user_last_served uls ON sq.user_id = uls.user_id + ORDER BY + COALESCE(uls.last_served_at, '1970-01-01') ASC, + sq.item_id ASC + LIMIT 1 + """ + else: + query = """--sql SELECT sq.*, u.display_name as user_display_name, @@ -225,7 +261,9 @@ def dequeue(self) -> Optional[SessionQueueItem]: sq.item_id ASC LIMIT 1 """ - ) + + with self._db.transaction() as cursor: + cursor.execute(query) result = cast(Union[sqlite3.Row, None], cursor.fetchone()) if result is None: return None @@ -860,7 +898,18 @@ def get_queue_status( acting_user_id: Optional[str] = None, ) -> SessionQueueStatus: with self._db.transaction() as cursor: - # When user_id is provided (non-admin), only count that user's items + cursor.execute( + """--sql + SELECT status, count(*) + FROM session_queue + WHERE queue_id = ? + GROUP BY status + """, + (queue_id,), + ) + counts_result = cast(list[sqlite3.Row], cursor.fetchall()) + + user_counts_result: list[sqlite3.Row] = [] if user_id is not None: cursor.execute( """--sql @@ -871,22 +920,19 @@ def get_queue_status( """, (queue_id, user_id), ) - else: - cursor.execute( - """--sql - SELECT status, count(*) - FROM session_queue - WHERE queue_id = ? - GROUP BY status - """, - (queue_id,), - ) - counts_result = cast(list[sqlite3.Row], cursor.fetchall()) + user_counts_result = cast(list[sqlite3.Row], cursor.fetchall()) current_item = self.get_current(queue_id=queue_id) total = sum(row[1] or 0 for row in counts_result) counts: dict[str, int] = {row[0]: row[1] for row in counts_result} + user_pending: Optional[int] = None + user_in_progress: Optional[int] = None + if user_id is not None: + user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result} + user_pending = user_counts.get("pending", 0) + user_in_progress = user_counts.get("in_progress", 0) + # Redaction is decided from the same current_item snapshot used to embed identifiers, # so a concurrent transition (e.g. B finishing while A's status changes) cannot leave # stale identifiers in the result. user_id (count filter) and acting_user_id @@ -909,6 +955,8 @@ def get_queue_status( failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), total=total, + user_pending=user_pending, + user_in_progress=user_in_progress, ) def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus: diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index d99bb04a631..c164d1dafe1 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -728,8 +728,8 @@ "desc": "Select the move tool." }, "selectRectTool": { - "title": "Rect Tool", - "desc": "Select the rect tool." + "title": "Shapes Tool", + "desc": "Select the shapes tool." }, "selectLassoTool": { "title": "Lasso Tool", @@ -2881,6 +2881,10 @@ "polygon": "Polygon", "polygonHint": "Click to add points, click the first point to close." }, + "shape": { + "rect": "Rect", + "oval": "Oval" + }, "modifierHints": { "keys": { "control": "Ctrl", @@ -2896,11 +2900,12 @@ }, "labels": { "pan": "Pan", + "moveShape": "Move shape", "pickColor": "Pick color", "straightLine": "Straight line", "resizeBrush": "Resize brush", "resizeEraser": "Resize eraser", - "subtractMask": "Subtract mask", + "erase": "Erase", "snap45Degrees": "Snap to 45deg", "lockAspectRatio": "Lock ratio", "unlockAspectRatio": "Unlock ratio", @@ -2917,6 +2922,7 @@ "tool": { "brush": "Brush", "eraser": "Eraser", + "shapes": "Shapes", "rectangle": "Rectangle", "lasso": "Lasso", "gradient": "Gradient", diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/GradientIcons.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/GradientIcons.tsx index b09e46d7320..61074015e76 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Tool/GradientIcons.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/GradientIcons.tsx @@ -32,7 +32,7 @@ export const GradientLinearIcon = memo(() => { const id = useId(); const gradientId = `${id}-gradient-linear-diagonal`; return ( - + @@ -40,15 +40,15 @@ export const GradientLinearIcon = memo(() => { ); @@ -59,7 +59,7 @@ export const GradientRadialIcon = memo(() => { const id = useId(); const gradientId = `${id}-gradient-radial`; return ( - + @@ -67,13 +67,13 @@ export const GradientRadialIcon = memo(() => { ); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx index 30d82722072..c0291f8e587 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx @@ -5,7 +5,7 @@ import { ToolColorPickerButton } from 'features/controlLayers/components/Tool/To import { ToolGradientButton } from 'features/controlLayers/components/Tool/ToolGradientButton'; import { ToolLassoButton } from 'features/controlLayers/components/Tool/ToolLassoButton'; import { ToolMoveButton } from 'features/controlLayers/components/Tool/ToolMoveButton'; -import { ToolRectButton } from 'features/controlLayers/components/Tool/ToolRectButton'; +import { ToolShapesButton } from 'features/controlLayers/components/Tool/ToolShapesButton'; import { ToolTextButton } from 'features/controlLayers/components/Tool/ToolTextButton'; import React from 'react'; @@ -18,7 +18,7 @@ export const ToolChooser: React.FC = () => { - + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapeTypeToggle.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapeTypeToggle.tsx new file mode 100644 index 00000000000..2e4530e2e27 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapeTypeToggle.tsx @@ -0,0 +1,65 @@ +import { ButtonGroup, IconButton, Tooltip } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { selectShapeType, settingsShapeTypeChanged } from 'features/controlLayers/store/canvasSettingsSlice'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiCircleBold, PiPolygonBold, PiRectangleBold, PiScribbleLoopBold } from 'react-icons/pi'; + +export const ToolShapeTypeToggle = memo(() => { + const { t } = useTranslation(); + const shapeType = useAppSelector(selectShapeType); + const dispatch = useAppDispatch(); + + const onRectClick = useCallback(() => dispatch(settingsShapeTypeChanged('rect')), [dispatch]); + const onOvalClick = useCallback(() => dispatch(settingsShapeTypeChanged('oval')), [dispatch]); + const onPolygonClick = useCallback(() => dispatch(settingsShapeTypeChanged('polygon')), [dispatch]); + const onFreehandClick = useCallback(() => dispatch(settingsShapeTypeChanged('freehand')), [dispatch]); + + const rectLabel = t('controlLayers.shape.rect', { defaultValue: 'Rect' }); + const ovalLabel = t('controlLayers.shape.oval', { defaultValue: 'Oval' }); + const polygonLabel = t('controlLayers.lasso.polygon', { defaultValue: 'Polygon' }); + const freehandLabel = t('controlLayers.lasso.freehand', { defaultValue: 'Freehand' }); + + return ( + + + } + colorScheme={shapeType === 'rect' ? 'invokeBlue' : 'base'} + variant="solid" + onClick={onRectClick} + /> + + + } + colorScheme={shapeType === 'oval' ? 'invokeBlue' : 'base'} + variant="solid" + onClick={onOvalClick} + /> + + + } + colorScheme={shapeType === 'polygon' ? 'invokeBlue' : 'base'} + variant="solid" + onClick={onPolygonClick} + /> + + + } + colorScheme={shapeType === 'freehand' ? 'invokeBlue' : 'base'} + variant="solid" + onClick={onFreehandClick} + /> + + + ); +}); + +ToolShapeTypeToggle.displayName = 'ToolShapeTypeToggle'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolRectButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapesButton.tsx similarity index 58% rename from invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolRectButton.tsx rename to invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapesButton.tsx index 93029390883..3f6c546d2cf 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolRectButton.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapesButton.tsx @@ -3,32 +3,33 @@ import { useSelectTool, useToolIsSelected } from 'features/controlLayers/compone import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiRectangleBold } from 'react-icons/pi'; +import { PiShapesBold } from 'react-icons/pi'; -export const ToolRectButton = memo(() => { +export const ToolShapesButton = memo(() => { const { t } = useTranslation(); const isSelected = useToolIsSelected('rect'); - const selectRect = useSelectTool('rect'); + const selectShapes = useSelectTool('rect'); + const label = t('controlLayers.tool.shapes', { defaultValue: 'Shapes' }); useRegisteredHotkeys({ id: 'selectRectTool', category: 'canvas', - callback: selectRect, + callback: selectShapes, options: { enabled: !isSelected }, - dependencies: [isSelected, selectRect], + dependencies: [isSelected, selectShapes], }); return ( - + } + aria-label={`${label} (U)`} + icon={} colorScheme={isSelected ? 'invokeBlue' : 'base'} variant="solid" - onClick={selectRect} + onClick={selectShapes} /> ); }); -ToolRectButton.displayName = 'ToolRectButton'; +ToolShapesButton.displayName = 'ToolShapesButton'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx index bee8f5d1a34..bd72306e2e3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx @@ -7,6 +7,7 @@ import { ToolGradientClipToggle } from 'features/controlLayers/components/Tool/T import { ToolGradientModeToggle } from 'features/controlLayers/components/Tool/ToolGradientModeToggle'; import { ToolLassoModeToggle } from 'features/controlLayers/components/Tool/ToolLassoModeToggle'; import { ToolOptionsRowContainer } from 'features/controlLayers/components/Tool/ToolOptionsRowContainer'; +import { ToolShapeTypeToggle } from 'features/controlLayers/components/Tool/ToolShapeTypeToggle'; import { ToolWidthPicker } from 'features/controlLayers/components/Tool/ToolWidthPicker'; import { CanvasToolbarFitBboxToLayersButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToLayersButton'; import { CanvasToolbarFitBboxToMasksButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToMasksButton'; @@ -35,6 +36,7 @@ import { memo, useMemo } from 'react'; export const CanvasToolbar = memo(() => { const isBrushSelected = useToolIsSelected('brush'); const isEraserSelected = useToolIsSelected('eraser'); + const isShapeSelected = useToolIsSelected('rect'); const isTextSelected = useToolIsSelected('text'); const isLassoSelected = useToolIsSelected('lasso'); const isGradientSelected = useToolIsSelected('gradient'); @@ -56,9 +58,28 @@ export const CanvasToolbar = memo(() => { useCanvasToggleBboxHotkey(); return ( - - + + + {isShapeSelected && ( + + + + )} {isGradientSelected && ( @@ -72,21 +93,24 @@ export const CanvasToolbar = memo(() => { )} {isTextSelected ? : showToolWithPicker && } - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + ); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts index 9941761a2ee..b282580fec9 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts @@ -9,8 +9,11 @@ import { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva import { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient'; import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; import { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso'; +import { CanvasObjectOval } from 'features/controlLayers/konva/CanvasObject/CanvasObjectOval'; +import { CanvasObjectPolygon } from 'features/controlLayers/konva/CanvasObject/CanvasObjectPolygon'; import { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect'; import type { AnyObjectRenderer, AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types'; +import { shouldPreserveSuspendableShapesSession } from 'features/controlLayers/konva/CanvasTool/toolHotkeys'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import Konva from 'konva'; import type { Logger } from 'roarr'; @@ -83,6 +86,21 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase { this.subscriptions.add( this.manager.tool.$tool.listen(() => { if (this.hasBuffer() && !this.manager.$isBusy.get()) { + const isTemporaryShapesToolSwitch = shouldPreserveSuspendableShapesSession( + this.manager.tool.$tool.get(), + this.manager.tool.$toolBuffer.get(), + this.manager.tool.tools.rect.hasSuspendableSession() + ); + + if (isTemporaryShapesToolSwitch) { + return; + } + + if (this.state?.type === 'polygon' && this.state.previewPoint) { + this.clearBuffer(); + return; + } + this.commitBuffer(); } }) @@ -153,6 +171,24 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase { this.konva.group.add(this.renderer.konva.group); } + didRender = this.renderer.update(this.state, true); + } else if (this.state.type === 'oval') { + assert(this.renderer instanceof CanvasObjectOval || !this.renderer); + + if (!this.renderer) { + this.renderer = new CanvasObjectOval(this.state, this); + this.konva.group.add(this.renderer.konva.group); + } + + didRender = this.renderer.update(this.state, true); + } else if (this.state.type === 'polygon') { + assert(this.renderer instanceof CanvasObjectPolygon || !this.renderer); + + if (!this.renderer) { + this.renderer = new CanvasObjectPolygon(this.state, this); + this.konva.group.add(this.renderer.konva.group); + } + didRender = this.renderer.update(this.state, true); } else if (this.state.type === 'lasso') { assert(this.renderer instanceof CanvasObjectLasso || !this.renderer); @@ -240,28 +276,40 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase { this.log.trace({ buffer: this.renderer.repr() }, 'Committing buffer'); + let committedState = this.state; + + // Polygon previews render an outline while they are still live in the buffer. + // Clear that preview state before adopting the renderer into the persistent object group. + if (committedState.type === 'polygon' && this.renderer instanceof CanvasObjectPolygon) { + committedState = { ...committedState, previewPoint: undefined }; + this.state = null; + this.renderer.update(committedState, true); + } + // Move the buffer to the persistent objects group/renderers this.parent.renderer.adoptObjectRenderer(this.renderer); if (pushToState) { const entityIdentifier = this.parent.entityIdentifier; - switch (this.state.type) { + switch (committedState.type) { case 'brush_line': case 'brush_line_with_pressure': - this.manager.stateApi.addBrushLine({ entityIdentifier, brushLine: this.state }); + this.manager.stateApi.addBrushLine({ entityIdentifier, brushLine: committedState }); break; case 'eraser_line': case 'eraser_line_with_pressure': - this.manager.stateApi.addEraserLine({ entityIdentifier, eraserLine: this.state }); + this.manager.stateApi.addEraserLine({ entityIdentifier, eraserLine: committedState }); break; case 'rect': - this.manager.stateApi.addRect({ entityIdentifier, rect: this.state }); + case 'oval': + case 'polygon': + this.manager.stateApi.addShape({ entityIdentifier, shape: committedState }); break; case 'lasso': - this.manager.stateApi.addLasso({ entityIdentifier, lasso: this.state }); + this.manager.stateApi.addLasso({ entityIdentifier, lasso: committedState }); break; case 'gradient': - this.manager.stateApi.addGradient({ entityIdentifier, gradient: this.state }); + this.manager.stateApi.addGradient({ entityIdentifier, gradient: committedState }); break; } } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts index 903ccaa772c..f62ce3f9822 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts @@ -11,6 +11,8 @@ import { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva import { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient'; import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; import { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso'; +import { CanvasObjectOval } from 'features/controlLayers/konva/CanvasObject/CanvasObjectOval'; +import { CanvasObjectPolygon } from 'features/controlLayers/konva/CanvasObject/CanvasObjectPolygon'; import { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect'; import type { AnyObjectRenderer, AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types'; import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters'; @@ -398,6 +400,26 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase { this.konva.objectGroup.add(renderer.konva.group); } + didRender = renderer.update(objectState, force || isFirstRender); + } else if (objectState.type === 'oval') { + assert(renderer instanceof CanvasObjectOval || !renderer); + + if (!renderer) { + renderer = new CanvasObjectOval(objectState, this); + this.renderers.set(renderer.id, renderer); + this.konva.objectGroup.add(renderer.konva.group); + } + + didRender = renderer.update(objectState, force || isFirstRender); + } else if (objectState.type === 'polygon') { + assert(renderer instanceof CanvasObjectPolygon || !renderer); + + if (!renderer) { + renderer = new CanvasObjectPolygon(objectState, this); + this.renderers.set(renderer.id, renderer); + this.konva.objectGroup.add(renderer.konva.group); + } + didRender = renderer.update(objectState, force || isFirstRender); } else if (objectState.type === 'lasso') { assert(renderer instanceof CanvasObjectLasso || !renderer); @@ -455,10 +477,24 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase { renderer instanceof CanvasObjectEraserLine || renderer instanceof CanvasObjectEraserLineWithPressure; const isSubtractingLasso = renderer instanceof CanvasObjectLasso && renderer.state.compositeOperation === 'destination-out'; + const isSubtractRect = + renderer instanceof CanvasObjectRect && renderer.state.compositeOperation === 'destination-out'; + const isSubtractOval = + renderer instanceof CanvasObjectOval && renderer.state.compositeOperation === 'destination-out'; + const isSubtractPolygon = + renderer instanceof CanvasObjectPolygon && renderer.state.compositeOperation === 'destination-out'; const isImage = renderer instanceof CanvasObjectImage; const imageIgnoresTransparency = isImage && renderer.state.usePixelBbox === false; const hasClip = renderer instanceof CanvasObjectBrushLine && renderer.state.clip; - if (isEraserLine || isSubtractingLasso || hasClip || (isImage && !imageIgnoresTransparency)) { + if ( + isEraserLine || + isSubtractingLasso || + isSubtractRect || + isSubtractOval || + isSubtractPolygon || + hasClip || + (isImage && !imageIgnoresTransparency) + ) { needsPixelBbox = true; break; } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectOval.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectOval.ts new file mode 100644 index 00000000000..8c06268f768 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectOval.ts @@ -0,0 +1,88 @@ +import { rgbaColorToString } from 'common/util/colorCodeTransformers'; +import { deepClone } from 'common/util/deepClone'; +import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer'; +import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; +import type { CanvasOvalState } from 'features/controlLayers/store/types'; +import Konva from 'konva'; +import type { Logger } from 'roarr'; + +export class CanvasObjectOval extends CanvasModuleBase { + readonly type = 'object_oval'; + readonly id: string; + readonly path: string[]; + readonly parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer; + readonly manager: CanvasManager; + readonly log: Logger; + + state: CanvasOvalState; + konva: { + group: Konva.Group; + ellipse: Konva.Ellipse; + }; + + constructor(state: CanvasOvalState, parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer) { + super(); + this.id = state.id; + this.parent = parent; + this.manager = parent.manager; + this.path = this.manager.buildPath(this); + this.log = this.manager.buildLogger(this); + + this.log.debug({ state }, 'Creating module'); + + this.konva = { + group: new Konva.Group({ name: `${this.type}:group`, listening: false }), + ellipse: new Konva.Ellipse({ + name: `${this.type}:ellipse`, + listening: false, + radiusX: 0, + radiusY: 0, + perfectDrawEnabled: false, + }), + }; + this.konva.group.add(this.konva.ellipse); + this.state = state; + } + + update(state: CanvasOvalState, force = false): boolean { + if (force || this.state !== state) { + this.log.trace({ state }, 'Updating oval'); + const { rect, color, compositeOperation } = state; + const fill = compositeOperation === 'destination-out' ? 'rgba(255,255,255,1)' : rgbaColorToString(color); + this.konva.ellipse.setAttrs({ + x: rect.x + rect.width / 2, + y: rect.y + rect.height / 2, + radiusX: rect.width / 2, + radiusY: rect.height / 2, + fill, + globalCompositeOperation: compositeOperation, + }); + this.state = state; + return true; + } + + return false; + } + + setVisibility(isVisible: boolean): void { + this.log.trace({ isVisible }, 'Setting oval visibility'); + this.konva.group.visible(isVisible); + } + + destroy = () => { + this.log.debug('Destroying module'); + this.konva.group.destroy(); + }; + + repr = () => { + return { + id: this.id, + type: this.type, + path: this.path, + parent: this.parent.id, + state: deepClone(this.state), + }; + }; +} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectPolygon.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectPolygon.ts new file mode 100644 index 00000000000..dc54811569b --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectPolygon.ts @@ -0,0 +1,113 @@ +import { rgbaColorToString } from 'common/util/colorCodeTransformers'; +import { deepClone } from 'common/util/deepClone'; +import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer'; +import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; +import type { CanvasPolygonState, RgbaColor } from 'features/controlLayers/store/types'; +import Konva from 'konva'; +import type { Logger } from 'roarr'; + +const getPreviewStrokeColor = (color: RgbaColor) => rgbaColorToString({ ...color, a: Math.max(color.a, 0.9) }); + +export class CanvasObjectPolygon extends CanvasModuleBase { + readonly type = 'object_polygon'; + readonly id: string; + readonly path: string[]; + readonly parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer; + readonly manager: CanvasManager; + readonly log: Logger; + + state: CanvasPolygonState; + konva: { + group: Konva.Group; + fillPolygon: Konva.Line; + previewStroke: Konva.Line; + }; + + constructor(state: CanvasPolygonState, parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer) { + super(); + this.id = state.id; + this.parent = parent; + this.manager = parent.manager; + this.path = this.manager.buildPath(this); + this.log = this.manager.buildLogger(this); + + this.log.debug({ state }, 'Creating module'); + + this.konva = { + group: new Konva.Group({ name: `${this.type}:group`, listening: false }), + fillPolygon: new Konva.Line({ + name: `${this.type}:fill_polygon`, + listening: false, + closed: true, + strokeEnabled: false, + perfectDrawEnabled: false, + }), + previewStroke: new Konva.Line({ + name: `${this.type}:preview_stroke`, + listening: false, + closed: false, + fillEnabled: false, + lineCap: 'round', + lineJoin: 'round', + perfectDrawEnabled: false, + strokeWidth: 1, + }), + }; + this.konva.group.add(this.konva.fillPolygon, this.konva.previewStroke); + this.state = state; + } + + update(state: CanvasPolygonState, force = false): boolean { + if (force || this.state !== state) { + this.log.trace({ state }, 'Updating polygon'); + const combinedPoints = state.previewPoint + ? [...state.points, state.previewPoint.x, state.previewPoint.y] + : state.points; + const hasRenderablePolygon = combinedPoints.length >= 6; + const isLiveBufferPreview = this.parent.type === 'buffer_renderer' && this.parent.state?.id === state.id; + const fill = + state.compositeOperation === 'destination-out' ? 'rgba(255,255,255,1)' : rgbaColorToString(state.color); + + this.konva.fillPolygon.setAttrs({ + points: combinedPoints, + visible: hasRenderablePolygon, + fill, + globalCompositeOperation: state.compositeOperation, + }); + + this.konva.previewStroke.setAttrs({ + points: combinedPoints, + visible: (Boolean(state.previewPoint) || isLiveBufferPreview) && combinedPoints.length >= 4, + stroke: getPreviewStrokeColor(state.color), + globalCompositeOperation: 'source-over', + }); + + this.state = state; + return true; + } + + return false; + } + + setVisibility(isVisible: boolean): void { + this.log.trace({ isVisible }, 'Setting polygon visibility'); + this.konva.group.visible(isVisible); + } + + destroy = () => { + this.log.debug('Destroying module'); + this.konva.group.destroy(); + }; + + repr = () => { + return { + id: this.id, + type: this.type, + path: this.path, + parent: this.parent.id, + state: deepClone(this.state), + }; + }; +} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectRect.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectRect.ts index 1ac8e5b5f37..e879dcd35ab 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectRect.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectRect.ts @@ -46,13 +46,15 @@ export class CanvasObjectRect extends CanvasModuleBase { this.isFirstRender = false; this.log.trace({ state }, 'Updating rect'); - const { rect, color } = state; + const { rect, color, compositeOperation } = state; + const fill = compositeOperation === 'destination-out' ? 'rgba(255,255,255,1)' : rgbaColorToString(color); this.konva.rect.setAttrs({ x: rect.x, y: rect.y, width: rect.width, height: rect.height, - fill: rgbaColorToString(color), + fill, + globalCompositeOperation: compositeOperation, }); this.state = state; return true; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts index f193c0b391e..620842a9426 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts @@ -5,6 +5,8 @@ import type { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/ import type { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient'; import type { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; import type { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso'; +import type { CanvasObjectOval } from 'features/controlLayers/konva/CanvasObject/CanvasObjectOval'; +import type { CanvasObjectPolygon } from 'features/controlLayers/konva/CanvasObject/CanvasObjectPolygon'; import type { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect'; import type { CanvasBrushLineState, @@ -14,6 +16,8 @@ import type { CanvasGradientState, CanvasImageState, CanvasLassoState, + CanvasOvalState, + CanvasPolygonState, CanvasRectState, } from 'features/controlLayers/store/types'; @@ -28,6 +32,8 @@ export type AnyObjectRenderer = | CanvasObjectEraserLineWithPressure | CanvasObjectRect | CanvasObjectLasso + | CanvasObjectOval + | CanvasObjectPolygon | CanvasObjectImage | CanvasObjectGradient; /** @@ -41,4 +47,6 @@ export type AnyObjectState = | CanvasImageState | CanvasRectState | CanvasLassoState + | CanvasOvalState + | CanvasPolygonState | CanvasGradientState; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts index 7d4c76b0c06..26abd908e51 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts @@ -25,8 +25,8 @@ import { entityMovedBy, entityMovedTo, entityRasterized, - entityRectAdded, entityReset, + entityShapeAdded, inpaintMaskAdded, rasterLayerAdded, rgAdded, @@ -48,7 +48,7 @@ import type { EntityMovedByPayload, EntityMovedToPayload, EntityRasterizedPayload, - EntityRectAddedPayload, + EntityShapeAddedPayload, Rect, RgbaColor, } from 'features/controlLayers/store/types'; @@ -171,10 +171,10 @@ export class CanvasStateApiModule extends CanvasModuleBase { }; /** - * Adds a rectangle to an entity, pushing state to redux. + * Adds a shape to an entity, pushing state to redux. */ - addRect = (arg: EntityRectAddedPayload) => { - this.store.dispatch(entityRectAdded(arg)); + addShape = (arg: EntityShapeAddedPayload) => { + this.store.dispatch(entityShapeAdded(arg)); }; /** diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasRectToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasRectToolModule.ts deleted file mode 100644 index 3f64b0c2fc1..00000000000 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasRectToolModule.ts +++ /dev/null @@ -1,102 +0,0 @@ -import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; -import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; -import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule'; -import { floorCoord, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util'; -import type { KonvaEventObject } from 'konva/lib/Node'; -import type { Logger } from 'roarr'; - -export class CanvasRectToolModule extends CanvasModuleBase { - readonly type = 'rect_tool'; - readonly id: string; - readonly path: string[]; - readonly parent: CanvasToolModule; - readonly manager: CanvasManager; - readonly log: Logger; - - constructor(parent: CanvasToolModule) { - super(); - this.id = getPrefixedId(this.type); - this.parent = parent; - this.manager = this.parent.manager; - this.path = this.manager.buildPath(this); - this.log = this.manager.buildLogger(this); - - this.log.debug('Creating module'); - } - - syncCursorStyle = () => { - this.manager.stage.setCursor('crosshair'); - }; - - onStagePointerDown = async (_e: KonvaEventObject) => { - const cursorPos = this.parent.$cursorPos.get(); - const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get(); - const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); - - if (!cursorPos || !isPrimaryPointerDown || !selectedEntity) { - /** - * Can't do anything without: - * - A cursor position: the cursor is not on the stage - * - The mouse is down: the user is not drawing - * - A selected entity: there is no entity to draw on - */ - return; - } - - const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position); - - await selectedEntity.bufferRenderer.setBuffer({ - id: getPrefixedId('rect'), - type: 'rect', - rect: { x: Math.round(normalizedPoint.x), y: Math.round(normalizedPoint.y), width: 0, height: 0 }, - color: this.manager.stateApi.getCurrentColor(), - }); - }; - - onStagePointerUp = (_e: KonvaEventObject) => { - const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); - if (!selectedEntity) { - return; - } - - if (selectedEntity.bufferRenderer.state?.type === 'rect' && selectedEntity.bufferRenderer.hasBuffer()) { - selectedEntity.bufferRenderer.commitBuffer(); - } else { - selectedEntity.bufferRenderer.clearBuffer(); - } - }; - - onStagePointerMove = async (_e: KonvaEventObject) => { - const cursorPos = this.parent.$cursorPos.get(); - - if (!cursorPos) { - return; - } - - if (!this.parent.$isPrimaryPointerDown.get()) { - return; - } - - const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); - - if (!selectedEntity) { - return; - } - - const bufferState = selectedEntity.bufferRenderer.state; - - if (!bufferState) { - return; - } - - if (bufferState.type !== 'rect') { - return; - } - - const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position); - const alignedPoint = floorCoord(normalizedPoint); - bufferState.rect.width = Math.round(alignedPoint.x - bufferState.rect.x); - bufferState.rect.height = Math.round(alignedPoint.y - bufferState.rect.y); - await selectedEntity.bufferRenderer.setBuffer(bufferState); - }; -} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasShapeToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasShapeToolModule.ts new file mode 100644 index 00000000000..89c3f7691eb --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasShapeToolModule.ts @@ -0,0 +1,895 @@ +import { rgbaColorToString } from 'common/util/colorCodeTransformers'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; +import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule'; +import { shouldPreserveSuspendableShapesSession } from 'features/controlLayers/konva/CanvasTool/toolHotkeys'; +import { + addCoords, + floorCoord, + getPrefixedId, + isDistanceMoreThanMin, + offsetCoord, +} from 'features/controlLayers/konva/util'; +import { selectShapeType } from 'features/controlLayers/store/canvasSettingsSlice'; +import type { + CanvasEntityIdentifier, + CanvasPolygonState, + CanvasRectState, + Coordinate, +} from 'features/controlLayers/store/types'; +import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify'; +import Konva from 'konva'; +import type { KonvaEventObject } from 'konva/lib/Node'; +import type { Logger } from 'roarr'; + +type CanvasShapeToolModuleConfig = { + START_POINT_RADIUS_PX: number; + START_POINT_STROKE_WIDTH_PX: number; + START_POINT_HOVER_RADIUS_DELTA_PX: number; + POLYGON_CLOSE_RADIUS_PX: number; + MIN_FREEHAND_POINT_DISTANCE_PX: number; + MAX_FREEHAND_SEGMENT_LENGTH_PX: number; + FREEHAND_SIMPLIFY_MIN_POINTS: number; + FREEHAND_SIMPLIFY_TOLERANCE: number; + PREVIEW_STROKE_COLOR: string; +}; + +const DEFAULT_CONFIG: CanvasShapeToolModuleConfig = { + START_POINT_RADIUS_PX: 4, + START_POINT_STROKE_WIDTH_PX: 2, + START_POINT_HOVER_RADIUS_DELTA_PX: 2, + POLYGON_CLOSE_RADIUS_PX: 10, + MIN_FREEHAND_POINT_DISTANCE_PX: 1, + MAX_FREEHAND_SEGMENT_LENGTH_PX: 2, + FREEHAND_SIMPLIFY_MIN_POINTS: 200, + FREEHAND_SIMPLIFY_TOLERANCE: 0.6, + PREVIEW_STROKE_COLOR: rgbaColorToString({ r: 90, g: 175, b: 255, a: 1 }), +}; + +const SUBTRACT_CURSOR = `url("data:image/svg+xml,${encodeURIComponent( + ` + + + + + + + ` +)}") 12 12, crosshair`; + +const getAxisSign = (value: number, fallback: number): number => { + if (value === 0) { + return fallback === 0 ? 1 : Math.sign(fallback); + } + return Math.sign(value); +}; + +export class CanvasShapeToolModule extends CanvasModuleBase { + readonly type = 'shape_tool'; + readonly id: string; + readonly path: string[]; + readonly parent: CanvasToolModule; + readonly manager: CanvasManager; + readonly log: Logger; + + config: CanvasShapeToolModuleConfig = DEFAULT_CONFIG; + subscriptions: Set<() => void> = new Set(); + + private activeEntityIdentifier: CanvasEntityIdentifier | null = null; + private shapeId: string | null = null; + private dragStartPoint: Coordinate | null = null; + private dragCurrentPoint: Coordinate | null = null; + private translatePreviousPointerPoint: Coordinate | null = null; + private freehandPoints: Coordinate[] = []; + private isDrawingFreehand = false; + private polygonPoints: Coordinate[] = []; + private polygonPointer: Coordinate | null = null; + + konva: { + group: Konva.Group; + startPointIndicator: Konva.Circle; + }; + + constructor(parent: CanvasToolModule) { + super(); + this.id = getPrefixedId(this.type); + this.parent = parent; + this.manager = this.parent.manager; + this.path = this.manager.buildPath(this); + this.log = this.manager.buildLogger(this); + + this.log.debug('Creating module'); + + this.konva = { + group: new Konva.Group({ name: `${this.type}:group`, listening: false }), + startPointIndicator: new Konva.Circle({ + name: `${this.type}:start_point_indicator`, + listening: false, + fillEnabled: false, + stroke: this.config.PREVIEW_STROKE_COLOR, + visible: false, + perfectDrawEnabled: false, + }), + }; + this.konva.group.add(this.konva.startPointIndicator); + + this.subscriptions.add(this.manager.stateApi.$altKey.listen(this.onModifierChanged)); + this.subscriptions.add(this.manager.stateApi.$ctrlKey.listen(this.onModifierChanged)); + this.subscriptions.add(this.manager.stateApi.$metaKey.listen(this.onModifierChanged)); + this.subscriptions.add(this.manager.stateApi.$shiftKey.listen(this.onModifierChanged)); + this.subscriptions.add( + this.manager.stateApi.createStoreSubscription(selectShapeType, () => { + if (this.hasActiveSession()) { + this.cancel(); + } + this.render(); + }) + ); + } + + hasActiveSession = (): boolean => { + return Boolean( + this.dragStartPoint || this.isDrawingFreehand || this.freehandPoints.length || this.polygonPoints.length + ); + }; + + hasSuspendableSession = (): boolean => { + return Boolean(this.isDrawingFreehand || this.freehandPoints.length || this.polygonPoints.length); + }; + + hasActiveDragSession = (): boolean => { + return Boolean(this.dragStartPoint || this.isDrawingFreehand); + }; + + hasActiveRectOvalDragSession = (): boolean => { + const shapeType = this.manager.stateApi.getSettings().shapeType; + return Boolean(this.dragStartPoint && this.dragCurrentPoint && (shapeType === 'rect' || shapeType === 'oval')); + }; + + hasActivePolygonSession = (): boolean => { + return this.polygonPoints.length > 0; + }; + + isTranslatingDragSession = (): boolean => { + return this.translatePreviousPointerPoint !== null; + }; + + freezePolygonPreview = async () => { + if (!this.hasActivePolygonSession()) { + return; + } + + const activeEntity = this.getActiveEntityAdapter(); + const cursorPos = this.parent.$cursorPos.get(); + if (!activeEntity || !cursorPos) { + return; + } + + const point = this.getEntityRelativePoint(cursorPos.relative, activeEntity.state.position); + this.polygonPointer = point; + await this.updatePolygonBuffer(); + this.render(); + }; + + onToolChanged = () => { + const tool = this.parent.$tool.get(); + const isTemporaryToolSwitch = shouldPreserveSuspendableShapesSession( + tool, + this.parent.$toolBuffer.get(), + this.hasSuspendableSession() + ); + if (tool !== 'rect' && !isTemporaryToolSwitch) { + this.cancel(); + } + }; + + syncCursorStyle = () => { + this.manager.stage.setCursor(this.getCompositeOperation() === 'destination-out' ? SUBTRACT_CURSOR : 'crosshair'); + }; + + render = () => { + const tool = this.parent.$tool.get(); + const isTemporaryToolSwitch = shouldPreserveSuspendableShapesSession( + tool, + this.parent.$toolBuffer.get(), + this.hasSuspendableSession() + ); + if (tool !== 'rect' && !isTemporaryToolSwitch) { + this.konva.startPointIndicator.visible(false); + return; + } + + if (tool === 'rect') { + this.syncCursorStyle(); + } + + this.syncStartPointIndicator(); + }; + + cancel = () => { + this.clearActiveBuffer(); + this.resetState(); + this.render(); + }; + + startDragTranslation = () => { + const activeEntity = this.getActiveEntityAdapter(); + const cursorPos = this.parent.$cursorPos.get(); + if (!activeEntity || !cursorPos || !this.hasActiveRectOvalDragSession()) { + return; + } + + this.translatePreviousPointerPoint = this.getEntityRelativePoint(cursorPos.relative, activeEntity.state.position); + }; + + stopDragTranslation = () => { + this.translatePreviousPointerPoint = null; + }; + + onStagePointerDown = async (e: KonvaEventObject) => { + const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); + const cursorPos = this.parent.$cursorPos.get(); + if (!selectedEntity || !cursorPos) { + return; + } + + if (e.evt.button !== 0) { + return; + } + + const shapeType = this.manager.stateApi.getSettings().shapeType; + const point = this.getEntityRelativePoint(cursorPos.relative, selectedEntity.state.position); + + if (shapeType === 'polygon') { + await this.onPolygonPointerDown(point, selectedEntity.entityIdentifier, e.evt.shiftKey); + return; + } + + if (shapeType === 'freehand') { + if (!this.parent.$isPrimaryPointerDown.get()) { + return; + } + + await this.startFreehandSession(point, selectedEntity.entityIdentifier); + return; + } + + if (!this.parent.$isPrimaryPointerDown.get()) { + return; + } + + this.clearActiveBuffer(); + this.resetState(); + this.activeEntityIdentifier = selectedEntity.entityIdentifier; + this.shapeId = getPrefixedId(shapeType); + this.dragStartPoint = point; + this.dragCurrentPoint = point; + await this.updateDragBuffer(); + }; + + onStagePointerMove = async (e: KonvaEventObject) => { + const shapeType = this.manager.stateApi.getSettings().shapeType; + const activeEntity = this.getActiveEntityAdapter(); + const cursorPos = this.parent.$cursorPos.get(); + + if (!activeEntity || !cursorPos) { + return; + } + + const point = this.getEntityRelativePoint(cursorPos.relative, activeEntity.state.position); + + if (shapeType === 'polygon') { + if (!this.hasActivePolygonSession()) { + return; + } + this.polygonPointer = this.getPolygonPoint(point, e.evt.shiftKey); + await this.updatePolygonBuffer(); + this.render(); + return; + } + + if (shapeType === 'freehand') { + await this.handleFreehandPointerMove(point); + return; + } + + if (!this.parent.$isPrimaryPointerDown.get() || !this.dragStartPoint) { + return; + } + + if (this.isTranslatingDragSession()) { + await this.translateDragShape(point); + return; + } + + this.dragCurrentPoint = point; + await this.updateDragBuffer(); + }; + + onWindowPointerMove = async () => { + const shapeType = this.manager.stateApi.getSettings().shapeType; + const activeEntity = this.getActiveEntityAdapter(); + const cursorPos = this.parent.$cursorPos.get(); + + if (!activeEntity || !cursorPos || !this.parent.$isPrimaryPointerDown.get()) { + return; + } + + const point = this.getEntityRelativePoint(cursorPos.relative, activeEntity.state.position); + + if (shapeType === 'freehand') { + await this.handleFreehandPointerMove(point); + return; + } + + if ((shapeType !== 'rect' && shapeType !== 'oval') || !this.dragStartPoint) { + return; + } + + if (this.isTranslatingDragSession()) { + await this.translateDragShape(point); + return; + } + + this.dragCurrentPoint = point; + await this.updateDragBuffer(); + }; + + onStagePointerUp = async (_e: KonvaEventObject) => { + const shapeType = this.manager.stateApi.getSettings().shapeType; + + if (shapeType === 'polygon') { + this.render(); + return; + } + + if (shapeType === 'freehand') { + await this.commitFreehand(); + return; + } + + this.finishDragShapeSession(); + }; + + onWindowPointerUp = async () => { + if (this.isDrawingFreehand) { + await this.commitFreehand(); + return; + } + + if (!this.dragStartPoint) { + return; + } + + this.finishDragShapeSession(); + }; + + repr = () => { + return { + id: this.id, + type: this.type, + path: this.path, + activeEntityIdentifier: this.activeEntityIdentifier, + shapeId: this.shapeId, + dragStartPoint: this.dragStartPoint, + dragCurrentPoint: this.dragCurrentPoint, + translatePreviousPointerPoint: this.translatePreviousPointerPoint, + freehandPoints: this.freehandPoints, + isDrawingFreehand: this.isDrawingFreehand, + polygonPoints: this.polygonPoints, + polygonPointer: this.polygonPointer, + }; + }; + + destroy = () => { + this.log.debug('Destroying module'); + this.subscriptions.forEach((unsubscribe) => unsubscribe()); + this.subscriptions.clear(); + this.konva.group.destroy(); + }; + + private onModifierChanged = () => { + const tool = this.parent.$tool.get(); + const isTemporaryToolSwitch = shouldPreserveSuspendableShapesSession( + tool, + this.parent.$toolBuffer.get(), + this.hasSuspendableSession() + ); + if (tool !== 'rect' && !isTemporaryToolSwitch) { + return; + } + + if (tool === 'rect') { + this.syncCursorStyle(); + } + void this.updateActivePreview(); + this.render(); + }; + + private updateActivePreview = async () => { + if (this.dragStartPoint) { + await this.updateDragBuffer(); + return; + } + + if (this.isDrawingFreehand || this.freehandPoints.length > 0) { + await this.updateFreehandBuffer(); + return; + } + + if (this.hasActivePolygonSession()) { + await this.updatePolygonBuffer(); + } + }; + + private startFreehandSession = async (point: Coordinate, entityIdentifier: CanvasEntityIdentifier) => { + this.clearActiveBuffer(); + this.resetState(); + this.activeEntityIdentifier = entityIdentifier; + this.shapeId = getPrefixedId('polygon'); + this.freehandPoints = [point]; + this.isDrawingFreehand = true; + await this.updateFreehandBuffer(); + }; + + private handleFreehandPointerMove = async (point: Coordinate) => { + if (!this.isDrawingFreehand || !this.parent.$isPrimaryPointerDown.get()) { + return; + } + + const minDistance = this.manager.stage.unscale(this.config.MIN_FREEHAND_POINT_DISTANCE_PX); + const lastPoint = this.freehandPoints.at(-1) ?? null; + if (!isDistanceMoreThanMin(point, lastPoint, minDistance)) { + return; + } + + this.appendFreehandPoint(point); + await this.updateFreehandBuffer(); + }; + + private translateDragShape = async (point: Coordinate) => { + if (!this.translatePreviousPointerPoint || !this.dragStartPoint || !this.dragCurrentPoint) { + return; + } + + const dx = point.x - this.translatePreviousPointerPoint.x; + const dy = point.y - this.translatePreviousPointerPoint.y; + + if (dx === 0 && dy === 0) { + return; + } + + this.dragStartPoint = { + x: this.dragStartPoint.x + dx, + y: this.dragStartPoint.y + dy, + }; + this.dragCurrentPoint = { + x: this.dragCurrentPoint.x + dx, + y: this.dragCurrentPoint.y + dy, + }; + this.translatePreviousPointerPoint = point; + + await this.updateDragBuffer(); + }; + + private onPolygonPointerDown = async ( + point: Coordinate, + entityIdentifier: CanvasEntityIdentifier, + shouldSnap: boolean + ) => { + if ( + this.activeEntityIdentifier && + (this.activeEntityIdentifier.id !== entityIdentifier.id || + this.activeEntityIdentifier.type !== entityIdentifier.type) + ) { + this.cancel(); + } + + this.activeEntityIdentifier = entityIdentifier; + this.dragStartPoint = null; + this.dragCurrentPoint = null; + + if (this.polygonPoints.length === 0) { + this.shapeId = getPrefixedId('polygon'); + this.polygonPoints = [point]; + this.polygonPointer = point; + await this.updatePolygonBuffer(); + this.render(); + return; + } + + const startPoint = this.polygonPoints[0]; + if (!startPoint) { + return; + } + + if (this.polygonPoints.length >= 3 && this.isPointNearStart(point)) { + await this.commitPolygon(); + return; + } + + const polygonPoint = this.getPolygonPoint(point, shouldSnap); + this.polygonPoints = [...this.polygonPoints, polygonPoint]; + this.polygonPointer = polygonPoint; + await this.updatePolygonBuffer(); + this.render(); + }; + + private commitPolygon = async () => { + const activeEntity = this.getActiveEntityAdapter(); + if (!activeEntity || !this.shapeId || this.polygonPoints.length < 3) { + this.cancel(); + return; + } + + const polygonState: CanvasPolygonState = { + id: this.shapeId, + type: 'polygon', + points: this.polygonPoints.flatMap((point) => [point.x, point.y]), + color: this.manager.stateApi.getCurrentColor(), + compositeOperation: this.getCompositeOperation(), + }; + + await activeEntity.bufferRenderer.setBuffer(polygonState); + activeEntity.bufferRenderer.commitBuffer(); + this.resetState(); + this.render(); + }; + + private commitFreehand = async () => { + if (!this.isDrawingFreehand) { + return; + } + + const activeEntity = this.getActiveEntityAdapter(); + if (!activeEntity || !this.shapeId) { + this.cancel(); + return; + } + + const simplifiedPoints = this.simplifyFreehandContour(this.freehandPoints); + if (simplifiedPoints.length < 3) { + activeEntity.bufferRenderer.clearBuffer(); + this.resetState(); + this.render(); + return; + } + + const polygonState: CanvasPolygonState = { + id: this.shapeId, + type: 'polygon', + points: simplifiedPoints.flatMap((point) => [point.x, point.y]), + color: this.manager.stateApi.getCurrentColor(), + compositeOperation: this.getCompositeOperation(), + }; + + await activeEntity.bufferRenderer.setBuffer(polygonState); + activeEntity.bufferRenderer.commitBuffer(); + this.resetState(); + this.render(); + }; + + private updateDragBuffer = async () => { + const activeEntity = this.getActiveEntityAdapter(); + if (!activeEntity || !this.dragStartPoint || !this.dragCurrentPoint || !this.shapeId) { + return; + } + + const shapeType = this.manager.stateApi.getSettings().shapeType; + if (shapeType !== 'rect' && shapeType !== 'oval') { + return; + } + + const rect = this.getDragRect(this.dragStartPoint, this.dragCurrentPoint, { + fromCenter: this.manager.stateApi.$altKey.get(), + constrainSquare: this.manager.stateApi.$shiftKey.get(), + }); + + await activeEntity.bufferRenderer.setBuffer({ + id: this.shapeId, + type: shapeType, + rect, + color: this.manager.stateApi.getCurrentColor(), + compositeOperation: this.getCompositeOperation(), + }); + }; + + private updatePolygonBuffer = async () => { + const activeEntity = this.getActiveEntityAdapter(); + if (!activeEntity || !this.shapeId || this.polygonPoints.length === 0) { + return; + } + + await activeEntity.bufferRenderer.setBuffer({ + id: this.shapeId, + type: 'polygon', + points: this.polygonPoints.flatMap((point) => [point.x, point.y]), + previewPoint: this.polygonPointer ?? this.polygonPoints.at(-1), + color: this.manager.stateApi.getCurrentColor(), + compositeOperation: this.getCompositeOperation(), + }); + }; + + private updateFreehandBuffer = async () => { + const activeEntity = this.getActiveEntityAdapter(); + if (!activeEntity || !this.shapeId || this.freehandPoints.length === 0) { + return; + } + + await activeEntity.bufferRenderer.setBuffer({ + id: this.shapeId, + type: 'polygon', + points: this.freehandPoints.flatMap((point) => [point.x, point.y]), + color: this.manager.stateApi.getCurrentColor(), + compositeOperation: this.getCompositeOperation(), + }); + }; + + private syncStartPointIndicator = () => { + const activeEntity = this.getActiveEntityAdapter(); + const startPoint = this.polygonPoints[0]; + if (!activeEntity || !startPoint || this.manager.stateApi.getSettings().shapeType !== 'polygon') { + this.konva.startPointIndicator.visible(false); + return; + } + + const isHoveringStartPoint = this.getIsHoveringStartPoint(startPoint, activeEntity.state.position); + const baseRadius = this.manager.stage.unscale(this.config.START_POINT_RADIUS_PX); + const stagePoint = addCoords(startPoint, activeEntity.state.position); + + this.konva.startPointIndicator.setAttrs({ + x: stagePoint.x, + y: stagePoint.y, + radius: + baseRadius + + (isHoveringStartPoint ? this.manager.stage.unscale(this.config.START_POINT_HOVER_RADIUS_DELTA_PX) : 0), + strokeWidth: this.manager.stage.unscale(this.config.START_POINT_STROKE_WIDTH_PX), + visible: true, + }); + }; + + private getEntityRelativePoint = (point: Coordinate, position: Coordinate): Coordinate => { + return floorCoord(offsetCoord(point, position)); + }; + + private getCompositeOperation = (): CanvasRectState['compositeOperation'] => { + return this.manager.stateApi.$ctrlKey.get() || this.manager.stateApi.$metaKey.get() + ? 'destination-out' + : 'source-over'; + }; + + private getPolygonPoint = (point: Coordinate, shouldSnap: boolean): Coordinate => { + if (!shouldSnap) { + return point; + } + + const lastPoint = this.polygonPoints.at(-1); + if (!lastPoint) { + return point; + } + + const dx = point.x - lastPoint.x; + const dy = point.y - lastPoint.y; + const distance = Math.hypot(dx, dy); + if (distance === 0) { + return point; + } + + const snapAngle = Math.PI / 4; + const angle = Math.atan2(dy, dx); + const snappedAngle = Math.round(angle / snapAngle) * snapAngle; + + const snappedPoint = { + x: lastPoint.x + Math.cos(snappedAngle) * distance, + y: lastPoint.y + Math.sin(snappedAngle) * distance, + }; + + return this.alignPointToStart(snappedPoint); + }; + + private isPointNearStart = (point: Coordinate): boolean => { + const startPoint = this.polygonPoints[0]; + if (!startPoint) { + return false; + } + return Math.hypot(point.x - startPoint.x, point.y - startPoint.y) <= this.getPolygonCloseRadius(); + }; + + private getPolygonCloseRadius = (): number => { + return this.manager.stage.unscale(this.config.POLYGON_CLOSE_RADIUS_PX); + }; + + private getIsHoveringStartPoint = (startPoint: Coordinate, entityPosition: Coordinate): boolean => { + if (this.polygonPoints.length < 3) { + return false; + } + + const pointerPoint = this.parent.$cursorPos.get()?.relative; + if (!pointerPoint) { + return false; + } + + const entityRelativePointerPoint = this.getEntityRelativePoint(pointerPoint, entityPosition); + return ( + Math.hypot(entityRelativePointerPoint.x - startPoint.x, entityRelativePointerPoint.y - startPoint.y) <= + this.getPolygonCloseRadius() + ); + }; + + private alignPointToStart = (point: Coordinate): Coordinate => { + if (this.polygonPoints.length < 2) { + return point; + } + + const startPoint = this.polygonPoints[0]; + if (!startPoint) { + return point; + } + + const alignThreshold = this.getPolygonCloseRadius(); + const deltaX = Math.abs(point.x - startPoint.x); + const deltaY = Math.abs(point.y - startPoint.y); + const canAlignX = deltaX <= alignThreshold; + const canAlignY = deltaY <= alignThreshold; + + if (!canAlignX && !canAlignY) { + return point; + } + + if (canAlignX && canAlignY) { + if (deltaX <= deltaY) { + return { x: startPoint.x, y: point.y }; + } + return { x: point.x, y: startPoint.y }; + } + + if (canAlignX) { + return { x: startPoint.x, y: point.y }; + } + + return { x: point.x, y: startPoint.y }; + }; + + private appendFreehandPoint = (point: Coordinate) => { + const lastPoint = this.freehandPoints.at(-1) ?? null; + if (!lastPoint) { + this.freehandPoints.push(point); + return; + } + + const maxSegmentLength = this.manager.stage.unscale(this.config.MAX_FREEHAND_SEGMENT_LENGTH_PX); + const dx = point.x - lastPoint.x; + const dy = point.y - lastPoint.y; + const distance = Math.hypot(dx, dy); + + if (distance <= maxSegmentLength) { + this.freehandPoints.push(point); + return; + } + + const steps = Math.ceil(distance / maxSegmentLength); + for (let i = 1; i <= steps; i++) { + const t = i / steps; + this.freehandPoints.push({ + x: lastPoint.x + dx * t, + y: lastPoint.y + dy * t, + }); + } + }; + + private simplifyFreehandContour = (points: Coordinate[]): Coordinate[] => { + if (points.length < this.config.FREEHAND_SIMPLIFY_MIN_POINTS) { + return points; + } + + const simplifiedFlatPoints = simplifyFlatNumbersArray( + points.flatMap((point) => [point.x, point.y]), + { + tolerance: this.config.FREEHAND_SIMPLIFY_TOLERANCE, + highestQuality: true, + } + ); + + if (simplifiedFlatPoints.length < 6) { + return points; + } + + const simplifiedPoints = this.flatNumbersToCoords(simplifiedFlatPoints); + if (simplifiedPoints.length < 3) { + return points; + } + + return simplifiedPoints; + }; + + private flatNumbersToCoords = (points: number[]): Coordinate[] => { + const coords: Coordinate[] = []; + for (let i = 0; i < points.length; i += 2) { + const x = points[i]; + const y = points[i + 1]; + if (x === undefined || y === undefined) { + continue; + } + coords.push({ x, y }); + } + return coords; + }; + + private getDragRect = ( + start: Coordinate, + end: Coordinate, + options: { fromCenter: boolean; constrainSquare: boolean } + ): CanvasRectState['rect'] => { + let dx = end.x - start.x; + let dy = end.y - start.y; + + if (options.constrainSquare) { + const size = Math.max(Math.abs(dx), Math.abs(dy)); + const dxSign = getAxisSign(dx, dy); + const dySign = getAxisSign(dy, dx); + dx = dxSign * size; + dy = dySign * size; + } + + const x1 = options.fromCenter ? start.x - dx : start.x; + const y1 = options.fromCenter ? start.y - dy : start.y; + const x2 = options.fromCenter ? start.x + dx : start.x + dx; + const y2 = options.fromCenter ? start.y + dy : start.y + dy; + + return { + x: Math.min(x1, x2), + y: Math.min(y1, y2), + width: Math.abs(x2 - x1), + height: Math.abs(y2 - y1), + }; + }; + + private getActiveEntityAdapter = () => { + if (!this.activeEntityIdentifier) { + return null; + } + return this.manager.getAdapter(this.activeEntityIdentifier); + }; + + private finishDragShapeSession = () => { + const activeEntity = this.getActiveEntityAdapter(); + if (!activeEntity) { + this.resetState(); + this.render(); + return; + } + + const bufferState = activeEntity.bufferRenderer.state; + if ( + bufferState && + (bufferState.type === 'rect' || bufferState.type === 'oval') && + activeEntity.bufferRenderer.hasBuffer() && + bufferState.rect.width > 0 && + bufferState.rect.height > 0 + ) { + activeEntity.bufferRenderer.commitBuffer(); + } else { + activeEntity.bufferRenderer.clearBuffer(); + } + + this.resetState(); + this.render(); + }; + + private clearActiveBuffer = () => { + this.getActiveEntityAdapter()?.bufferRenderer.clearBuffer(); + }; + + private resetState = () => { + this.activeEntityIdentifier = null; + this.shapeId = null; + this.dragStartPoint = null; + this.dragCurrentPoint = null; + this.translatePreviousPointerPoint = null; + this.freehandPoints = []; + this.isDrawingFreehand = false; + this.polygonPoints = []; + this.polygonPointer = null; + this.konva.startPointIndicator.visible(false); + }; +} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts index beca4d14a0a..98e1ab7d5d9 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts @@ -1,5 +1,6 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; +import type { AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types'; import { CanvasBboxToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasBboxToolModule'; import { CanvasBrushToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasBrushToolModule'; import { CanvasColorPickerToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasColorPickerToolModule'; @@ -7,9 +8,15 @@ import { CanvasEraserToolModule } from 'features/controlLayers/konva/CanvasTool/ import { CanvasGradientToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasGradientToolModule'; import { CanvasLassoToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasLassoToolModule'; import { CanvasMoveToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasMoveToolModule'; -import { CanvasRectToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasRectToolModule'; +import { CanvasShapeToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasShapeToolModule'; import { CanvasTextToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasTextToolModule'; import { CanvasViewToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasViewToolModule'; +import { + getToolToCancelOnEscape, + shouldPreserveSuspendableShapesSession, + shouldQuickSwitchToColorPickerOnAlt, + shouldTranslateShapeDragOnSpace, +} from 'features/controlLayers/konva/CanvasTool/toolHotkeys'; import { ZOOM_DRAG_CURSOR } from 'features/controlLayers/konva/cursors/zoomDragCursor'; import { calculateNewBrushSizeFromWheelDelta, @@ -64,7 +71,7 @@ export class CanvasToolModule extends CanvasModuleBase { tools: { brush: CanvasBrushToolModule; eraser: CanvasEraserToolModule; - rect: CanvasRectToolModule; + rect: CanvasShapeToolModule; lasso: CanvasLassoToolModule; gradient: CanvasGradientToolModule; colorPicker: CanvasColorPickerToolModule; @@ -124,7 +131,7 @@ export class CanvasToolModule extends CanvasModuleBase { this.tools = { brush: new CanvasBrushToolModule(this), eraser: new CanvasEraserToolModule(this), - rect: new CanvasRectToolModule(this), + rect: new CanvasShapeToolModule(this), lasso: new CanvasLassoToolModule(this), gradient: new CanvasGradientToolModule(this), colorPicker: new CanvasColorPickerToolModule(this), @@ -141,6 +148,7 @@ export class CanvasToolModule extends CanvasModuleBase { this.konva.group.add(this.tools.brush.konva.group); this.konva.group.add(this.tools.eraser.konva.group); + this.konva.group.add(this.tools.rect.konva.group); this.konva.group.add(this.tools.colorPicker.konva.group); this.konva.group.add(this.tools.text.konva.group); this.konva.group.add(this.tools.bbox.konva.group); @@ -152,17 +160,24 @@ export class CanvasToolModule extends CanvasModuleBase { this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasSlice, this.render)); this.subscriptions.add( this.$tool.listen((tool, previousTool) => { - // Preserve pointer state during temporary view switching so lasso sessions can freeze/resume on space. - const shouldPreservePointerState = + // Preserve pointer state during temporary view switching so lasso and shapes sessions can freeze/resume on + // space. + const shouldPreserveLassoPointerState = this.$toolBuffer.get() === 'lasso' && this.tools.lasso.hasActiveSession() && ((previousTool === 'lasso' && tool === 'view') || (previousTool === 'view' && tool === 'lasso')); + const shouldPreserveShapesPointerState = + this.$toolBuffer.get() === 'rect' && + this.tools.rect.hasSuspendableSession() && + ((previousTool === 'rect' && tool === 'view') || (previousTool === 'view' && tool === 'rect')); + const shouldPreservePointerState = shouldPreserveLassoPointerState || shouldPreserveShapesPointerState; if (!shouldPreservePointerState) { // On tool switch, reset mouse state this.manager.tool.$isPrimaryPointerDown.set(false); } + this.tools.rect.onToolChanged(); this.tools.lasso.onToolChanged(); void this.tools.text.onToolChanged(); this.render(); @@ -239,6 +254,7 @@ export class CanvasToolModule extends CanvasModuleBase { this.tools.brush.render(); this.tools.eraser.render(); + this.tools.rect.render(); this.tools.colorPicker.render(); this.tools.text.render(); this.tools.bbox.render(); @@ -411,9 +427,8 @@ export class CanvasToolModule extends CanvasModuleBase { const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); if ( - selectedEntity?.bufferRenderer.state?.type !== 'rect' && - selectedEntity?.bufferRenderer.state?.type !== 'gradient' && - selectedEntity?.bufferRenderer.hasBuffer() + selectedEntity?.bufferRenderer.hasBuffer() && + !this.shouldDeferEnterLeaveCommit(selectedEntity.bufferRenderer.state) ) { selectedEntity.bufferRenderer.commitBuffer(); return; @@ -467,7 +482,7 @@ export class CanvasToolModule extends CanvasModuleBase { } }; - onStagePointerUp = (e: KonvaEventObject) => { + onStagePointerUp = async (e: KonvaEventObject) => { if (e.target !== this.konva.stage) { return; } @@ -490,7 +505,7 @@ export class CanvasToolModule extends CanvasModuleBase { } else if (tool === 'eraser') { this.tools.eraser.onStagePointerUp(e); } else if (tool === 'rect') { - this.tools.rect.onStagePointerUp(e); + await this.tools.rect.onStagePointerUp(e); } else if (tool === 'lasso') { void this.tools.lasso.onStagePointerUp(e); } else if (tool === 'gradient') { @@ -534,6 +549,8 @@ export class CanvasToolModule extends CanvasModuleBase { await this.tools.gradient.onStagePointerMove(e); } else if (tool === 'text') { // Already handled above + } else if (this.isTemporaryShapesToolSwitch()) { + // Preserve in-progress polygon/freehand shapes while temporarily switching to view or color picker. } else { this.manager.stateApi.getSelectedEntityAdapter()?.bufferRenderer.clearBuffer(); } @@ -559,9 +576,8 @@ export class CanvasToolModule extends CanvasModuleBase { if ( selectedEntity && - selectedEntity.bufferRenderer.state?.type !== 'rect' && - selectedEntity.bufferRenderer.state?.type !== 'gradient' && - selectedEntity.bufferRenderer.hasBuffer() + selectedEntity.bufferRenderer.hasBuffer() && + !this.shouldDeferEnterLeaveCommit(selectedEntity.bufferRenderer.state) ) { selectedEntity.bufferRenderer.commitBuffer(); } @@ -604,20 +620,19 @@ export class CanvasToolModule extends CanvasModuleBase { this.render(); }; - /** - * Commit the buffer on window pointer up. - * - * The user may start drawing inside the stage and then release the mouse button outside of the stage. To prevent - * whatever the user was drawing from being lost, or ending up with stale state, we need to commit the buffer - * on window pointer up. - */ - onWindowPointerUp = (_: PointerEvent) => { + onWindowPointerUp = async (_: PointerEvent) => { try { this.$isPrimaryPointerDown.set(false); void this.tools.lasso.onWindowPointerUp(); + await this.tools.rect.onWindowPointerUp(); const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); - if (selectedEntity && selectedEntity.bufferRenderer.hasBuffer() && !this.manager.$isBusy.get()) { + if ( + selectedEntity && + selectedEntity.bufferRenderer.hasBuffer() && + !this.manager.$isBusy.get() && + !this.shouldSkipWindowPointerUpCommit(selectedEntity.bufferRenderer.state) + ) { selectedEntity.bufferRenderer.commitBuffer(); } } finally { @@ -625,36 +640,38 @@ export class CanvasToolModule extends CanvasModuleBase { } }; - onWindowPointerMove = (e: PointerEvent) => { + onWindowPointerMove = async (e: PointerEvent) => { const target = e.target; if (target instanceof Node && this.manager.stage.container.contains(target)) { return; } - if (this.$tool.get() !== 'lasso') { - return; - } - - if (!this.getCanDraw()) { - return; - } - - if (!this.$isPrimaryPointerDown.get()) { - return; - } - - if (!this.tools.lasso.hasActiveSession()) { - return; - } - try { this.$lastPointerType.set(e.pointerType); + if (!this.getCanDraw()) { + return; + } + + if (!this.$isPrimaryPointerDown.get()) { + return; + } + if (!this.syncCursorPositionsFromWindowEvent(e)) { return; } - this.tools.lasso.onWindowPointerMove(e); + if (this.$tool.get() === 'rect') { + if (!this.tools.rect.hasActiveDragSession()) { + return; + } + await this.tools.rect.onWindowPointerMove(); + } else if (this.$tool.get() === 'lasso') { + if (!this.tools.lasso.hasActiveSession()) { + return; + } + this.tools.lasso.onWindowPointerMove(e); + } } finally { this.render(); } @@ -665,6 +682,8 @@ export class CanvasToolModule extends CanvasModuleBase { * and the color picker tool is still active when you come back. */ onWindowBlur = () => { + this.manager.stateApi.$spaceKey.set(false); + this.tools.rect.stopDragTranslation(); this.revertToolBuffer(); }; @@ -691,9 +710,25 @@ export class CanvasToolModule extends CanvasModuleBase { if (e.key === KEY_ESCAPE) { // Cancel shape drawing on escape e.preventDefault(); - if (this.$tool.get() === 'lasso') { + const tool = this.$tool.get(); + const toolToCancel = getToolToCancelOnEscape( + tool, + this.$toolBuffer.get(), + this.tools.lasso.hasActiveSession(), + this.tools.rect.hasSuspendableSession() + ); + + this.manager.stateApi.$spaceKey.set(false); + this.tools.rect.stopDragTranslation(); + if (toolToCancel === 'rect') { + this.tools.rect.cancel(); + } + if (toolToCancel === 'lasso') { this.tools.lasso.reset(); } + if (toolToCancel && tool !== toolToCancel) { + this.revertToolBuffer(); + } const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); if ( selectedEntity && @@ -707,16 +742,40 @@ export class CanvasToolModule extends CanvasModuleBase { } if (isSpaceKey) { - // Select the view tool on space key down e.preventDefault(); e.stopPropagation(); const currentTool = this.$tool.get(); - this.$toolBuffer.set(currentTool); + const shapeType = this.manager.stateApi.getSettings().shapeType; + const hasActiveShapeDragSession = this.tools.rect.hasActiveDragSession(); + const isPrimaryPointerDown = this.$isPrimaryPointerDown.get(); this.manager.stateApi.$spaceKey.set(true); + + if (shouldTranslateShapeDragOnSpace(currentTool, shapeType, hasActiveShapeDragSession, isPrimaryPointerDown)) { + this.tools.rect.startDragTranslation(); + return; + } + + if (currentTool === 'rect' && this.tools.rect.hasActivePolygonSession()) { + void this.tools.rect.freezePolygonPreview(); + } + + // Select the view tool on space key down + this.$toolBuffer.set(currentTool); this.$tool.set('view'); if (currentTool === 'lasso' && this.tools.lasso.hasActiveSession() && this.$isPrimaryPointerDown.get()) { // Start panning immediately if user is already drawing with freehand lasso. this.manager.stage.startDragging(); + } else if ( + currentTool === 'rect' && + this.tools.rect.hasSuspendableSession() && + this.$isPrimaryPointerDown.get() + ) { + // Match lasso: allow an in-progress freehand shapes session to freeze and pan immediately on space. + this.manager.stage.startDragging(); + } else if (currentTool === 'rect' && this.tools.rect.hasActivePolygonSession()) { + // Match polygon lasso: when a polygon session is active, Space should immediately enter panning without + // requiring an extra click on the canvas. + this.manager.stage.startDragging(); } else { this.$cursorPos.set(null); } @@ -724,10 +783,17 @@ export class CanvasToolModule extends CanvasModuleBase { } if (e.key === KEY_ALT) { + const tool = this.$tool.get(); + const shapeType = this.manager.stateApi.getSettings().shapeType; + const hasActiveShapeDragSession = this.tools.rect.hasActiveDragSession(); + if (!shouldQuickSwitchToColorPickerOnAlt(tool, shapeType, hasActiveShapeDragSession)) { + e.preventDefault(); + return; + } // Select the color picker on alt key down e.preventDefault(); e.stopPropagation(); - this.$toolBuffer.set(this.$tool.get()); + this.$toolBuffer.set(tool); this.$tool.set('colorPicker'); } }; @@ -747,11 +813,15 @@ export class CanvasToolModule extends CanvasModuleBase { } if (e.key === KEY_SPACE || e.code === CODE_SPACE) { - // Revert the tool to the previous tool on space key up e.preventDefault(); e.stopPropagation(); - this.revertToolBuffer(); this.manager.stateApi.$spaceKey.set(false); + if (this.tools.rect.isTranslatingDragSession()) { + this.tools.rect.stopDragTranslation(); + return; + } + // Revert the tool to the previous tool on space key up + this.revertToolBuffer(); return; } @@ -806,4 +876,28 @@ export class CanvasToolModule extends CanvasModuleBase { } this.konva.group.destroy(); }; + + private shouldDeferEnterLeaveCommit = (state: AnyObjectState | null) => { + if (!state) { + return false; + } + + if (state.type === 'rect' || state.type === 'oval' || state.type === 'gradient') { + return true; + } + + return state.type === 'polygon'; + }; + + private shouldSkipWindowPointerUpCommit = (state: AnyObjectState | null) => { + return Boolean(state?.type === 'polygon' && state.previewPoint); + }; + + private isTemporaryShapesToolSwitch = () => { + return shouldPreserveSuspendableShapesSession( + this.$tool.get(), + this.$toolBuffer.get(), + this.tools.rect.hasSuspendableSession() + ); + }; } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.test.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.test.ts new file mode 100644 index 00000000000..e4fd800d2c2 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.test.ts @@ -0,0 +1,70 @@ +import { describe, expect, it } from 'vitest'; + +import { + getToolToCancelOnEscape, + shouldPreserveSuspendableShapesSession, + shouldQuickSwitchToColorPickerOnAlt, + shouldTranslateShapeDragOnSpace, +} from './toolHotkeys'; + +describe('tool hotkeys', () => { + it('keeps the color-picker quick-switch available before starting rect and oval drags', () => { + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'rect', false)).toBe(true); + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'oval', false)).toBe(true); + }); + + it('blocks the color-picker quick-switch while a rect, oval, or freehand drag is active', () => { + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'rect', true)).toBe(false); + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'oval', true)).toBe(false); + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'freehand', true)).toBe(false); + }); + + it('keeps the color-picker quick-switch for polygon mode and non-shape tools', () => { + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'polygon', false)).toBe(true); + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'polygon', true)).toBe(true); + expect(shouldQuickSwitchToColorPickerOnAlt('brush', 'rect', true)).toBe(true); + expect(shouldQuickSwitchToColorPickerOnAlt('lasso', 'polygon', false)).toBe(true); + }); + + it('uses Space to translate active rect and oval drags instead of switching to view', () => { + expect(shouldTranslateShapeDragOnSpace('rect', 'rect', true, true)).toBe(true); + expect(shouldTranslateShapeDragOnSpace('rect', 'oval', true, true)).toBe(true); + }); + + it('does not use Space translation outside active rect and oval drags', () => { + expect(shouldTranslateShapeDragOnSpace('rect', 'rect', false, true)).toBe(false); + expect(shouldTranslateShapeDragOnSpace('rect', 'rect', true, false)).toBe(false); + expect(shouldTranslateShapeDragOnSpace('rect', 'polygon', true, true)).toBe(false); + expect(shouldTranslateShapeDragOnSpace('rect', 'freehand', true, true)).toBe(false); + expect(shouldTranslateShapeDragOnSpace('brush', 'rect', true, true)).toBe(false); + }); + + it('preserves suspendable shapes sessions across temporary view and color-picker switches', () => { + expect(shouldPreserveSuspendableShapesSession('view', 'rect', true)).toBe(true); + expect(shouldPreserveSuspendableShapesSession('colorPicker', 'rect', true)).toBe(true); + expect(shouldPreserveSuspendableShapesSession('rect', 'rect', true)).toBe(true); + }); + + it('does not preserve suspendable shapes sessions for unrelated tool switches', () => { + expect(shouldPreserveSuspendableShapesSession('brush', 'rect', true)).toBe(false); + expect(shouldPreserveSuspendableShapesSession('view', null, true)).toBe(false); + expect(shouldPreserveSuspendableShapesSession('colorPicker', 'rect', false)).toBe(false); + }); + + it('cancels the active drawing tool directly on escape', () => { + expect(getToolToCancelOnEscape('rect', null, false, false)).toBe('rect'); + expect(getToolToCancelOnEscape('lasso', null, false, false)).toBe('lasso'); + }); + + it('cancels preserved drawing sessions while temporarily switched away', () => { + expect(getToolToCancelOnEscape('view', 'lasso', true, false)).toBe('lasso'); + expect(getToolToCancelOnEscape('view', 'rect', false, true)).toBe('rect'); + expect(getToolToCancelOnEscape('colorPicker', 'rect', false, true)).toBe('rect'); + }); + + it('does not cancel unrelated buffered tools on escape', () => { + expect(getToolToCancelOnEscape('view', 'lasso', false, false)).toBeNull(); + expect(getToolToCancelOnEscape('colorPicker', 'lasso', true, false)).toBeNull(); + expect(getToolToCancelOnEscape('view', 'brush', false, true)).toBeNull(); + }); +}); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.ts new file mode 100644 index 00000000000..84cbdc93dce --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.ts @@ -0,0 +1,65 @@ +import type { Tool } from 'features/controlLayers/store/types'; + +type ShapeType = 'rect' | 'oval' | 'polygon' | 'freehand'; + +export const shouldPreserveSuspendableShapesSession = ( + tool: Tool, + toolBuffer: Tool | null, + hasSuspendableShapeSession: boolean +): boolean => { + if (!hasSuspendableShapeSession || toolBuffer !== 'rect') { + return false; + } + + return tool === 'view' || tool === 'colorPicker' || tool === 'rect'; +}; + +export const shouldQuickSwitchToColorPickerOnAlt = ( + tool: Tool, + shapeType: ShapeType, + hasActiveShapeDragSession: boolean +): boolean => { + if (tool !== 'rect') { + return true; + } + + if (shapeType === 'polygon') { + return true; + } + + return !hasActiveShapeDragSession; +}; + +export const shouldTranslateShapeDragOnSpace = ( + tool: Tool, + shapeType: ShapeType, + hasActiveShapeDragSession: boolean, + isPrimaryPointerDown: boolean +): boolean => { + if (tool !== 'rect' || !hasActiveShapeDragSession || !isPrimaryPointerDown) { + return false; + } + + return shapeType === 'rect' || shapeType === 'oval'; +}; + +export const getToolToCancelOnEscape = ( + tool: Tool, + toolBuffer: Tool | null, + hasActiveLassoSession: boolean, + hasSuspendableShapeSession: boolean +): Tool | null => { + if (tool === 'rect' || tool === 'lasso') { + return tool; + } + + if (tool === 'view' && toolBuffer === 'lasso' && hasActiveLassoSession) { + return 'lasso'; + } + + if ((tool === 'view' || tool === 'colorPicker') && toolBuffer === 'rect' && hasSuspendableShapeSession) { + return 'rect'; + } + + return null; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts index 202b70e142d..509aefdaaf2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts @@ -14,6 +14,7 @@ export type TransformSmoothingMode = z.infer; const zGradientType = z.enum(['linear', 'radial']); const zLassoMode = z.enum(['freehand', 'polygon']); +const zShapeType = z.enum(['rect', 'oval', 'polygon', 'freehand']); const zCanvasSettingsState = z.object({ /** @@ -115,6 +116,10 @@ const zCanvasSettingsState = z.object({ * The gradient tool type. */ gradientType: zGradientType.default('linear'), + /** + * The shape tool type. + */ + shapeType: zShapeType.default('rect'), /** * Whether the gradient tool clips to the drag gesture. */ @@ -152,6 +157,7 @@ const getInitialState = (): CanvasSettingsState => ({ transformSmoothingEnabled: false, transformSmoothingMode: 'bicubic', gradientType: 'linear', + shapeType: 'rect', gradientClipEnabled: true, lassoMode: 'freehand', }); @@ -248,6 +254,9 @@ const slice = createSlice({ settingsGradientTypeChanged: (state, action: PayloadAction) => { state.gradientType = action.payload; }, + settingsShapeTypeChanged: (state, action: PayloadAction) => { + state.shapeType = action.payload; + }, settingsGradientClipToggled: (state) => { state.gradientClipEnabled = !state.gradientClipEnabled; }, @@ -284,6 +293,7 @@ export const { settingsStagingAreaAutoSwitchChanged, settingsFillColorPickerPinnedSet, settingsGradientTypeChanged, + settingsShapeTypeChanged, settingsGradientClipToggled, settingsLassoModeChanged, } = slice.actions; @@ -326,5 +336,6 @@ export const selectTransformSmoothingEnabled = createCanvasSettingsSelector( ); export const selectTransformSmoothingMode = createCanvasSettingsSelector((settings) => settings.transformSmoothingMode); export const selectGradientType = createCanvasSettingsSelector((settings) => settings.gradientType); +export const selectShapeType = createCanvasSettingsSelector((settings) => settings.shapeType); export const selectGradientClipEnabled = createCanvasSettingsSelector((settings) => settings.gradientClipEnabled); export const selectLassoMode = createCanvasSettingsSelector((settings) => settings.lassoMode); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 9e639c8e7af..a18a6ed308f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -70,7 +70,7 @@ import type { EntityLassoAddedPayload, EntityMovedToPayload, EntityRasterizedPayload, - EntityRectAddedPayload, + EntityShapeAddedPayload, IPMethodV2, T2IAdapterConfig, ZImageControlConfig, @@ -1568,8 +1568,8 @@ const slice = createSlice({ points: eraserLine.type === 'eraser_line' ? simplifyFlatNumbersArray(eraserLine.points) : eraserLine.points, }); }, - entityRectAdded: (state, action: PayloadAction) => { - const { entityIdentifier, rect } = action.payload; + entityShapeAdded: (state, action: PayloadAction) => { + const { entityIdentifier, shape } = action.payload; const entity = selectEntity(state, entityIdentifier); if (!entity) { return; @@ -1577,7 +1577,7 @@ const slice = createSlice({ // TODO(psyche): If we add the object without splatting, the renderer will see it as the same object and not // re-render it (reference equality check). I don't like this behaviour. - entity.objects.push({ ...rect }); + entity.objects.push({ ...shape }); }, entityLassoAdded: (state, action: PayloadAction) => { const { entityIdentifier, lasso } = action.payload; @@ -1910,7 +1910,7 @@ export const { entityRasterized, entityBrushLineAdded, entityEraserLineAdded, - entityRectAdded, + entityShapeAdded, entityLassoAdded, entityGradientAdded, // Raster layer adjustments @@ -2046,7 +2046,7 @@ export const canvasSliceConfig: SliceConfig = { const doNotGroupMatcher = isAnyOf( entityBrushLineAdded, entityEraserLineAdded, - entityRectAdded, + entityShapeAdded, entityLassoAdded, entityGradientAdded ); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 7a7ebeade71..cbeccdfa930 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -260,6 +260,7 @@ const zCanvasRectState = z.object({ type: z.literal('rect'), rect: zRect, color: zRgbaColor, + compositeOperation: z.enum(['source-over', 'destination-out']).default('source-over'), }); export type CanvasRectState = z.infer; @@ -277,6 +278,28 @@ const zCanvasLassoState = z.object({ }); export type CanvasLassoState = z.infer; +const zCanvasOvalState = z.object({ + id: zId, + type: z.literal('oval'), + rect: zRect, + color: zRgbaColor, + compositeOperation: z.enum(['source-over', 'destination-out']).default('source-over'), +}); +export type CanvasOvalState = z.infer; + +const zCanvasPolygonState = z.object({ + id: zId, + type: z.literal('polygon'), + points: zPoints, + color: zRgbaColor, + compositeOperation: z.enum(['source-over', 'destination-out']).default('source-over'), + previewPoint: zCoordinate.optional(), +}); +export type CanvasPolygonState = z.infer; + +const zCanvasShapeState = z.union([zCanvasRectState, zCanvasOvalState, zCanvasPolygonState]); +type CanvasShapeState = z.infer; + // Gradient state includes clip metadata so the tool can optionally clip to drag gesture. const zCanvasLinearGradientState = z.object({ id: zId, @@ -325,7 +348,7 @@ const zCanvasObjectState = z.union([ zCanvasImageState, zCanvasBrushLineState, zCanvasEraserLineState, - zCanvasRectState, + zCanvasShapeState, zCanvasLassoState, zCanvasBrushLineWithPressureState, zCanvasEraserLineWithPressureState, @@ -1020,8 +1043,8 @@ export type EntityBrushLineAddedPayload = EntityIdentifierPayload<{ export type EntityEraserLineAddedPayload = EntityIdentifierPayload<{ eraserLine: CanvasEraserLineState | CanvasEraserLineWithPressureState; }>; -export type EntityRectAddedPayload = EntityIdentifierPayload<{ rect: CanvasRectState }>; export type EntityLassoAddedPayload = EntityIdentifierPayload<{ lasso: CanvasLassoState }>; +export type EntityShapeAddedPayload = EntityIdentifierPayload<{ shape: CanvasShapeState }>; export type EntityGradientAddedPayload = EntityIdentifierPayload<{ gradient: CanvasGradientState }>; export type EntityRasterizedPayload = EntityIdentifierPayload<{ imageObject: CanvasImageState; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index 8f0e31ef5cd..63179db844a 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -276,7 +276,7 @@ export const MODEL_FORMAT_TO_LONG_NAME: Record = { unknown: 'Unknown', }; -export const SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS: BaseModelType[] = ['flux', 'sd-3', 'z-image']; +export const SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS: BaseModelType[] = ['flux', 'sd-3']; export const SUPPORTS_REF_IMAGES_BASE_MODELS: BaseModelType[] = ['sd-1', 'sdxl', 'flux', 'flux2', 'qwen-image']; diff --git a/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx b/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx index e8636466066..1ba2ffd572d 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx @@ -1,4 +1,6 @@ import { Badge, Portal } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectIsAuthenticated } from 'features/auth/store/authSlice'; import type { RefObject } from 'react'; import { memo, useEffect, useMemo, useState } from 'react'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; @@ -10,14 +12,24 @@ type Props = { type SessionQueueStatus = components['schemas']['SessionQueueStatus']; +const hasUserCounts = (queueData: SessionQueueStatus): boolean => { + return ( + queueData.user_pending !== undefined && + queueData.user_pending !== null && + queueData.user_in_progress !== undefined && + queueData.user_in_progress !== null + ); +}; + /** - * Calculates the appropriate badge text based on queue status. + * Calculates the appropriate badge text based on queue status and authentication state. * Returns null if badge should be hidden. * - * In multiuser mode, the backend already scopes counts to the current user for non-admins, - * so pending + in_progress reflects the user's own queue items. + * In multiuser mode, the badge is "X/Y" where X is the calling user's pending+in_progress count + * and Y is the total across all users. In single-user mode (or when user counts are unavailable) + * the badge shows the total only. */ -const getBadgeText = (queueData: SessionQueueStatus | undefined): string | null => { +const getBadgeText = (queueData: SessionQueueStatus | undefined, isAuthenticated: boolean): string | null => { if (!queueData) { return null; } @@ -28,18 +40,24 @@ const getBadgeText = (queueData: SessionQueueStatus | undefined): string | null return null; } + if (isAuthenticated && hasUserCounts(queueData)) { + const userPending = queueData.user_pending! + queueData.user_in_progress!; + return `${userPending}/${totalPending}`; + } + return totalPending.toString(); }; export const QueueCountBadge = memo(({ targetRef }: Props) => { const [badgePos, setBadgePos] = useState<{ x: string; y: string } | null>(null); + const isAuthenticated = useAppSelector(selectIsAuthenticated); const { queueData } = useGetQueueStatusQuery(undefined, { selectFromResult: (res) => ({ queueData: res.data?.queue, }), }); - const badgeText = useMemo(() => getBadgeText(queueData), [queueData]); + const badgeText = useMemo(() => getBadgeText(queueData, isAuthenticated), [queueData, isAuthenticated]); useEffect(() => { if (!targetRef.current) { diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewCanvasHeaderActions.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewCanvasHeaderActions.tsx index 024e92328f9..d486bc1137f 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewCanvasHeaderActions.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewCanvasHeaderActions.tsx @@ -3,7 +3,7 @@ import { useStore } from '@nanostores/react'; import { useAppSelector } from 'app/store/storeHooks'; import type { IDockviewHeaderActionsProps } from 'dockview'; import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; -import { selectLassoMode } from 'features/controlLayers/store/canvasSettingsSlice'; +import { selectLassoMode, selectShapeType } from 'features/controlLayers/store/canvasSettingsSlice'; import { selectBbox } from 'features/controlLayers/store/selectors'; import type { Tool } from 'features/controlLayers/store/types'; import { IS_MAC_OS } from 'features/system/components/HotkeysModal/useHotkeyData'; @@ -16,6 +16,7 @@ import { WORKSPACE_PANEL_ID } from './shared'; const $fallbackTool = atom('move'); const $fallbackToolBuffer = atom(null); +const $fallbackPrimaryPointerDown = atom(false); const $fallbackTextSession = atom(null); type CanvasToolModifierHintKey = ReturnType[number]['keys'][number]; @@ -45,10 +46,12 @@ export const DockviewCanvasHeaderActions = memo((props: IDockviewHeaderActionsPr const { t } = useTranslation(); const canvasManager = useCanvasManagerSafe(); const lassoMode = useAppSelector(selectLassoMode); + const shapeType = useAppSelector(selectShapeType); const bboxAspectRatioLocked = useAppSelector((state) => selectBbox(state).aspectRatio.isLocked); const tool = useStore(canvasManager?.tool.$tool ?? $fallbackTool); const toolBuffer = useStore(canvasManager?.tool.$toolBuffer ?? $fallbackToolBuffer); + const isPrimaryPointerDown = useStore(canvasManager?.tool.$isPrimaryPointerDown ?? $fallbackPrimaryPointerDown); const textSession = useStore(canvasManager?.tool.tools.text.$session ?? $fallbackTextSession); const effectiveTool = useMemo(() => { @@ -66,10 +69,21 @@ export const DockviewCanvasHeaderActions = memo((props: IDockviewHeaderActionsPr return getCanvasToolModifierHints({ tool: effectiveTool, lassoMode, + shapeType, bboxAspectRatioLocked, hasActiveTextSession: Boolean(textSession), + isPrimaryPointerDown, }); - }, [bboxAspectRatioLocked, canvasManager, effectiveTool, lassoMode, props.activePanel?.id, textSession]); + }, [ + bboxAspectRatioLocked, + canvasManager, + effectiveTool, + isPrimaryPointerDown, + lassoMode, + props.activePanel?.id, + shapeType, + textSession, + ]); if (hints.length === 0) { return null; diff --git a/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.test.ts b/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.test.ts index 7c599efad23..244634fd381 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.test.ts +++ b/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.test.ts @@ -2,88 +2,116 @@ import { describe, expect, it } from 'vitest'; import { getCanvasToolModifierHintIds } from './canvasToolModifierHints'; +const buildArgs = (overrides: Partial[0]> = {}) => ({ + tool: 'brush' as const, + lassoMode: 'freehand' as const, + shapeType: 'rect' as const, + bboxAspectRatioLocked: false, + hasActiveTextSession: false, + isPrimaryPointerDown: false, + ...overrides, +}); + describe('getCanvasToolModifierHintIds', () => { it('returns brush hints in priority order', () => { - expect( - getCanvasToolModifierHintIds({ - tool: 'brush', - lassoMode: 'freehand', - bboxAspectRatioLocked: false, - hasActiveTextSession: false, - }) - ).toEqual(['shiftStraightLine', 'modWheelResizeBrush', 'spacePan', 'altPickColor']); + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'brush' }))).toEqual([ + 'shiftStraightLine', + 'modWheelResizeBrush', + 'spacePan', + 'altPickColor', + ]); }); it('omits alt color-picker hint for eraser', () => { - expect( - getCanvasToolModifierHintIds({ - tool: 'eraser', - lassoMode: 'freehand', - bboxAspectRatioLocked: false, - hasActiveTextSession: false, - }) - ).toEqual(['shiftStraightLine', 'modWheelResizeEraser', 'spacePan']); + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'eraser' }))).toEqual([ + 'shiftStraightLine', + 'modWheelResizeEraser', + 'spacePan', + ]); }); it('adds polygon snapping for polygon lasso', () => { - expect( - getCanvasToolModifierHintIds({ - tool: 'lasso', - lassoMode: 'polygon', - bboxAspectRatioLocked: false, - hasActiveTextSession: false, - }) - ).toEqual(['modSubtractMask', 'shiftSnap45Degrees', 'spacePan']); + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'lasso', lassoMode: 'polygon' }))).toEqual([ + 'modErase', + 'shiftSnap45Degrees', + 'spacePan', + ]); }); it('omits polygon snapping for freehand lasso', () => { - expect( - getCanvasToolModifierHintIds({ - tool: 'lasso', - lassoMode: 'freehand', - bboxAspectRatioLocked: false, - hasActiveTextSession: false, - }) - ).toEqual(['modSubtractMask', 'spacePan']); + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'lasso' }))).toEqual(['modErase', 'spacePan']); }); it('switches the bbox aspect-ratio hint based on lock state', () => { - expect( - getCanvasToolModifierHintIds({ - tool: 'bbox', - lassoMode: 'freehand', - bboxAspectRatioLocked: false, - hasActiveTextSession: false, - }) - ).toEqual(['shiftLockAspectRatio', 'altScaleFromCenter', 'modFineGrid']); + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'bbox' }))).toEqual([ + 'shiftLockAspectRatio', + 'altScaleFromCenter', + 'modFineGrid', + ]); - expect( - getCanvasToolModifierHintIds({ - tool: 'bbox', - lassoMode: 'freehand', - bboxAspectRatioLocked: true, - hasActiveTextSession: false, - }) - ).toEqual(['shiftUnlockAspectRatio', 'altScaleFromCenter', 'modFineGrid']); + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'bbox', bboxAspectRatioLocked: true }))).toEqual([ + 'shiftUnlockAspectRatio', + 'altScaleFromCenter', + 'modFineGrid', + ]); }); it('only shows text-session hints when a text session is active', () => { + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'text', hasActiveTextSession: true }))).toEqual([ + 'enterCommitText', + 'shiftEnterNewLine', + 'escCancelText', + 'modDragText', + 'shiftSnapRotation', + ]); + + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'text' }))).toEqual(['spacePan', 'altPickColor']); + }); + + it('shows idle rect and oval shapes hints', () => { + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'rect' }))).toEqual([ + 'modErase', + 'shiftLockAspectRatio', + 'spacePan', + 'altPickColor', + ]); + + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'oval' }))).toEqual([ + 'modErase', + 'shiftLockAspectRatio', + 'spacePan', + 'altPickColor', + ]); + }); + + it('shows active rect and oval drag hints', () => { + expect( + getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'rect', isPrimaryPointerDown: true })) + ).toEqual(['modErase', 'shiftLockAspectRatio', 'altScaleFromCenter', 'spaceMoveShape']); + expect( - getCanvasToolModifierHintIds({ - tool: 'text', - lassoMode: 'freehand', - bboxAspectRatioLocked: false, - hasActiveTextSession: true, - }) - ).toEqual(['enterCommitText', 'shiftEnterNewLine', 'escCancelText', 'modDragText', 'shiftSnapRotation']); + getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'oval', isPrimaryPointerDown: true })) + ).toEqual(['modErase', 'shiftLockAspectRatio', 'altScaleFromCenter', 'spaceMoveShape']); + }); + + it('shows polygon shape hints', () => { + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'polygon' }))).toEqual([ + 'modErase', + 'shiftSnap45Degrees', + 'spacePan', + 'altPickColor', + ]); + }); + + it('omits alt color-picker hint during an active freehand stroke', () => { + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'freehand' }))).toEqual([ + 'modErase', + 'spacePan', + 'altPickColor', + ]); expect( - getCanvasToolModifierHintIds({ - tool: 'text', - lassoMode: 'freehand', - bboxAspectRatioLocked: false, - hasActiveTextSession: false, - }) - ).toEqual(['spacePan', 'altPickColor']); + getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'freehand', isPrimaryPointerDown: true })) + ).toEqual(['modErase', 'spacePan']); }); }); diff --git a/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.ts b/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.ts index a23682e2072..3543ac4358d 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.ts +++ b/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.ts @@ -1,14 +1,20 @@ +import { + shouldQuickSwitchToColorPickerOnAlt, + shouldTranslateShapeDragOnSpace, +} from 'features/controlLayers/konva/CanvasTool/toolHotkeys'; import type { Tool } from 'features/controlLayers/store/types'; +type ShapeType = 'rect' | 'oval' | 'polygon' | 'freehand'; type CanvasToolModifierHintKey = 'mod' | 'shift' | 'alt' | 'space' | 'wheel' | 'arrows' | 'enter' | 'esc'; type CanvasToolModifierHintId = | 'spacePan' + | 'spaceMoveShape' | 'altPickColor' | 'shiftStraightLine' | 'modWheelResizeBrush' | 'modWheelResizeEraser' - | 'modSubtractMask' + | 'modErase' | 'shiftSnap45Degrees' | 'shiftLockAspectRatio' | 'shiftUnlockAspectRatio' @@ -35,6 +41,11 @@ const HINTS: Record = { keys: ['space'], labelKey: 'controlLayers.modifierHints.labels.pan', }, + spaceMoveShape: { + id: 'spaceMoveShape', + keys: ['space'], + labelKey: 'controlLayers.modifierHints.labels.moveShape', + }, altPickColor: { id: 'altPickColor', keys: ['alt'], @@ -55,10 +66,10 @@ const HINTS: Record = { keys: ['mod', 'wheel'], labelKey: 'controlLayers.modifierHints.labels.resizeEraser', }, - modSubtractMask: { - id: 'modSubtractMask', + modErase: { + id: 'modErase', keys: ['mod'], - labelKey: 'controlLayers.modifierHints.labels.subtractMask', + labelKey: 'controlLayers.modifierHints.labels.erase', }, shiftSnap45Degrees: { id: 'shiftSnap45Degrees', @@ -120,8 +131,10 @@ const HINTS: Record = { type GetCanvasToolModifierHintsArg = { tool: Tool; lassoMode: 'freehand' | 'polygon'; + shapeType: ShapeType; bboxAspectRatioLocked: boolean; hasActiveTextSession: boolean; + isPrimaryPointerDown: boolean; }; const mapHintIdsToHints = (hintIds: readonly CanvasToolModifierHintId[]): CanvasToolModifierHint[] => @@ -130,8 +143,10 @@ const mapHintIdsToHints = (hintIds: readonly CanvasToolModifierHintId[]): Canvas export const getCanvasToolModifierHintIds = ({ tool, lassoMode, + shapeType, bboxAspectRatioLocked, hasActiveTextSession, + isPrimaryPointerDown, }: GetCanvasToolModifierHintsArg): CanvasToolModifierHintId[] => { // Resolver map: each tool returns the relevant hint ids based on the provided args. const TOOL_HINT_RESOLVERS: Record< @@ -141,7 +156,7 @@ export const getCanvasToolModifierHintIds = ({ brush: () => ['shiftStraightLine', 'modWheelResizeBrush', ...SHARED_HINT_IDS], eraser: () => ['shiftStraightLine', 'modWheelResizeEraser', 'spacePan'], lasso: ({ lassoMode: lm }) => - lm === 'polygon' ? ['modSubtractMask', 'shiftSnap45Degrees', 'spacePan'] : ['modSubtractMask', 'spacePan'], + lm === 'polygon' ? ['modErase', 'shiftSnap45Degrees', 'spacePan'] : ['modErase', 'spacePan'], bbox: ({ bboxAspectRatioLocked: locked }) => [ locked ? 'shiftUnlockAspectRatio' : 'shiftLockAspectRatio', 'altScaleFromCenter', @@ -155,7 +170,21 @@ export const getCanvasToolModifierHintIds = ({ view: () => ['altPickColor'], colorPicker: () => ['spacePan'], gradient: () => [...SHARED_HINT_IDS], - rect: () => [...SHARED_HINT_IDS], + rect: ({ shapeType: st, isPrimaryPointerDown: pointerDown }) => { + if (st === 'polygon') { + return ['modErase', 'shiftSnap45Degrees', 'spacePan', 'altPickColor']; + } + + if (st === 'freehand') { + return shouldQuickSwitchToColorPickerOnAlt('rect', st, pointerDown) + ? ['modErase', 'spacePan', 'altPickColor'] + : ['modErase', 'spacePan']; + } + + return shouldTranslateShapeDragOnSpace('rect', st, pointerDown, pointerDown) + ? ['modErase', 'shiftLockAspectRatio', 'altScaleFromCenter', 'spaceMoveShape'] + : ['modErase', 'shiftLockAspectRatio', 'spacePan', 'altPickColor']; + }, }; const resolver = TOOL_HINT_RESOLVERS[tool]; @@ -163,7 +192,9 @@ export const getCanvasToolModifierHintIds = ({ if (!resolver) { return []; } - return Array.from(resolver({ tool, lassoMode, bboxAspectRatioLocked, hasActiveTextSession })); + return Array.from( + resolver({ tool, lassoMode, shapeType, bboxAspectRatioLocked, hasActiveTextSession, isPrimaryPointerDown }) + ); }; export const getCanvasToolModifierHints = (args: GetCanvasToolModifierHintsArg): CanvasToolModifierHint[] => diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 68f24a26ec1..ef1d7d92f43 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1875,7 +1875,9 @@ export type paths = { }; /** * Get Queue Item Ids - * @description Gets all queue item ids that match the given parameters. Non-admin users only see their own items. + * @description Gets all queue item ids that match the given parameters. The IDs themselves are not sensitive; + * per-item field redaction is performed when the items are fetched via list_all_queue_items or + * get_queue_items_by_item_ids. */ get: operations["get_queue_item_ids"]; put?: never; @@ -2135,7 +2137,9 @@ export type paths = { }; /** * Get Queue Status - * @description Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it. + * @description Gets the status of the session queue. Returns global counts plus the calling user's own + * pending/in_progress counts (so the UI can show an X/Y badge). Non-admin users cannot see the + * current item's identifiers unless they own it. */ get: operations["get_queue_status"]; put?: never; @@ -16183,6 +16187,7 @@ export type components = { * force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). * pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. * max_queue_size: Maximum number of items in the session queue. + * session_queue_mode: Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.
Valid values: `FIFO`, `round_robin` * clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`. * max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. * allow_nodes: List of nodes to allow. Omit to allow all. @@ -16525,6 +16530,13 @@ export type components = { * @default 10000 */ max_queue_size?: number; + /** + * Session Queue Mode + * @description Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting. + * @default round_robin + * @enum {string} + */ + session_queue_mode?: "FIFO" | "round_robin"; /** * Clear Queue On Startup * @description Empties session queue on startup. If true, disables `max_queue_history`. @@ -28086,6 +28098,16 @@ export type components = { * @description Total number of queue items */ total: number; + /** + * User Pending + * @description Number of pending queue items for the calling user (multiuser only) + */ + user_pending?: number | null; + /** + * User In Progress + * @description Number of in-progress queue items for the calling user (multiuser only) + */ + user_in_progress?: number | null; }; /** * SetupRequest diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 1e73abb2027..b9d364cf3b4 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -389,6 +389,24 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis }); socket.on('queue_item_status_changed', (data) => { + // Sanitized companion event sent to non-owner queue subscribers in multiuser mode. The + // backend sets user_id="redacted" and clears identifiers/error fields. We must not run + // payload-driven cache mutations or per-session side effects (node state reset, progress + // clear, completion bookkeeping) — those belong to the owner. Just invalidate queue tags + // so the non-owner's queue list and badge counts refetch with sanitized data. + if (data.user_id === 'redacted') { + log.trace({ data }, `Sanitized queue_item_status_changed for item ${data.item_id}`); + const tags: ApiTagDescription[] = [ + 'SessionQueueStatus', + 'SessionQueueItemIdList', + { type: 'SessionQueueItem', id: data.item_id }, + { type: 'SessionQueueItem', id: LIST_TAG }, + { type: 'SessionQueueItem', id: LIST_ALL_TAG }, + ]; + dispatch(queueApi.util.invalidateTags(tags)); + return; + } + if (finishedQueueItemIds.has(data.item_id)) { log.trace({ data }, `Received event for already-finished queue item ${data.item_id}`); return; diff --git a/tests/app/routers/test_multiuser_authorization.py b/tests/app/routers/test_multiuser_authorization.py index 3461f37e7e9..90bc82b30cf 100644 --- a/tests/app/routers/test_multiuser_authorization.py +++ b/tests/app/routers/test_multiuser_authorization.py @@ -1333,14 +1333,31 @@ def test_get_queue_status_hides_current_item_for_non_owner(self): assert status_obj.session_id is None assert status_obj.batch_id is None - def test_session_queue_status_no_user_fields(self): - """SessionQueueStatus should not have user_pending/user_in_progress fields anymore. - Non-admin users now get their own counts in the main pending/in_progress fields.""" + def test_session_queue_status_has_user_fields(self): + """SessionQueueStatus exposes user_pending/user_in_progress so the queue badge + can render an X/Y count (X = caller's jobs, Y = global total).""" from invokeai.app.services.session_queue.session_queue_common import SessionQueueStatus fields = set(SessionQueueStatus.model_fields.keys()) - assert "user_pending" not in fields - assert "user_in_progress" not in fields + assert "user_pending" in fields + assert "user_in_progress" in fields + + status_obj = SessionQueueStatus( + queue_id="default", + item_id=None, + session_id=None, + batch_id=None, + pending=5, + in_progress=1, + completed=0, + failed=0, + canceled=0, + total=6, + user_pending=2, + user_in_progress=1, + ) + assert status_obj.user_pending == 2 + assert status_obj.user_in_progress == 1 # =========================================================================== @@ -1708,8 +1725,11 @@ def test_batch_enqueued_event_carries_user_id(self) -> None: assert event.queue_id == "default" def test_queue_item_status_changed_routed_privately(self, socketio: Any) -> None: - """Verify that _handle_queue_event emits QueueItemStatusChangedEvent ONLY to - user:{user_id} and admin rooms, never to the queue_id room.""" + """_handle_queue_event must emit the FULL QueueItemStatusChangedEvent only to the + owner's user room and the admin room. A sanitized companion (user_id="redacted", + identifiers stripped) is also emitted to the queue_id room so other users' UIs can + refresh, with the owner's and admins' sids in skip_sid so they don't get a duplicate + that would clobber their cache.""" import asyncio from unittest.mock import AsyncMock @@ -1758,20 +1778,60 @@ def test_queue_item_status_changed_routed_privately(self, socketio: Any) -> None ), ) + # Track owner sid so we can verify skip_sid is honored + socketio._socket_users["sid-owner"] = {"user_id": "owner-xyz", "is_admin": False} + socketio._socket_users["sid-admin"] = {"user_id": "admin-1", "is_admin": True} + socketio._socket_users["sid-other"] = {"user_id": "other-user", "is_admin": False} + mock_emit = AsyncMock() socketio._sio.emit = mock_emit asyncio.run(socketio._handle_queue_event(("queue_item_status_changed", event))) - rooms_emitted_to = [call.kwargs.get("room") for call in mock_emit.call_args_list] - assert "user:owner-xyz" in rooms_emitted_to - assert "admin" in rooms_emitted_to - # CRITICAL: must NOT emit to the queue_id room — that would leak to other users - assert "default" not in rooms_emitted_to + # Collect (room, payload, skip_sid) for each emit call + emits = [ + (c.kwargs.get("room"), c.kwargs.get("data"), c.kwargs.get("skip_sid")) for c in mock_emit.call_args_list + ] + + # Full event must go to owner room and admin room with original sensitive fields + owner_emits = [(p, s) for r, p, s in emits if r == "user:owner-xyz"] + admin_emits = [(p, s) for r, p, s in emits if r == "admin"] + assert len(owner_emits) == 1 and len(admin_emits) == 1 + for payload, _ in owner_emits + admin_emits: + assert payload["user_id"] == "owner-xyz" + assert payload["batch_id"] == "batch-private" + assert payload["session_id"] == "sess-private" + assert payload["destination"] == "canvas" + + # A sanitized companion event must go to the queue_id room with sensitive fields cleared + queue_emits = [(p, s) for r, p, s in emits if r == "default"] + assert len(queue_emits) == 1, "expected exactly one sanitized emit to queue room" + sanitized_payload, skip_sid = queue_emits[0] + assert sanitized_payload["user_id"] == "redacted" + assert sanitized_payload["batch_id"] == "redacted" + assert sanitized_payload["session_id"] == "redacted" + assert sanitized_payload["origin"] is None + assert sanitized_payload["destination"] is None + assert sanitized_payload["error_type"] is None + assert sanitized_payload["batch_status"]["batch_id"] == "redacted" + assert sanitized_payload["batch_status"]["destination"] is None + assert sanitized_payload["queue_status"]["item_id"] is None + assert sanitized_payload["queue_status"]["batch_id"] is None + assert sanitized_payload["queue_status"]["user_pending"] is None + # Owner and admin sids must be skipped so they don't receive the duplicate + assert "sid-owner" in skip_sid + assert "sid-admin" in skip_sid + # Third-party user must NOT be skipped — they need the sanitized event + assert "sid-other" not in skip_sid + # Status (non-sensitive) is preserved so the non-owner UI knows what changed + assert sanitized_payload["status"] == "in_progress" + assert sanitized_payload["item_id"] == 1 def test_batch_enqueued_routed_privately(self, socketio: Any) -> None: - """Verify that _handle_queue_event emits BatchEnqueuedEvent ONLY to - user:{user_id} and admin rooms, never to the queue_id room.""" + """_handle_queue_event must emit the FULL BatchEnqueuedEvent only to the owner's + user room and the admin room. A sanitized companion (user_id="redacted", batch_id + and origin stripped) is also emitted to the queue_id room so other users' badge + totals refresh, with owner/admin sids in skip_sid.""" import asyncio from unittest.mock import AsyncMock @@ -1792,15 +1852,39 @@ def test_batch_enqueued_routed_privately(self, socketio: Any) -> None: ) event = BatchEnqueuedEvent.build(enqueue_result, user_id="owner-zzz") + socketio._socket_users["sid-owner"] = {"user_id": "owner-zzz", "is_admin": False} + socketio._socket_users["sid-admin"] = {"user_id": "admin-1", "is_admin": True} + socketio._socket_users["sid-other"] = {"user_id": "other-user", "is_admin": False} + mock_emit = AsyncMock() socketio._sio.emit = mock_emit asyncio.run(socketio._handle_queue_event(("batch_enqueued", event))) - rooms_emitted_to = [call.kwargs.get("room") for call in mock_emit.call_args_list] - assert "user:owner-zzz" in rooms_emitted_to - assert "admin" in rooms_emitted_to - assert "default" not in rooms_emitted_to + emits = [ + (c.kwargs.get("room"), c.kwargs.get("data"), c.kwargs.get("skip_sid")) for c in mock_emit.call_args_list + ] + + # Full event to owner + admin contains the real batch_id and origin + owner_emits = [(p, s) for r, p, s in emits if r == "user:owner-zzz"] + admin_emits = [(p, s) for r, p, s in emits if r == "admin"] + assert len(owner_emits) == 1 and len(admin_emits) == 1 + for payload, _ in owner_emits + admin_emits: + assert payload["user_id"] == "owner-zzz" + assert payload["batch_id"] == "batch-pvt" + assert payload["origin"] == "workflows" + + # Sanitized event to queue room: user/batch/origin redacted, owner+admin skipped + queue_emits = [(p, s) for r, p, s in emits if r == "default"] + assert len(queue_emits) == 1 + sanitized_payload, skip_sid = queue_emits[0] + assert sanitized_payload["user_id"] == "redacted" + assert sanitized_payload["batch_id"] == "redacted" + assert sanitized_payload["origin"] is None + assert sanitized_payload["enqueued"] == 5 # count is non-sensitive + assert "sid-owner" in skip_sid + assert "sid-admin" in skip_sid + assert "sid-other" not in skip_sid def test_queue_cleared_still_broadcast(self, socketio: Any) -> None: """QueueClearedEvent does not carry user identity and should still be broadcast diff --git a/tests/app/services/session_queue/test_session_queue_dequeue.py b/tests/app/services/session_queue/test_session_queue_dequeue.py new file mode 100644 index 00000000000..0f82f2babaa --- /dev/null +++ b/tests/app/services/session_queue/test_session_queue_dequeue.py @@ -0,0 +1,214 @@ +"""Tests for session queue dequeue() ordering: FIFO and round-robin modes.""" + +import json +import uuid +from typing import Optional + +import pytest +from pydantic_core import to_jsonable_python + +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue +from invokeai.app.services.shared.graph import Graph, GraphExecutionState + +_EMPTY_SESSION_JSON = json.dumps(to_jsonable_python(GraphExecutionState(graph=Graph()).model_dump())) + + +@pytest.fixture +def session_queue_fifo(mock_invoker: Invoker) -> SqliteSessionQueue: + """Queue backed by a single-user (FIFO) invoker.""" + # Default config has multiuser=False, so FIFO is always used. + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +@pytest.fixture +def session_queue_round_robin(mock_invoker: Invoker) -> SqliteSessionQueue: + """Queue backed by a multiuser invoker with round_robin mode.""" + mock_invoker.services.configuration = InvokeAIAppConfig( + use_memory_db=True, + node_cache_size=0, + multiuser=True, + session_queue_mode="round_robin", + ) + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +def _insert_queue_item( + session_queue: SqliteSessionQueue, + queue_id: str, + user_id: str, + priority: int = 0, +) -> int: + """Directly insert a minimal queue item and return its item_id.""" + session_id = str(uuid.uuid4()) + batch_id = str(uuid.uuid4()) + with session_queue._db.transaction() as cursor: + cursor.execute( + """--sql + INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id, user_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (queue_id, _EMPTY_SESSION_JSON, session_id, batch_id, None, priority, None, None, None, None, user_id), + ) + return cursor.lastrowid # type: ignore[return-value] + + +def _dequeue_user_ids(session_queue: SqliteSessionQueue, count: int) -> list[Optional[str]]: + """Dequeue `count` items and return the list of user_ids in dequeue order.""" + result = [] + for _ in range(count): + item = session_queue.dequeue() + result.append(item.user_id if item is not None else None) + return result + + +# --------------------------------------------------------------------------- +# FIFO tests +# --------------------------------------------------------------------------- + + +def test_fifo_single_user_order(session_queue_fifo: SqliteSessionQueue) -> None: + """FIFO: items from a single user are dequeued in insertion order.""" + queue_id = "default" + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + + user_ids = _dequeue_user_ids(session_queue_fifo, 3) + assert user_ids == ["user_a", "user_a", "user_a"] + + +def test_fifo_multi_user_preserves_insertion_order(session_queue_fifo: SqliteSessionQueue) -> None: + """FIFO: jobs from multiple users are dequeued in strict insertion order, not interleaved.""" + queue_id = "default" + # Insert A1, A2, B1, C1, C2, A3 – FIFO should preserve this exact order. + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + _insert_queue_item(session_queue_fifo, queue_id, "user_b") + _insert_queue_item(session_queue_fifo, queue_id, "user_c") + _insert_queue_item(session_queue_fifo, queue_id, "user_c") + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + + user_ids = _dequeue_user_ids(session_queue_fifo, 6) + assert user_ids == ["user_a", "user_a", "user_b", "user_c", "user_c", "user_a"] + + +def test_fifo_priority_respected(session_queue_fifo: SqliteSessionQueue) -> None: + """FIFO: higher-priority items are dequeued before lower-priority ones.""" + queue_id = "default" + _insert_queue_item(session_queue_fifo, queue_id, "user_a", priority=0) + _insert_queue_item(session_queue_fifo, queue_id, "user_a", priority=10) + + user_ids = _dequeue_user_ids(session_queue_fifo, 2) + # Both are user_a; second inserted item has higher priority and should come first. + assert user_ids == ["user_a", "user_a"] + + +def test_fifo_returns_none_when_empty(session_queue_fifo: SqliteSessionQueue) -> None: + """FIFO: dequeue returns None when the queue is empty.""" + assert session_queue_fifo.dequeue() is None + + +# --------------------------------------------------------------------------- +# Round-robin tests +# --------------------------------------------------------------------------- + + +def test_round_robin_interleaves_users(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin: jobs from multiple users are interleaved one per user per round. + + Queue insertion order (matching the issue example): + A job 1, A job 2, B job 1, C job 1, C job 2, A job 3 + + Expected dequeue order: + A job 1, B job 1, C job 1, A job 2, C job 2, A job 3 + """ + queue_id = "default" + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_b") + _insert_queue_item(session_queue_round_robin, queue_id, "user_c") + _insert_queue_item(session_queue_round_robin, queue_id, "user_c") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + + user_ids = _dequeue_user_ids(session_queue_round_robin, 6) + assert user_ids == ["user_a", "user_b", "user_c", "user_a", "user_c", "user_a"] + + +def test_round_robin_single_user_behaves_like_fifo(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin with only one user produces the same order as FIFO.""" + queue_id = "default" + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + + user_ids = _dequeue_user_ids(session_queue_round_robin, 3) + assert user_ids == ["user_a", "user_a", "user_a"] + + +def test_round_robin_handles_user_joining_mid_queue(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin: a user who joins later is correctly interleaved.""" + queue_id = "default" + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_b") + + user_ids = _dequeue_user_ids(session_queue_round_robin, 3) + # Round 1: A (oldest rank-1 item), B (rank-1 item) + # Round 2: A (rank-2 item) + assert user_ids == ["user_a", "user_b", "user_a"] + + +def test_round_robin_returns_none_when_empty(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin: dequeue returns None when the queue is empty.""" + assert session_queue_round_robin.dequeue() is None + + +def test_round_robin_priority_within_user_respected(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin: within a single user's items, higher priority is dequeued first.""" + queue_id = "default" + # Insert low-priority item first, then high-priority for same user. + _insert_queue_item(session_queue_round_robin, queue_id, "user_a", priority=0) + _insert_queue_item(session_queue_round_robin, queue_id, "user_a", priority=10) + _insert_queue_item(session_queue_round_robin, queue_id, "user_b", priority=0) + + # Round 1: user_a's best item (priority 10), user_b's only item. + # Round 2: user_a's remaining item (priority 0). + items = [] + for _ in range(3): + item = session_queue_round_robin.dequeue() + assert item is not None + items.append((item.user_id, item.priority)) + + assert items[0] == ("user_a", 10) + assert items[1] == ("user_b", 0) + assert items[2] == ("user_a", 0) + + +def test_round_robin_ignored_in_single_user_mode(mock_invoker: Invoker) -> None: + """When multiuser=False, round_robin config is ignored and FIFO is used.""" + mock_invoker.services.configuration = InvokeAIAppConfig( + use_memory_db=True, + node_cache_size=0, + multiuser=False, + session_queue_mode="round_robin", + ) + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + + queue_id = "default" + _insert_queue_item(queue, queue_id, "user_a") + _insert_queue_item(queue, queue_id, "user_a") + _insert_queue_item(queue, queue_id, "user_b") + + # FIFO order: user_a, user_a, user_b + user_ids = _dequeue_user_ids(queue, 3) + assert user_ids == ["user_a", "user_a", "user_b"]