diff --git a/src/winml/modelkit/commands/analyze.py b/src/winml/modelkit/commands/analyze.py index 07c618f6e..7380fa149 100644 --- a/src/winml/modelkit/commands/analyze.py +++ b/src/winml/modelkit/commands/analyze.py @@ -744,6 +744,7 @@ def _build_runtime_debug_output_path(model_path: Path, ep_name: str, device_name @cli_utils.verbosity_options() @cli_utils.build_config_option() @cli_utils.output_option("Save JSON output to file") +@cli_utils.overwrite_option(optional_message="Applies to both --output and --optim-config.") @click.option( "--information/--no-information", default=True, @@ -790,6 +791,7 @@ def analyze( ep: EPNameOrAlias | Literal["all", "auto"] | None, device: str | None, output: Path | None, + overwrite: bool, information: bool, output_format: cli_utils.OutputFormat, verbose: int, @@ -833,6 +835,11 @@ def analyze( verbose, quiet = cli_utils.resolve_verbosity(ctx, verbose, quiet) configure_logging(verbosity=verbose, quiet=quiet) + # Refuse to clobber existing outputs unless the user opted in — fail fast + # before analysis runs. Guards both result JSON and the optim-config dump. + cli_utils.guard_output(output, overwrite) + cli_utils.guard_output(optim_config, overwrite, label="Optimization config") + try: from ..analyze import ONNXStaticAnalyzer @@ -958,9 +965,7 @@ def analyze( sys.exit(2) compatible_eps = resolve_eps(ref_device) if not compatible_eps: - logger.error( - "No execution provider is available for device '%s'.", ref_device - ) + logger.error("No execution provider is available for device '%s'.", ref_device) sys.exit(2) eps = [compatible_eps[0]] else: diff --git a/src/winml/modelkit/commands/catalog.py b/src/winml/modelkit/commands/catalog.py index 4ce102a2a..4839aaf3c 100644 --- a/src/winml/modelkit/commands/catalog.py +++ b/src/winml/modelkit/commands/catalog.py @@ -384,6 +384,7 @@ def _save_json(data: Any, path: Path) -> None: optional_message="If not specified, shows all devices", ) @cli_utils.output_option("Save results to a JSON file.") +@cli_utils.overwrite_option() @cli_utils.format_option() def catalog( model_type: str | None, @@ -391,6 +392,7 @@ def catalog( ep: EPNameOrAlias | None, device: str | None, output: Path | None, + overwrite: bool, output_format: cli_utils.OutputFormat, ) -> None: r"""Browse WinML CLI's curated built-in model catalog. @@ -435,4 +437,5 @@ def catalog( _output_list(models, ep_col_header=ep_col_header, ep_col_fn=ep_col_fn) if output is not None: + cli_utils.guard_output(output, overwrite) _save_json(models, output) diff --git a/src/winml/modelkit/commands/compile.py b/src/winml/modelkit/commands/compile.py index 895d6a4d8..99616e77e 100644 --- a/src/winml/modelkit/commands/compile.py +++ b/src/winml/modelkit/commands/compile.py @@ -51,6 +51,7 @@ "EP context (weight sharing). Required unless --list.", ) @cli_utils.output_option("Output file path (e.g., model_compiled.onnx)") +@cli_utils.overwrite_option() @click.option( "--output-dir", type=click.Path(path_type=Path), @@ -108,6 +109,7 @@ def compile( model: tuple[Path, ...], output: Path | None, output_dir: Path | None, + overwrite: bool, device: str, ep: EPNameOrAlias | None, validate: bool, @@ -243,6 +245,9 @@ def compile( console.print(f"[bold blue]SDK root:[/bold blue] {qnn_sdk_root}") # Resolve output path: -o (file) takes precedence over --output-dir resolved_output = output or output_dir + # Refuse to clobber an existing output unless the user opted in. A file + # blocks when it exists; a directory blocks only when non-empty. + cli_utils.guard_output(resolved_output, overwrite) if output: console.print(f"[bold blue]Output:[/bold blue] {output}") elif output_dir: diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index ab0b58e9b..9414e6e8b 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -111,6 +111,7 @@ def _apply_stage_overrides(cfg: Any, *, no_quant: bool, no_compile: bool) -> Non ) @cli_utils.precision_option() @cli_utils.output_option("Output JSON file path (default: stdout)") +@cli_utils.overwrite_option() @click.option( "--library", "library_name", @@ -141,6 +142,7 @@ def config( ep: EPNameOrAlias | None, precision: str, output: Path | None, + overwrite: bool, library_name: str, verbose: int, quiet: bool, @@ -310,6 +312,7 @@ def config( no_quant=not quant, no_compile=no_compile, output=output, + overwrite=overwrite, console=console, ) return @@ -444,6 +447,7 @@ def config( config_json = json.dumps(output_data, indent=2) if output: + cli_utils.guard_output(output, overwrite) output.parent.mkdir(parents=True, exist_ok=True) output.write_text(config_json) suffix = f" [dim]({_n_modules} submodules)[/dim]" if _n_modules else "" @@ -516,6 +520,7 @@ def _generate_pipeline_configs( no_quant: bool, no_compile: bool, output: Path | None, + overwrite: bool, console: Any, ) -> None: """Generate and save one config file per pipeline sub-component.""" @@ -546,6 +551,7 @@ def _generate_pipeline_configs( if output: suffixed = output.with_stem(f"{output.stem}_{component_name}") + cli_utils.guard_output(suffixed, overwrite) suffixed.parent.mkdir(parents=True, exist_ok=True) tmp = suffixed.with_suffix(".json.tmp") tmp.write_text(config_json) diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index fb13096c2..ca7a91e20 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -140,6 +140,7 @@ help='Path to a JSON file with label mapping: {"label_name": id}.', ) @cli_utils.output_option("Output JSON file path.") +@cli_utils.overwrite_option() @click.option( "--dataset-script", type=str, @@ -194,6 +195,7 @@ def eval( column: tuple[str, ...], label_mapping_path: Path | None, output: Path | None, + overwrite: bool, output_format: cli_utils.OutputFormat, verbose: int, quiet: bool, @@ -250,6 +252,10 @@ def eval( _resolve_label_mapping(cfg) _run_dataset_script(cfg, trust_remote_code) + # Refuse to clobber an existing report unless the user opted in — fail fast + # before the (expensive) evaluation runs. + cli_utils.guard_output(cfg.output_path, overwrite) + if cfg.model_path is not None and cfg.precision != "auto": logger.warning( "--precision %s is ignored for pre-built ONNX inputs " diff --git a/src/winml/modelkit/commands/export.py b/src/winml/modelkit/commands/export.py index 4cae34a86..ee2b93bfa 100644 --- a/src/winml/modelkit/commands/export.py +++ b/src/winml/modelkit/commands/export.py @@ -70,6 +70,7 @@ def _delete_onnx_with_external_data(onnx_path: Path) -> None: help="HuggingFace model name or local path (e.g., prajjwal1/bert-tiny)", ) @cli_utils.output_option("Output ONNX file path (e.g., model.onnx)", required=True) +@cli_utils.overwrite_option() @click.option( "--with-report/--no-with-report", default=False, @@ -136,6 +137,7 @@ def export( ctx: click.Context, model: str, output: Path, + overwrite: bool, verbose: int, quiet: bool, with_report: bool, @@ -237,8 +239,11 @@ def export( if export_config: console.print(f"[bold blue]Export config:[/bold blue] {export_config}") - # Create output directory if needed + # Refuse to clobber an existing output unless the user opted in. output_path = Path(output) + cli_utils.guard_output(output_path, overwrite) + + # Create output directory if needed output_path.parent.mkdir(parents=True, exist_ok=True) # Load export configuration from JSON file if provided, or create default diff --git a/src/winml/modelkit/commands/optimize.py b/src/winml/modelkit/commands/optimize.py index 287f9a423..e0d8747f2 100644 --- a/src/winml/modelkit/commands/optimize.py +++ b/src/winml/modelkit/commands/optimize.py @@ -173,6 +173,7 @@ def capability_options(func: F) -> F: help="Input ONNX model file", ) @cli_utils.output_option("Output path (default: {input}_opt.onnx)") +@cli_utils.overwrite_option() @click.option( "--config", "-c", @@ -189,6 +190,7 @@ def optimize( list_rewrites: bool, model: Path | None, output: Path | None, + overwrite: bool, config: Path | None, verbose: int, quiet: bool, @@ -346,6 +348,9 @@ def optimize( if output is None: output = model.parent / f"{model.stem}_opt.onnx" + # Refuse to clobber an existing output unless the user opted in. + cli_utils.guard_output(output, overwrite) + # Show info console.print(f"[bold blue]Input:[/bold blue] {model}") console.print(f"[bold blue]Output:[/bold blue] {output}") diff --git a/src/winml/modelkit/commands/perf.py b/src/winml/modelkit/commands/perf.py index d64dda258..6c92461d5 100644 --- a/src/winml/modelkit/commands/perf.py +++ b/src/winml/modelkit/commands/perf.py @@ -1473,6 +1473,7 @@ def _run_simple_loop( "Output JSON file path. Defaults to " "'~/.cache/winml/perf/[/]/.json'." ) +@cli_utils.overwrite_option() @click.option( "--batch-size", type=int, @@ -1555,6 +1556,7 @@ def perf( ep: EPNameOrAlias | None, ep_options: tuple[str, ...], output: Path | None, + overwrite: bool, batch_size: int, shape_config_path: Path | None, quant: bool, @@ -1705,6 +1707,9 @@ def perf( if output is None: output = generate_output_path(hf_model) + # Refuse to clobber an existing report unless the user opted in. + cli_utils.guard_output(output, overwrite) + # Create config. The raw device/EP request is passed through unchanged; # PerfBenchmark resolves the concrete device + EP internally (failing fast # before the build), so the CLI does not pre-resolve here. diff --git a/src/winml/modelkit/commands/quantize.py b/src/winml/modelkit/commands/quantize.py index 902ea3144..8e529697a 100644 --- a/src/winml/modelkit/commands/quantize.py +++ b/src/winml/modelkit/commands/quantize.py @@ -48,6 +48,7 @@ help="Input ONNX model file", ) @cli_utils.output_option("Output path (default: {input}_qdq.onnx)") +@cli_utils.overwrite_option() @cli_utils.precision_option( default=None, help_text="Quantization precision: auto, fp16, int4, int8, int16, or w{x}a{y} where " @@ -112,6 +113,7 @@ def quantize( ctx: click.Context, model: Path, output: Path | None, + overwrite: bool, precision: str | None, samples: int, method: str, @@ -244,6 +246,9 @@ def quantize( console.print(f"[bold blue]Dataset:[/bold blue] {_dataset_display}") # ── Shared execution: print header, run, report ────────────── + # Refuse to clobber an existing output unless the user opted in. Runs after + # the per-precision default path is resolved, before any mkdir/work. + cli_utils.guard_output(output, overwrite) output.parent.mkdir(parents=True, exist_ok=True) console.print(f"[bold blue]Input:[/bold blue] {model}") console.print(f"[bold blue]Output:[/bold blue] {output}") diff --git a/src/winml/modelkit/commands/run.py b/src/winml/modelkit/commands/run.py index 193809b4e..f8e86797f 100644 --- a/src/winml/modelkit/commands/run.py +++ b/src/winml/modelkit/commands/run.py @@ -470,6 +470,7 @@ def _print_input_hint(engine: Any) -> None: ) @cli_utils.format_option(short_flag=False) @cli_utils.output_option("Write output to file instead of stdout") +@cli_utils.overwrite_option() @click.option( "--port", default=_DEFAULT_PORT, @@ -506,6 +507,7 @@ def run( show_schema: bool, output_format: cli_utils.OutputFormat, output: Path | None, + overwrite: bool, port: int, connect_host: str, connect: bool, @@ -537,6 +539,10 @@ def run( if ctx.obj and ctx.obj.get("debug"): logging.getLogger("winml.modelkit").setLevel(logging.DEBUG) + # Refuse to clobber an existing output file unless the user opted in — + # fail fast before loading the model / running inference. + cli_utils.guard_output(output, overwrite) + # Parse -P/--param entries pipeline_kwargs: dict[str, Any] = {} for p in params: diff --git a/src/winml/modelkit/utils/cli.py b/src/winml/modelkit/utils/cli.py index 312fff9cc..77802d79e 100644 --- a/src/winml/modelkit/utils/cli.py +++ b/src/winml/modelkit/utils/cli.py @@ -156,6 +156,76 @@ def output_option(help_text: str, required: bool = False) -> Callable[[F], F]: return click.option("--output", "-o", **kwargs) +def overwrite_option(optional_message: str | None = None) -> Callable[[F], F]: + """Add the shared ``--overwrite/--no-overwrite`` toggle (default: no-overwrite). + + Output-producing commands default to *not* clobbering an existing output so + a re-run can't silently destroy a previous result. Pair this with + :func:`guard_output`, which performs the actual existence check. The + decorated function receives the value as the ``overwrite`` parameter. + + Args: + optional_message: Command-specific note appended after the help text. + + Returns: + Decorator function. + """ + help_text = "Overwrite an existing output instead of erroring out" + if optional_message: + help_text = f"{help_text}. {optional_message}" + return click.option( + "--overwrite/--no-overwrite", + "overwrite", + default=False, + show_default=True, + help=help_text, + ) + + +def guard_output( + path: str | Path | None, + overwrite: bool, + *, + label: str = "Output", +) -> None: + """Fail fast when an output path already exists and ``--overwrite`` was not set. + + Shared safety check for every output-producing command so a re-run can't + silently clobber a previous result. Call this *before* any ``mkdir`` / + cleanup / work, with the fully resolved output path (including defaulted + paths like ``{stem}_qdq.onnx``). A ``None`` path (e.g. output goes to + stdout) is a no-op. + + Files block when they exist. Directories block only when they exist *and* + are non-empty, so a freshly-created or empty output directory does not + false-trigger. + + Args: + path: Resolved output file or directory path, or ``None``. + overwrite: When ``True``, the check is skipped (user opted in). + label: Human-readable noun for the error message (e.g. ``"Output dir"``). + + Raises: + click.ClickException: If the path exists (non-empty, for directories) + and ``overwrite`` is ``False``. + """ + if path is None or overwrite: + return + resolved = Path(path) + if not resolved.exists(): + return + if resolved.is_dir(): + if any(resolved.iterdir()): + raise click.ClickException( + f"{label} directory '{resolved}' already exists and is not empty. " + "Re-run with --overwrite to replace its contents." + ) + return + raise click.ClickException( + f"{label} '{resolved}' already exists. Re-run with --overwrite to replace it." + ) + + def format_option( choices: list[OutputFormat] | None = None, default: OutputFormat = "text", diff --git a/tests/unit/commands/test_catalog.py b/tests/unit/commands/test_catalog.py index 94dd73a6f..0f41dedb0 100644 --- a/tests/unit/commands/test_catalog.py +++ b/tests/unit/commands/test_catalog.py @@ -255,6 +255,29 @@ def test_catalog_saves_json_file(runner, patched_catalog, tmp_path): assert "size_mb" in first +def test_catalog_existing_output_blocked_without_overwrite(runner, patched_catalog, tmp_path): + """An existing --output is not clobbered unless --overwrite is passed.""" + out = tmp_path / "catalog.json" + out.write_text("ORIGINAL") + result = runner.invoke(catalog, ["--output", str(out)]) + assert result.exit_code != 0 + assert "already exists" in result.output + assert "--overwrite" in result.output + # The original file is left untouched. + assert out.read_text() == "ORIGINAL" + + +def test_catalog_existing_output_replaced_with_overwrite(runner, patched_catalog, tmp_path): + """--overwrite allows replacing an existing output file.""" + out = tmp_path / "catalog.json" + out.write_text("ORIGINAL") + result = runner.invoke(catalog, ["--output", str(out), "--overwrite"]) + assert result.exit_code == 0, result.output + data = json.loads(out.read_text()) + assert isinstance(data, list) + assert len(data) == 4 + + def test_catalog_filter_model_type(runner, patched_catalog, tmp_path): out = tmp_path / "out.json" result = runner.invoke(catalog, ["--model-type", "bert", "--output", str(out)]) diff --git a/tests/unit/commands/test_compile_quantize_flags.py b/tests/unit/commands/test_compile_quantize_flags.py index b1b22a920..fc150dc4f 100644 --- a/tests/unit/commands/test_compile_quantize_flags.py +++ b/tests/unit/commands/test_compile_quantize_flags.py @@ -573,3 +573,142 @@ def fake_quantize(*_args, **_kwargs): assert r.exit_code != 0, r.output assert "not a supported quantization precision" in r.output assert ran["called"] is False + + +class TestOverwriteGuard: + """The shared --overwrite/--no-overwrite guard on quantize (file) and + compile (directory) outputs. Cross-checks the wiring of + ``cli_utils.guard_output`` into real commands.""" + + @staticmethod + def _quantize(args, tmp_path, *, expect_called: bool): + from click.testing import CliRunner + + from winml.modelkit.commands.quantize import quantize as quantize_cmd + + called: dict[str, bool] = {"v": False} + + def fake_quantize(model_path, output_path=None, config=None, **kwargs): + called["v"] = True + result = MagicMock() + result.success = True + result.output_path = output_path + result.nodes_quantized = 0 + result.total_time_seconds = 0.0 + result.errors = [] + return result + + with patch("winml.modelkit.quant.quantize_onnx", side_effect=fake_quantize): + r = CliRunner().invoke(quantize_cmd, args, obj={}, catch_exceptions=False) + assert called["v"] is expect_called, r.output + return r + + def test_quantize_existing_output_blocked(self, tmp_path): + model, _ = TestQuantizeCliConfigPrecedence._setup(tmp_path) + out = tmp_path / "q.onnx" + out.write_text("ORIGINAL") + # quantize_onnx must NOT run; the guard fires before any work (and before + # the quantizer's destructive stale-sidecar cleanup). + r = self._quantize(["-m", str(model), "-o", str(out)], tmp_path, expect_called=False) + assert r.exit_code != 0 + assert "already exists" in r.output + assert "--overwrite" in r.output + assert out.read_text() == "ORIGINAL" + + def test_quantize_existing_output_allowed_with_overwrite(self, tmp_path): + model, _ = TestQuantizeCliConfigPrecedence._setup(tmp_path) + out = tmp_path / "q.onnx" + out.write_text("ORIGINAL") + r = self._quantize( + ["-m", str(model), "-o", str(out), "--overwrite"], tmp_path, expect_called=True + ) + assert r.exit_code == 0, r.output + + def test_quantize_default_derived_output_guarded(self, tmp_path): + """The guard covers the defaulted ``{stem}_qdq.onnx`` path, not just -o.""" + model, _ = TestQuantizeCliConfigPrecedence._setup(tmp_path) + default_out = model.parent / f"{model.stem}_qdq.onnx" + default_out.write_text("ORIGINAL") + r = self._quantize(["-m", str(model)], tmp_path, expect_called=False) + assert r.exit_code != 0 + assert "already exists" in r.output + + @staticmethod + def _compile(args, tmp_path, *, expect_called: bool): + from click.testing import CliRunner + + from winml.modelkit.commands.compile import compile as compile_cmd + + called: dict[str, bool] = {"v": False} + mock_result = MagicMock() + mock_result.success = True + mock_result.output_path = tmp_path / "model_compiled.onnx" + mock_result.compile_time = 1.0 + mock_result.total_time = 1.5 + + def fake_compile(*_args, **_kwargs): + called["v"] = True + return mock_result + + with ( + patch( + "winml.modelkit.commands.compile.resolve_device", + return_value=("npu", ["npu", "gpu", "cpu"]), + ), + patch("winml.modelkit.commands.compile.is_compiled_onnx", return_value=False), + patch("winml.modelkit.compiler.compile_onnx", side_effect=fake_compile), + ): + r = CliRunner().invoke(compile_cmd, args, catch_exceptions=False) + assert called["v"] is expect_called, r.output + return r + + def test_compile_non_empty_output_dir_blocked(self, tmp_path): + model = tmp_path / "model.onnx" + model.write_bytes(b"fake") + out_dir = tmp_path / "out" + out_dir.mkdir() + (out_dir / "stale.onnx").write_bytes(b"old") + r = self._compile( + ["-m", str(model), "--device", "npu", "--ep", "qnn", "--output-dir", str(out_dir)], + tmp_path, + expect_called=False, + ) + assert r.exit_code != 0 + assert "not empty" in r.output + assert "--overwrite" in r.output + + def test_compile_empty_output_dir_ok(self, tmp_path): + """An existing but empty output dir does not trip the guard.""" + model = tmp_path / "model.onnx" + model.write_bytes(b"fake") + out_dir = tmp_path / "out" + out_dir.mkdir() + r = self._compile( + ["-m", str(model), "--device", "npu", "--ep", "qnn", "--output-dir", str(out_dir)], + tmp_path, + expect_called=True, + ) + assert r.exit_code == 0, r.output + + def test_compile_non_empty_output_dir_allowed_with_overwrite(self, tmp_path): + model = tmp_path / "model.onnx" + model.write_bytes(b"fake") + out_dir = tmp_path / "out" + out_dir.mkdir() + (out_dir / "stale.onnx").write_bytes(b"old") + r = self._compile( + [ + "-m", + str(model), + "--device", + "npu", + "--ep", + "qnn", + "--output-dir", + str(out_dir), + "--overwrite", + ], + tmp_path, + expect_called=True, + ) + assert r.exit_code == 0, r.output diff --git a/tests/unit/utils/test_cli.py b/tests/unit/utils/test_cli.py index 28e12f459..5eb2d08a1 100644 --- a/tests/unit/utils/test_cli.py +++ b/tests/unit/utils/test_cli.py @@ -6,6 +6,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import click import pytest from click.testing import CliRunner @@ -13,15 +15,21 @@ from winml.modelkit.utils.cli import ( analyze_option, build_pipeline_extra_kwargs, + guard_output, ignored_build_flags_warning, max_optim_iterations_option, optimize_option, + overwrite_option, parse_ep_options, precision_option, quant_option, ) +if TYPE_CHECKING: + from pathlib import Path + + class TestParseEpOptions: """Tests for parse_ep_options().""" @@ -314,3 +322,86 @@ def test_max_optim_zero_counts_as_set(self) -> None: msg = ignored_build_flags_warning(skip_build_onnx=True, max_optim_iterations=0) assert msg is not None assert "--max-optim-iterations" in msg + + +class TestOverwriteOption: + """Tests for the shared overwrite_option() decorator.""" + + @staticmethod + def _make_cmd(optional_message: str | None = None) -> click.Command: + @click.command() + @overwrite_option(optional_message=optional_message) + def cmd(overwrite: bool) -> None: + click.echo(repr(overwrite)) + + return cmd + + def test_default_is_false(self) -> None: + assert CliRunner().invoke(self._make_cmd(), []).output.strip() == "False" + + def test_overwrite_flag_sets_true(self) -> None: + assert CliRunner().invoke(self._make_cmd(), ["--overwrite"]).output.strip() == "True" + + def test_no_overwrite_flag_sets_false(self) -> None: + assert CliRunner().invoke(self._make_cmd(), ["--no-overwrite"]).output.strip() == "False" + + def test_optional_message_appended(self) -> None: + result = CliRunner().invoke(self._make_cmd(optional_message="Extra note."), ["--help"]) + joined = " ".join(result.output.split()) + assert "Overwrite an existing output" in joined + assert "Extra note." in joined + + +class TestGuardOutput: + """Tests for the shared guard_output() existence check.""" + + def test_none_path_is_noop(self) -> None: + guard_output(None, overwrite=False) # must not raise + + def test_missing_file_is_noop(self, tmp_path: Path) -> None: + guard_output(tmp_path / "nope.onnx", overwrite=False) # must not raise + + def test_existing_file_raises(self, tmp_path: Path) -> None: + f = tmp_path / "model.onnx" + f.write_text("x") + with pytest.raises(click.ClickException) as exc: + guard_output(f, overwrite=False) + assert "already exists" in str(exc.value) + assert "--overwrite" in str(exc.value) + + def test_existing_file_with_overwrite_is_noop(self, tmp_path: Path) -> None: + f = tmp_path / "model.onnx" + f.write_text("x") + guard_output(f, overwrite=True) # must not raise + + def test_empty_dir_is_noop(self, tmp_path: Path) -> None: + d = tmp_path / "out" + d.mkdir() + guard_output(d, overwrite=False) # empty dir must not raise + + def test_non_empty_dir_raises(self, tmp_path: Path) -> None: + d = tmp_path / "out" + d.mkdir() + (d / "artifact.onnx").write_text("x") + with pytest.raises(click.ClickException) as exc: + guard_output(d, overwrite=False) + assert "not empty" in str(exc.value) + + def test_non_empty_dir_with_overwrite_is_noop(self, tmp_path: Path) -> None: + d = tmp_path / "out" + d.mkdir() + (d / "artifact.onnx").write_text("x") + guard_output(d, overwrite=True) # must not raise + + def test_custom_label_in_message(self, tmp_path: Path) -> None: + f = tmp_path / "cfg.json" + f.write_text("x") + with pytest.raises(click.ClickException) as exc: + guard_output(f, overwrite=False, label="Optimization config") + assert "Optimization config" in str(exc.value) + + def test_accepts_str_path(self, tmp_path: Path) -> None: + f = tmp_path / "model.onnx" + f.write_text("x") + with pytest.raises(click.ClickException): + guard_output(str(f), overwrite=False)