diff --git a/sagemaker-train/src/sagemaker/train/constants.py b/sagemaker-train/src/sagemaker/train/constants.py index 9a07888064..c2545e79ee 100644 --- a/sagemaker-train/src/sagemaker/train/constants.py +++ b/sagemaker-train/src/sagemaker/train/constants.py @@ -49,3 +49,14 @@ "qwen.qwen3-32b-v1:0", "qwen.qwen3-coder-30b-a3b-v1:0" ] + +# Allowed evaluator models for LLM as Judge evaluator with region restrictions +_ALLOWED_EVALUATOR_MODELS = { + "anthropic.claude-3-5-sonnet-20240620-v1:0": ["us-west-2", "us-east-1", "ap-northeast-1"], + "anthropic.claude-3-5-sonnet-20241022-v2:0": ["us-west-2"], + "anthropic.claude-3-haiku-20240307-v1:0": ["us-west-2", "us-east-1", "ap-northeast-1", "eu-west-1"], + "anthropic.claude-3-5-haiku-20241022-v1:0": ["us-west-2"], + "meta.llama3-1-70b-instruct-v1:0": ["us-west-2"], + "mistral.mistral-large-2402-v1:0": ["us-west-2", "us-east-1", "eu-west-1"], + "amazon.nova-pro-v1:0": ["us-east-1"] +} diff --git a/sagemaker-train/src/sagemaker/train/evaluate/llm_as_judge_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/llm_as_judge_evaluator.py index 16f9405838..3be78ebd6d 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/llm_as_judge_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/llm_as_judge_evaluator.py @@ -13,6 +13,7 @@ from .base_evaluator import BaseEvaluator from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature +from sagemaker.train.constants import _ALLOWED_EVALUATOR_MODELS _logger = logging.getLogger(__name__) @@ -144,6 +145,30 @@ def _validate_model_compatibility(cls, v, values): ) return v + + @validator('evaluator_model') + def _validate_evaluator_model(cls, v, values): + """Validate evaluator_model is allowed and check region compatibility.""" + + if v not in _ALLOWED_EVALUATOR_MODELS: + raise ValueError( + f"Invalid evaluator_model '{v}'. " + f"Allowed models are: {list(_ALLOWED_EVALUATOR_MODELS.keys())}" + ) + + # Get current region from session + session = values.get('sagemaker_session') + if session and hasattr(session, 'boto_region_name'): + current_region = session.boto_region_name + allowed_regions = _ALLOWED_EVALUATOR_MODELS[v] + + if current_region not in allowed_regions: + raise ValueError( + f"Evaluator model '{v}' is not available in region '{current_region}'. " + f"Available regions for this model: {allowed_regions}" + ) + + return v def _process_builtin_metrics(self, metrics: Optional[List[str]]) -> List[str]: """Process builtin metrics by removing 'Builtin.' prefix if present. diff --git a/sagemaker-train/tests/unit/train/evaluate/test_llm_as_judge_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_llm_as_judge_evaluator.py index 283e6723bf..5af23f7960 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_llm_as_judge_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_llm_as_judge_evaluator.py @@ -751,3 +751,116 @@ def test_llm_as_judge_evaluator_with_mlflow_names(mock_artifact, mock_resolve): assert evaluator.mlflow_experiment_name == "my-experiment" assert evaluator.mlflow_run_name == "my-run" + + +@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model') +@patch('sagemaker.core.resources.Artifact') +def test_llm_as_judge_evaluator_valid_evaluator_models(mock_artifact, mock_resolve): + """Test LLMAsJudgeEvaluator with valid evaluator models.""" + valid_models = [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-5-sonnet-20241022-v2:0", + "anthropic.claude-3-haiku-20240307-v1:0", + "anthropic.claude-3-5-haiku-20241022-v1:0", + "meta.llama3-1-70b-instruct-v1:0", + "mistral.mistral-large-2402-v1:0", + ] + + mock_info = Mock() + mock_info.base_model_name = DEFAULT_MODEL + mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN + mock_info.source_model_package_arn = None + mock_resolve.return_value = mock_info + + mock_artifact.get_all.return_value = iter([]) + mock_artifact_instance = Mock() + mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN + mock_artifact.create.return_value = mock_artifact_instance + + mock_session = Mock() + mock_session.boto_region_name = "us-west-2" # Region where all models including nova-pro are available + mock_session.boto_session = Mock() + mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE + + for model in valid_models: + evaluator = LLMAsJudgeEvaluator( + model=DEFAULT_MODEL, + evaluator_model=model, + dataset=DEFAULT_DATASET, + builtin_metrics=["Correctness"], + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, + sagemaker_session=mock_session, + ) + assert evaluator.evaluator_model == model + + +@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model') +@patch('sagemaker.core.resources.Artifact') +def test_llm_as_judge_evaluator_invalid_evaluator_model(mock_artifact, mock_resolve): + """Test LLMAsJudgeEvaluator raises error for invalid evaluator model.""" + mock_info = Mock() + mock_info.base_model_name = DEFAULT_MODEL + mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN + mock_info.source_model_package_arn = None + mock_resolve.return_value = mock_info + + mock_artifact.get_all.return_value = iter([]) + mock_artifact_instance = Mock() + mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN + mock_artifact.create.return_value = mock_artifact_instance + + mock_session = Mock() + mock_session.boto_region_name = DEFAULT_REGION + mock_session.boto_session = Mock() + mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE + + with pytest.raises(ValidationError) as exc_info: + LLMAsJudgeEvaluator( + model=DEFAULT_MODEL, + evaluator_model="invalid-model", + dataset=DEFAULT_DATASET, + builtin_metrics=["Correctness"], + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, + sagemaker_session=mock_session, + ) + assert "Invalid evaluator_model 'invalid-model'" in str(exc_info.value) + + +@patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session') +@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model') +@patch('sagemaker.core.resources.Artifact') +def test_llm_as_judge_evaluator_region_restriction(mock_artifact, mock_resolve, mock_get_session): + """Test LLMAsJudgeEvaluator raises error for model not available in region.""" + mock_info = Mock() + mock_info.base_model_name = DEFAULT_MODEL + mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN + mock_info.source_model_package_arn = None + mock_resolve.return_value = mock_info + + mock_artifact.get_all.return_value = iter([]) + mock_artifact_instance = Mock() + mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN + mock_artifact.create.return_value = mock_artifact_instance + + mock_session = Mock() + mock_session.boto_region_name = "eu-central-1" # Region not supported for nova-pro + mock_session.boto_session = Mock() + mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE + mock_get_session.return_value = mock_session + + with pytest.raises(ValidationError) as exc_info: + LLMAsJudgeEvaluator( + model=DEFAULT_MODEL, + evaluator_model="amazon.nova-pro-v1:0", + dataset=DEFAULT_DATASET, + builtin_metrics=["Correctness"], + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, + sagemaker_session=mock_session, + ) + assert "not available in region" in str(exc_info.value)