diff --git a/langfun/core/llms/anthropic.py b/langfun/core/llms/anthropic.py index 83498e2b..f9890299 100644 --- a/langfun/core/llms/anthropic.py +++ b/langfun/core/llms/anthropic.py @@ -54,8 +54,9 @@ class RateLimits(lf.ModelInfo.RateLimits): @property def max_tokens_per_minute(self) -> int: - return (self.max_input_tokens_per_minute - + self.max_output_tokens_per_minute) + return ( + self.max_input_tokens_per_minute + self.max_output_tokens_per_minute + ) SUPPORTED_MODELS = [ @@ -839,9 +840,7 @@ class Anthropic(rest.REST): """ model: pg.typing.Annotated[ - pg.typing.Enum( - pg.MISSING_VALUE, [m.model_id for m in SUPPORTED_MODELS] - ), + pg.typing.Enum(pg.MISSING_VALUE, [m.model_id for m in SUPPORTED_MODELS]), 'The name of the model to use.', ] @@ -855,10 +854,7 @@ class Anthropic(rest.REST): api_endpoint: str = 'https://api.anthropic.com/v1/messages' - api_version: Annotated[ - str, - 'Anthropic API version.' - ] = '2023-06-01' + api_version: Annotated[str, 'Anthropic API version.'] = '2023-06-01' thinking: Annotated[ bool | None, @@ -912,9 +908,7 @@ def _use_adaptive_thinking(self) -> bool: return self.model is not None and 'claude-opus-4-7' in self.model_id def request( - self, - prompt: lf.Message, - sampling_options: lf.LMSamplingOptions + self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions ) -> dict[str, Any]: """Returns the JSON input for a message.""" request = dict() @@ -1022,14 +1016,27 @@ def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]: def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: message = lf.Message.from_value(json, format='anthropic') - input_tokens = json['usage']['input_tokens'] - output_tokens = json['usage']['output_tokens'] + usage = json.get('usage', {}) + input_tokens = usage.get('input_tokens', 0) + output_tokens = usage.get('output_tokens', 0) + cache_read_tokens = usage.get('cache_read_input_tokens', 0) + cache_creation_tokens = usage.get('cache_creation_input_tokens', 0) + + # Anthropic's input_tokens excludes cache hits. Total prompt tokens + # comprises both cached and uncached segments. + prompt_tokens = input_tokens + cache_read_tokens + cache_creation_tokens + return lf.LMSamplingResult( [lf.LMSample(message)], usage=lf.LMSamplingUsage( - prompt_tokens=input_tokens, + prompt_tokens=prompt_tokens, completion_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, + total_tokens=prompt_tokens + output_tokens, + cached_prompt_tokens=cache_read_tokens, + completion_tokens_details={ + 'cache_creation_input_tokens': cache_creation_tokens, + 'cache_read_input_tokens': cache_read_tokens, + }, ), ) @@ -1118,21 +1125,25 @@ class Claude35(Anthropic): class Claude35Sonnet(Claude35): """Claude 3.5 Sonnet model (latest).""" + model = 'claude-3-5-sonnet-latest' class Claude35Sonnet_20241022(Claude35): # pylint: disable=invalid-name """Claude 3.5 Sonnet model (10/22/2024).""" + model = 'claude-3-5-sonnet-20241022' class Claude35Haiku(Claude35): """Claude 3.5 Haiku model (latest).""" + model = 'claude-3-5-haiku-latest' class Claude35Haiku_20241022(Claude35): # pylint: disable=invalid-name """Claude 3.5 Haiku model (10/22/2024).""" + model = 'claude-3-5-haiku-20241022' @@ -1182,4 +1193,5 @@ def _register_anthropic_models(): if m.provider == 'Anthropic': lf.LanguageModel.register(m.model_id, Anthropic) + _register_anthropic_models()