5555 "anthropic.claude" ,
5656]
5757
58- # Cache of model IDs that do not support the CountTokens API .
59- _UNSUPPORTED_COUNT_TOKENS_MODELS : set [str ] = set ()
58+ # Cache of model IDs for which CountTokens API calls should be skipped .
59+ _SKIP_COUNT_TOKENS_MODELS : set [str ] = set ()
6060
6161
62- def _clear_unsupported_count_tokens_cache () -> None :
63- """Clear the cache of model IDs that do not support the CountTokens API ."""
64- _UNSUPPORTED_COUNT_TOKENS_MODELS .clear ()
62+ def _clear_skip_count_tokens_cache () -> None :
63+ """Clear the cache of model IDs for which CountTokens API calls should be skipped ."""
64+ _SKIP_COUNT_TOKENS_MODELS .clear ()
6565
6666
6767def _suppress_task_exception (task : "asyncio.Task[None]" ) -> None :
@@ -124,8 +124,8 @@ class BedrockConfig(BaseModelConfig, total=False):
124124 temperature: Controls randomness in generation (higher = more random)
125125 top_p: Controls diversity via nucleus sampling (alternative to temperature)
126126 use_native_token_count: Whether to use the native Bedrock CountTokens API.
127- When True (default) , count_tokens() calls the Bedrock API for accurate counts.
128- When False, skips the API call and uses the local estimator.
127+ When True, count_tokens() calls the Bedrock API for accurate counts.
128+ When False (default) , skips the API call and uses the local estimator.
129129 """
130130
131131 additional_args : dict [str , Any ] | None
@@ -804,12 +804,12 @@ async def count_tokens(
804804 Returns:
805805 Total input token count.
806806 """
807- if self .config .get ("use_native_token_count" ) is False :
807+ if self .config .get ("use_native_token_count" ) is not True :
808808 return await super ().count_tokens (messages , tool_specs , system_prompt , system_prompt_content )
809809
810810 model_id : str = self .config ["model_id" ]
811811
812- if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS :
812+ if model_id in _SKIP_COUNT_TOKENS_MODELS :
813813 return await super ().count_tokens (messages , tool_specs , system_prompt , system_prompt_content )
814814
815815 try :
@@ -839,6 +839,17 @@ async def count_tokens(
839839 return total_tokens
840840 except Exception as e :
841841 if (
842+ isinstance (e , ClientError )
843+ and e .response .get ("Error" , {}).get ("Code" ) == "AccessDeniedException"
844+ ):
845+ logger .warning (
846+ "model_id=<%s> | bedrock:CountTokens permission denied,"
847+ " falling back to heuristic estimation: %s" ,
848+ model_id ,
849+ e ,
850+ )
851+ _SKIP_COUNT_TOKENS_MODELS .add (model_id )
852+ elif (
842853 isinstance (e , ClientError )
843854 and e .response .get ("Error" , {}).get ("Code" ) == "ValidationException"
844855 and "doesn't support counting tokens" in str (e )
@@ -848,7 +859,7 @@ async def count_tokens(
848859 " falling back to estimation" ,
849860 model_id ,
850861 )
851- _UNSUPPORTED_COUNT_TOKENS_MODELS .add (model_id )
862+ _SKIP_COUNT_TOKENS_MODELS .add (model_id )
852863 else :
853864 logger .debug (
854865 "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation" ,
0 commit comments