Skip to content

Commit 57de9c6

Browse files
committed
fixed unit tests
1 parent 241dfee commit 57de9c6

File tree

4 files changed

+69
-10
lines changed

4 files changed

+69
-10
lines changed

test_mmd.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import ads
2+
from ads.model.datascience_model_group import DataScienceModelGroup
3+
import json
4+
5+
# 1. Set Auth with the specific profile
6+
ads.set_auth(auth="security_token", profile="aryan-ashburn2")
7+
8+
# 2. Get the Model Group created by your deployment (from latest logs)
9+
group_id = "ocid1.datasciencemodelgroupint.oc1.iad.amaaaaaav66vvniafod77ugys4lya3xsq75frfpxzjbjbcipohli6pibik3q"
10+
11+
try:
12+
model_group = DataScienceModelGroup.from_id(group_id)
13+
14+
# 3. Extract the configuration
15+
config_value = model_group.custom_metadata_list.get("MULTI_MODEL_CONFIG").value
16+
config_json = json.loads(config_value)
17+
18+
# 4. Print and Verify
19+
print("\n--- Verification Results ---")
20+
for model in config_json['models']:
21+
print(f"\nModel Name: {model.get('model_name', 'Unknown')}")
22+
print(f"Params: {model['params']}")
23+
24+
if "--max-model-len" in model['params']:
25+
print(">> STATUS: Has SMM Defaults (Expected for 'Llama_Default2')")
26+
else:
27+
print(">> STATUS: Clean / No Defaults (Expected for 'Llama_Clear2')")
28+
29+
except Exception as e:
30+
print(f"Error fetching model group: {e}")
31+

test_mmd2.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import ads
2+
from ads.model.datascience_model_group import DataScienceModelGroup
3+
import json
4+
5+
# 1. Set Auth with the specific profile
6+
ads.set_auth(auth="security_token", profile="aryan-ashburn2")
7+
8+
# 2. Get the Model Group created by your deployment (from latest logs)
9+
group_id = "ocid1.datasciencemodelgroupint.oc1.iad.amaaaaaav66vvniazm2a2ao2u7n65baecxtu6e6lejfvj7gb3ytu3zduq35q"
10+
11+
try:
12+
model_group = DataScienceModelGroup.from_id(group_id)
13+
14+
# 3. Extract the configuration
15+
config_value = model_group.custom_metadata_list.get("MULTI_MODEL_CONFIG").value
16+
config_json = json.loads(config_value)
17+
18+
# 4. Print and Verify
19+
print("\n--- Verification Results ---")
20+
for model in config_json['models']:
21+
print(f"\nModel Name: {model.get('model_name', 'Unknown')}")
22+
print(f"Params: {model['params']}")
23+
24+
if "--max-model-len 1024" in model['params']:
25+
print(">> STATUS: SUCCESS - Custom value used (1024)")
26+
elif "--max-model-len 65536" in model['params']:
27+
print(">> STATUS: FAIL - Defaults merged in (65536)")
28+
else:
29+
print(">> STATUS: FAIL - Param missing entirely")
30+
31+
except Exception as e:
32+
print(f"Error fetching model group: {e}")

tests/unitary/with_extras/aqua/test_common_entities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def test_extract_params_from_env_var_missing_env(self):
196196
}
197197
result = AquaMultiModelRef.model_validate(values)
198198
assert result.env_var == {}
199-
assert result.params == {}
199+
assert result.params is None
200200

201201
def test_all_model_ids_no_finetunes(self):
202202
model = AquaMultiModelRef(model_id="ocid1.model.oc1..base")

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ class TestDataset:
556556
"models": [
557557
{
558558
"env_var": {},
559-
"params": {},
559+
"params": None,
560560
"gpu_count": 2,
561561
"model_id": "test_model_id_1",
562562
"model_name": "test_model_1",
@@ -566,7 +566,7 @@ class TestDataset:
566566
},
567567
{
568568
"env_var": {},
569-
"params": {},
569+
"params": None,
570570
"gpu_count": 2,
571571
"model_id": "test_model_id_2",
572572
"model_name": "test_model_2",
@@ -576,7 +576,7 @@ class TestDataset:
576576
},
577577
{
578578
"env_var": {},
579-
"params": {},
579+
"params": None,
580580
"gpu_count": 2,
581581
"model_id": "test_model_id_3",
582582
"model_name": "test_model_3",
@@ -1258,9 +1258,7 @@ def test_get_deployment(self, mock_get_resource_name):
12581258
mock_get_resource_name.side_effect = lambda param: (
12591259
"log-group-name"
12601260
if param.startswith("ocid1.loggroup")
1261-
else "log-name"
1262-
if param.startswith("ocid1.log")
1263-
else ""
1261+
else "log-name" if param.startswith("ocid1.log") else ""
12641262
)
12651263

12661264
result = self.app.get(model_deployment_id=TestDataset.MODEL_DEPLOYMENT_ID)
@@ -1301,9 +1299,7 @@ def test_get_multi_model_deployment(
13011299
mock_get_resource_name.side_effect = lambda param: (
13021300
"log-group-name"
13031301
if param.startswith("ocid1.loggroup")
1304-
else "log-name"
1305-
if param.startswith("ocid1.log")
1306-
else ""
1302+
else "log-name" if param.startswith("ocid1.log") else ""
13071303
)
13081304

13091305
aqua_multi_model = os.path.join(

0 commit comments

Comments
 (0)