Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/winml/modelkit/commands/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ def _build_runtime_debug_output_path(model_path: Path, ep_name: str, device_name
@cli_utils.verbosity_options()
@cli_utils.build_config_option()
@cli_utils.output_option("Save JSON output to file")
@cli_utils.overwrite_option(optional_message="Applies to both --output and --optim-config.")
@click.option(
"--information/--no-information",
default=True,
Expand Down Expand Up @@ -790,6 +791,7 @@ def analyze(
ep: EPNameOrAlias | Literal["all", "auto"] | None,
device: str | None,
output: Path | None,
overwrite: bool,
information: bool,
output_format: cli_utils.OutputFormat,
verbose: int,
Expand Down Expand Up @@ -833,6 +835,11 @@ def analyze(
verbose, quiet = cli_utils.resolve_verbosity(ctx, verbose, quiet)
configure_logging(verbosity=verbose, quiet=quiet)

# Refuse to clobber existing outputs unless the user opted in — fail fast
# before analysis runs. Guards both result JSON and the optim-config dump.
cli_utils.guard_output(output, overwrite)
cli_utils.guard_output(optim_config, overwrite, label="Optimization config")

try:
from ..analyze import ONNXStaticAnalyzer

Expand Down Expand Up @@ -958,9 +965,7 @@ def analyze(
sys.exit(2)
compatible_eps = resolve_eps(ref_device)
if not compatible_eps:
logger.error(
"No execution provider is available for device '%s'.", ref_device
)
logger.error("No execution provider is available for device '%s'.", ref_device)
sys.exit(2)
eps = [compatible_eps[0]]
else:
Expand Down
3 changes: 3 additions & 0 deletions src/winml/modelkit/commands/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,13 +384,15 @@ def _save_json(data: Any, path: Path) -> None:
optional_message="If not specified, shows all devices",
)
@cli_utils.output_option("Save results to a JSON file.")
@cli_utils.overwrite_option()
@cli_utils.format_option()
def catalog(
model_type: str | None,
task: str | None,
ep: EPNameOrAlias | None,
device: str | None,
output: Path | None,
overwrite: bool,
output_format: cli_utils.OutputFormat,
) -> None:
r"""Browse WinML CLI's curated built-in model catalog.
Expand Down Expand Up @@ -435,4 +437,5 @@ def catalog(
_output_list(models, ep_col_header=ep_col_header, ep_col_fn=ep_col_fn)

if output is not None:
cli_utils.guard_output(output, overwrite)
_save_json(models, output)
5 changes: 5 additions & 0 deletions src/winml/modelkit/commands/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"EP context (weight sharing). Required unless --list.",
)
@cli_utils.output_option("Output file path (e.g., model_compiled.onnx)")
@cli_utils.overwrite_option()
@click.option(
"--output-dir",
type=click.Path(path_type=Path),
Expand Down Expand Up @@ -108,6 +109,7 @@ def compile(
model: tuple[Path, ...],
output: Path | None,
output_dir: Path | None,
overwrite: bool,
device: str,
ep: EPNameOrAlias | None,
validate: bool,
Expand Down Expand Up @@ -243,6 +245,9 @@ def compile(
console.print(f"[bold blue]SDK root:[/bold blue] {qnn_sdk_root}")
# Resolve output path: -o (file) takes precedence over --output-dir
resolved_output = output or output_dir
# Refuse to clobber an existing output unless the user opted in. A file
# blocks when it exists; a directory blocks only when non-empty.
cli_utils.guard_output(resolved_output, overwrite)
if output:
console.print(f"[bold blue]Output:[/bold blue] {output}")
elif output_dir:
Expand Down
6 changes: 6 additions & 0 deletions src/winml/modelkit/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _apply_stage_overrides(cfg: Any, *, no_quant: bool, no_compile: bool) -> Non
)
@cli_utils.precision_option()
@cli_utils.output_option("Output JSON file path (default: stdout)")
@cli_utils.overwrite_option()
@click.option(
"--library",
"library_name",
Expand Down Expand Up @@ -141,6 +142,7 @@ def config(
ep: EPNameOrAlias | None,
precision: str,
output: Path | None,
overwrite: bool,
library_name: str,
verbose: int,
quiet: bool,
Expand Down Expand Up @@ -310,6 +312,7 @@ def config(
no_quant=not quant,
no_compile=no_compile,
output=output,
overwrite=overwrite,
console=console,
)
return
Expand Down Expand Up @@ -444,6 +447,7 @@ def config(
config_json = json.dumps(output_data, indent=2)

if output:
cli_utils.guard_output(output, overwrite)
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(config_json)
suffix = f" [dim]({_n_modules} submodules)[/dim]" if _n_modules else ""
Expand Down Expand Up @@ -516,6 +520,7 @@ def _generate_pipeline_configs(
no_quant: bool,
no_compile: bool,
output: Path | None,
overwrite: bool,
console: Any,
) -> None:
"""Generate and save one config file per pipeline sub-component."""
Expand Down Expand Up @@ -546,6 +551,7 @@ def _generate_pipeline_configs(

if output:
suffixed = output.with_stem(f"{output.stem}_{component_name}")
cli_utils.guard_output(suffixed, overwrite)
suffixed.parent.mkdir(parents=True, exist_ok=True)
tmp = suffixed.with_suffix(".json.tmp")
tmp.write_text(config_json)
Expand Down
6 changes: 6 additions & 0 deletions src/winml/modelkit/commands/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
help='Path to a JSON file with label mapping: {"label_name": id}.',
)
@cli_utils.output_option("Output JSON file path.")
@cli_utils.overwrite_option()
@click.option(
"--dataset-script",
type=str,
Expand Down Expand Up @@ -194,6 +195,7 @@ def eval(
column: tuple[str, ...],
label_mapping_path: Path | None,
output: Path | None,
overwrite: bool,
output_format: cli_utils.OutputFormat,
verbose: int,
quiet: bool,
Expand Down Expand Up @@ -250,6 +252,10 @@ def eval(
_resolve_label_mapping(cfg)
_run_dataset_script(cfg, trust_remote_code)

# Refuse to clobber an existing report unless the user opted in — fail fast
# before the (expensive) evaluation runs.
cli_utils.guard_output(cfg.output_path, overwrite)

if cfg.model_path is not None and cfg.precision != "auto":
logger.warning(
"--precision %s is ignored for pre-built ONNX inputs "
Expand Down
7 changes: 6 additions & 1 deletion src/winml/modelkit/commands/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _delete_onnx_with_external_data(onnx_path: Path) -> None:
help="HuggingFace model name or local path (e.g., prajjwal1/bert-tiny)",
)
@cli_utils.output_option("Output ONNX file path (e.g., model.onnx)", required=True)
@cli_utils.overwrite_option()
@click.option(
"--with-report/--no-with-report",
default=False,
Expand Down Expand Up @@ -136,6 +137,7 @@ def export(
ctx: click.Context,
model: str,
output: Path,
overwrite: bool,
verbose: int,
quiet: bool,
with_report: bool,
Expand Down Expand Up @@ -237,8 +239,11 @@ def export(
if export_config:
console.print(f"[bold blue]Export config:[/bold blue] {export_config}")

# Create output directory if needed
# Refuse to clobber an existing output unless the user opted in.
output_path = Path(output)
cli_utils.guard_output(output_path, overwrite)

# Create output directory if needed
output_path.parent.mkdir(parents=True, exist_ok=True)

# Load export configuration from JSON file if provided, or create default
Expand Down
5 changes: 5 additions & 0 deletions src/winml/modelkit/commands/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def capability_options(func: F) -> F:
help="Input ONNX model file",
)
@cli_utils.output_option("Output path (default: {input}_opt.onnx)")
@cli_utils.overwrite_option()
@click.option(
"--config",
"-c",
Expand All @@ -189,6 +190,7 @@ def optimize(
list_rewrites: bool,
model: Path | None,
output: Path | None,
overwrite: bool,
config: Path | None,
verbose: int,
quiet: bool,
Expand Down Expand Up @@ -346,6 +348,9 @@ def optimize(
if output is None:
output = model.parent / f"{model.stem}_opt.onnx"

# Refuse to clobber an existing output unless the user opted in.
cli_utils.guard_output(output, overwrite)

# Show info
console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Output:[/bold blue] {output}")
Expand Down
5 changes: 5 additions & 0 deletions src/winml/modelkit/commands/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,7 @@ def _run_simple_loop(
"Output JSON file path. Defaults to "
"'~/.cache/winml/perf/<model_slug>[/<module_class>]/<timestamp>.json'."
)
@cli_utils.overwrite_option()
@click.option(
"--batch-size",
type=int,
Expand Down Expand Up @@ -1555,6 +1556,7 @@ def perf(
ep: EPNameOrAlias | None,
ep_options: tuple[str, ...],
output: Path | None,
overwrite: bool,
batch_size: int,
shape_config_path: Path | None,
quant: bool,
Expand Down Expand Up @@ -1705,6 +1707,9 @@ def perf(
if output is None:
output = generate_output_path(hf_model)

# Refuse to clobber an existing report unless the user opted in.
cli_utils.guard_output(output, overwrite)

# Create config. The raw device/EP request is passed through unchanged;
# PerfBenchmark resolves the concrete device + EP internally (failing fast
# before the build), so the CLI does not pre-resolve here.
Expand Down
5 changes: 5 additions & 0 deletions src/winml/modelkit/commands/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
help="Input ONNX model file",
)
@cli_utils.output_option("Output path (default: {input}_qdq.onnx)")
@cli_utils.overwrite_option()
@cli_utils.precision_option(
default=None,
help_text="Quantization precision: auto, fp16, int4, int8, int16, or w{x}a{y} where "
Expand Down Expand Up @@ -112,6 +113,7 @@ def quantize(
ctx: click.Context,
model: Path,
output: Path | None,
overwrite: bool,
precision: str | None,
samples: int,
method: str,
Expand Down Expand Up @@ -244,6 +246,9 @@ def quantize(
console.print(f"[bold blue]Dataset:[/bold blue] {_dataset_display}")

# ── Shared execution: print header, run, report ──────────────
# Refuse to clobber an existing output unless the user opted in. Runs after
# the per-precision default path is resolved, before any mkdir/work.
cli_utils.guard_output(output, overwrite)
output.parent.mkdir(parents=True, exist_ok=True)
console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Output:[/bold blue] {output}")
Expand Down
6 changes: 6 additions & 0 deletions src/winml/modelkit/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def _print_input_hint(engine: Any) -> None:
)
@cli_utils.format_option(short_flag=False)
@cli_utils.output_option("Write output to file instead of stdout")
@cli_utils.overwrite_option()
@click.option(
"--port",
default=_DEFAULT_PORT,
Expand Down Expand Up @@ -506,6 +507,7 @@ def run(
show_schema: bool,
output_format: cli_utils.OutputFormat,
output: Path | None,
overwrite: bool,
port: int,
connect_host: str,
connect: bool,
Expand Down Expand Up @@ -537,6 +539,10 @@ def run(
if ctx.obj and ctx.obj.get("debug"):
logging.getLogger("winml.modelkit").setLevel(logging.DEBUG)

# Refuse to clobber an existing output file unless the user opted in —
# fail fast before loading the model / running inference.
cli_utils.guard_output(output, overwrite)

# Parse -P/--param entries
pipeline_kwargs: dict[str, Any] = {}
for p in params:
Expand Down
70 changes: 70 additions & 0 deletions src/winml/modelkit/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,76 @@ def output_option(help_text: str, required: bool = False) -> Callable[[F], F]:
return click.option("--output", "-o", **kwargs)


def overwrite_option(optional_message: str | None = None) -> Callable[[F], F]:
"""Add the shared ``--overwrite/--no-overwrite`` toggle (default: no-overwrite).

Output-producing commands default to *not* clobbering an existing output so
a re-run can't silently destroy a previous result. Pair this with
:func:`guard_output`, which performs the actual existence check. The
decorated function receives the value as the ``overwrite`` parameter.

Args:
optional_message: Command-specific note appended after the help text.

Returns:
Decorator function.
"""
help_text = "Overwrite an existing output instead of erroring out"
if optional_message:
help_text = f"{help_text}. {optional_message}"
return click.option(
"--overwrite/--no-overwrite",
"overwrite",
default=False,
show_default=True,
help=help_text,
)


def guard_output(
path: str | Path | None,
overwrite: bool,
*,
label: str = "Output",
) -> None:
"""Fail fast when an output path already exists and ``--overwrite`` was not set.

Shared safety check for every output-producing command so a re-run can't
silently clobber a previous result. Call this *before* any ``mkdir`` /
cleanup / work, with the fully resolved output path (including defaulted
paths like ``{stem}_qdq.onnx``). A ``None`` path (e.g. output goes to
stdout) is a no-op.

Files block when they exist. Directories block only when they exist *and*
are non-empty, so a freshly-created or empty output directory does not
false-trigger.

Args:
path: Resolved output file or directory path, or ``None``.
overwrite: When ``True``, the check is skipped (user opted in).
label: Human-readable noun for the error message (e.g. ``"Output dir"``).

Raises:
click.ClickException: If the path exists (non-empty, for directories)
and ``overwrite`` is ``False``.
"""
if path is None or overwrite:
return
resolved = Path(path)
if not resolved.exists():
return
if resolved.is_dir():
if any(resolved.iterdir()):
raise click.ClickException(
f"{label} directory '{resolved}' already exists and is not empty. "
"Re-run with --overwrite to replace its contents."
)
return
raise click.ClickException(
f"{label} '{resolved}' already exists. Re-run with --overwrite to replace it."
)


def format_option(
choices: list[OutputFormat] | None = None,
default: OutputFormat = "text",
Expand Down
Loading
Loading