66import os
77import traceback
88from dataclasses import fields
9- from typing import Dict , Optional , Union
9+ from typing import Any , Dict , Optional , Union
1010
1111import oci
1212from oci .data_science .models import UpdateModelDetails , UpdateModelProvenanceDetails
1313
1414from ads import set_auth
1515from ads .aqua import logger
16+ from ads .aqua .common .entities import ModelConfigResult
1617from ads .aqua .common .enums import ConfigFolder , Tags
1718from ads .aqua .common .errors import AquaRuntimeError , AquaValueError
1819from ads .aqua .common .utils import (
@@ -272,24 +273,24 @@ def get_config(
272273 model_id : str ,
273274 config_file_name : str ,
274275 config_folder : Optional [str ] = ConfigFolder .CONFIG ,
275- ) -> Dict :
276- """Gets the config for the given Aqua model.
276+ ) -> ModelConfigResult :
277+ """
278+ Gets the configuration for the given Aqua model along with the model details.
277279
278280 Parameters
279281 ----------
280- model_id: str
282+ model_id : str
281283 The OCID of the Aqua model.
282- config_file_name: str
283- name of the config file
284- config_folder: (str, optional):
285- subfolder path where config_file_name needs to be searched
286- Defaults to `ConfigFolder.CONFIG`.
287- When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT`
284+ config_file_name : str
285+ The name of the configuration file.
286+ config_folder : Optional[str]
287+ The subfolder path where config_file_name is searched.
288+ Defaults to ConfigFolder.CONFIG. For model artifact directories, use ConfigFolder.ARTIFACT.
288289
289290 Returns
290291 -------
291- Dict:
292- A dict of allowed configs .
292+ ModelConfigResult
293+ A Pydantic model containing the model_details (extracted from OCI) and the config dictionary .
293294 """
294295 config_folder = config_folder or ConfigFolder .CONFIG
295296 oci_model = self .ds_client .get_model (model_id ).data
@@ -301,11 +302,11 @@ def get_config(
301302 if oci_model .freeform_tags
302303 else False
303304 )
304-
305305 if not oci_aqua :
306- raise AquaRuntimeError (f"Target model { oci_model .id } is not Aqua model." )
306+ raise AquaRuntimeError (f"Target model { oci_model .id } is not an Aqua model." )
307+
308+ config : Dict [str , Any ] = {}
307309
308- config = {}
309310 # if the current model has a service model tag, then
310311 if Tags .AQUA_SERVICE_MODEL_TAG in oci_model .freeform_tags :
311312 base_model_ocid = oci_model .freeform_tags [Tags .AQUA_SERVICE_MODEL_TAG ]
@@ -325,7 +326,7 @@ def get_config(
325326 logger .debug (
326327 f"Failed to get artifact path from custom metadata for the model: { model_id } "
327328 )
328- return config
329+ return ModelConfigResult ( config = config , model_details = oci_model )
329330
330331 config_path = os .path .join (os .path .dirname (artifact_path ), config_folder )
331332 if not is_path_exists (config_path ):
@@ -350,9 +351,8 @@ def get_config(
350351 f"{ config_file_name } is not available for the model: { model_id } . "
351352 f"Check if the custom metadata has the artifact path set."
352353 )
353- return config
354354
355- return config
355+ return ModelConfigResult ( config = config , model_details = oci_model )
356356
357357 @property
358358 def telemetry (self ):
@@ -374,9 +374,11 @@ def build_cli(self) -> str:
374374 """
375375 cmd = f"ads aqua { self ._command } "
376376 params = [
377- f"--{ field .name } { json .dumps (getattr (self , field .name ))} "
378- if isinstance (getattr (self , field .name ), dict )
379- else f"--{ field .name } { getattr (self , field .name )} "
377+ (
378+ f"--{ field .name } { json .dumps (getattr (self , field .name ))} "
379+ if isinstance (getattr (self , field .name ), dict )
380+ else f"--{ field .name } { getattr (self , field .name )} "
381+ )
380382 for field in fields (self .__class__ )
381383 if getattr (self , field .name ) is not None
382384 ]
0 commit comments