diff --git a/src/svsbench/search.py b/src/svsbench/search.py index 5114760..9976050 100644 --- a/src/svsbench/search.py +++ b/src/svsbench/search.py @@ -7,6 +7,7 @@ import sys import time from pathlib import Path +from typing import Final import numpy as np import svs @@ -17,6 +18,15 @@ logger = logging.getLogger(__file__) +STR_TO_CALIBRATE_SEARCH_BUFFER: Final[ + dict[str, svs.VamanaSearchBufferOptimization] +] = { + "disable": svs.VamanaSearchBufferOptimization.Disable, + "all": svs.VamanaSearchBufferOptimization.All, + "roionly": svs.VamanaSearchBufferOptimization.ROIOnly, + "roituneup": svs.VamanaSearchBufferOptimization.ROITuneUp, +} + def _read_args(argv: list[str] | None = None) -> argparse.Namespace: """Read command line arguments.""" @@ -111,6 +121,12 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace: action="store_true", help="Load from static index", ) + parser.add_argument("--no_calibrate_prefetchers", action="store_true") + parser.add_argument( + "--calibrate_search_buffer", + choices=STR_TO_CALIBRATE_SEARCH_BUFFER.keys(), + default="all", + ) return parser.parse_args(argv) @@ -141,6 +157,8 @@ def search( calibration_ground_truth_path: Path | None = None, load_from_static: bool = False, lvq_strategy: svs.LVQStrategy | None = None, + train_prefetchers: bool = True, + search_buffer_optimization: svs.VamanaSearchBufferOptimization = svs.VamanaSearchBufferOptimization.All, ): logger.info({"search_args": locals()}) logger.info(utils.read_system_config()) @@ -207,8 +225,17 @@ def search( else: calibration_query = query calibration_ground_truth = ground_truth + calibration_parameters = svs.VamanaCalibrationParameters() + calibration_parameters.search_buffer_optimization = ( + search_buffer_optimization + ) + calibration_parameters.train_prefetchers = train_prefetchers index.experimental_calibrate( - calibration_query, calibration_ground_truth, count, recall + calibration_query, + calibration_ground_truth, + count, + recall, + calibration_parameters, ) logger.info( { @@ -285,10 +312,12 @@ def search( "search_results": { "qps": qps, "qps_mean": np.mean(qps), - "qps_rsd": np.std(qps, ddof=min(1, num_rep - 1)) / np.mean(qps), + "qps_rsd": np.std(qps, ddof=min(1, num_rep - 1)) + / np.mean(qps), "p95": p95s, "p95_mean": np.mean(p95s), - "p95_rsd": np.std(p95s, ddof=min(1, num_rep - 1)) / np.mean(p95s), + "p95_rsd": np.std(p95s, ddof=min(1, num_rep - 1)) + / np.mean(p95s), "search_parameters": { "search_window_size": index.search_parameters.buffer_config.search_window_size, "search_buffer_capacity": index.search_parameters.buffer_config.search_buffer_capacity, @@ -338,6 +367,10 @@ def main(argv: str | None = None) -> None: calibration_ground_truth_path=args.calibration_ground_truth_file, load_from_static=args.load_from_static, lvq_strategy=consts.STR_TO_LVQ_STRATEGY.get(args.lvq_strategy, None), + train_prefetchers=not args.no_calibrate_prefetchers, + search_buffer_optimization=STR_TO_CALIBRATE_SEARCH_BUFFER[ + args.calibrate_search_buffer + ], )