Source code for calute.context.token_counter

# Copyright 2025 The EasyDeL/Calute Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Token counting utilities using provider-specific APIs.

This module provides accurate token counting for various LLM providers.
It supports OpenAI (via tiktoken), Anthropic, Google, and includes
fallback mechanisms when provider-specific libraries are unavailable.

The module includes:

- ProviderTokenCounter: Low-level provider-specific token counting
- SmartTokenCounter: High-level automatic provider detection and counting

Token counting is essential for managing context limits, estimating
costs, and implementing effective compaction strategies.

Example:
    >>> from calute.context import SmartTokenCounter
    >>> counter = SmartTokenCounter(model="gpt-4")
    >>> token_count = counter.count_tokens("Hello, world!")
    >>> remaining = counter.count_remaining_capacity("Hello", max_tokens=4096)
"""

from typing import Any

try:
    import tiktoken

    TIKTOKEN_AVAILABLE = True
except ImportError:
    TIKTOKEN_AVAILABLE = False

try:
    import anthropic  # type:ignore #noqa

    ANTHROPIC_AVAILABLE = True
except ImportError:
    ANTHROPIC_AVAILABLE = False


[docs]class ProviderTokenCounter: """Token counter that uses actual provider APIs. Provides static methods for counting tokens using provider-specific implementations. Supports OpenAI (tiktoken), Anthropic, and Google, with fallback to character-based estimation when libraries are unavailable. This class is designed for internal use; prefer SmartTokenCounter for most use cases as it provides automatic provider detection. """
[docs] @classmethod def count_tokens_for_provider( cls, text: str | list[dict[str, str]], provider: str | None = None, model: str | None = None, llm_client: Any | None = None, ) -> int: """Count tokens using the appropriate provider's API. Args: text: Text string or messages list to count provider: Provider name (openai, anthropic, etc.) model: Model name llm_client: Optional LLM client instance Returns: Token count """ if isinstance(text, list): text = cls._messages_to_text(text) if not provider and model: provider = cls._detect_provider(model) if provider == "openai": return cls._count_openai_tokens(text, model) elif provider == "anthropic": return cls._count_anthropic_tokens(text, model, llm_client) elif provider == "google": return cls._count_google_tokens(text, model, llm_client) else: return cls._count_fallback_tokens(text, model)
@classmethod def _detect_provider(cls, model: str) -> str | None: """Detect provider from model name. Args: model: Model identifier string. Returns: Provider name (openai, anthropic, google, etc.) or None. """ if not model: return None model_lower = model.lower() if "gpt" in model_lower or "o1" in model_lower: return "openai" elif "claude" in model_lower: return "anthropic" elif "gemini" in model_lower or "palm" in model_lower: return "google" elif "llama" in model_lower: return "meta" elif "mistral" in model_lower or "mixtral" in model_lower: return "mistral" return None @classmethod def _messages_to_text(cls, messages: list[dict[str, str]]) -> str: """Convert messages list to plain text for token counting. Args: messages: List of message dictionaries with role and content. Returns: Concatenated text representation of all messages. """ text_parts = [] for msg in messages: role = msg.get("role", "") content = msg.get("content", "") text_parts.append(f"{role}: {content}") return "\n".join(text_parts) @classmethod def _count_openai_tokens(cls, text: str, model: str | None = None) -> int: """Count tokens for OpenAI models using tiktoken. Uses the appropriate encoding based on the model version. Falls back to character estimation if tiktoken is unavailable. Args: text: Text to count tokens for. model: Optional OpenAI model name for encoding selection. Returns: Token count for the given text. """ if not TIKTOKEN_AVAILABLE: return len(text) // 4 try: if model: if "gpt-4o" in model or "o1" in model: encoding = tiktoken.get_encoding("o200k_base") elif "gpt-4" in model: encoding = tiktoken.get_encoding("cl100k_base") elif "gpt-3.5" in model: encoding = tiktoken.get_encoding("cl100k_base") else: encoding = tiktoken.get_encoding("cl100k_base") else: encoding = tiktoken.get_encoding("cl100k_base") return len(encoding.encode(text)) except Exception as e: print(f"Error counting OpenAI tokens: {e}") return len(text) // 4 @classmethod def _count_anthropic_tokens(cls, text: str, model: str | None = None, client: Any | None = None) -> int: """Count tokens for Anthropic models. Attempts to use the Anthropic client's token counting API if available, otherwise falls back to tiktoken or character estimation. Args: text: Text to count tokens for. model: Optional Anthropic model name. client: Optional Anthropic client instance with count_tokens method. Returns: Token count for the given text. """ if ANTHROPIC_AVAILABLE and client: try: if hasattr(client, "count_tokens"): return client.count_tokens(text) except Exception: pass if TIKTOKEN_AVAILABLE: try: encoding = tiktoken.get_encoding("cl100k_base") return len(encoding.encode(text)) except Exception: pass return len(text) // 4 @classmethod def _count_google_tokens(cls, text: str, model: str | None = None, client: Any | None = None) -> int: """Count tokens for Google models (Gemini, PaLM). Uses tiktoken with a 1.1x adjustment factor to approximate Google's tokenization, which tends to produce slightly more tokens than OpenAI's encoding. Args: text: Text to count tokens for. model: Optional Google model name. client: Optional Google AI client instance. Returns: Estimated token count for the given text. """ if TIKTOKEN_AVAILABLE: try: encoding = tiktoken.get_encoding("cl100k_base") token_count = len(encoding.encode(text)) return int(token_count * 1.1) except Exception: pass return len(text) // 4 @classmethod def _count_fallback_tokens(cls, text: str, model: str | None = None) -> int: """Fallback token counting when provider-specific methods aren't available. Uses tiktoken if available, otherwise estimates based on character count (approximately 4 characters per token). Args: text: Text to count tokens for. model: Optional model name (unused, kept for API consistency). Returns: Estimated token count for the given text. """ if TIKTOKEN_AVAILABLE: try: encoding = tiktoken.get_encoding("cl100k_base") return len(encoding.encode(text)) except Exception: pass return len(text) // 4
[docs]class SmartTokenCounter: """Smart token counter that automatically selects the best counting method. Provides a high-level interface for token counting with automatic provider detection based on model name. Supports all major LLM providers and includes utility methods for capacity calculation and compression ratio estimation. Attributes: provider: Detected or specified provider name. model: Model name for token counting. llm_client: Optional LLM client for provider-specific counting. """ def __init__( self, provider: str | None = None, model: str | None = None, llm_client: Any | None = None, ): """Initialize the smart token counter. Args: provider: Provider name model: Model name llm_client: Optional LLM client with token counting capabilities """ self.provider = provider self.model = model self.llm_client = llm_client if not self.provider and self.model: self.provider = ProviderTokenCounter._detect_provider(self.model)
[docs] def count_tokens(self, text: str | list[dict[str, str]]) -> int: """Count tokens using the best available method. Args: text: Text string or list of message dictionaries. Returns: Token count for the given input. """ return ProviderTokenCounter.count_tokens_for_provider( text=text, provider=self.provider, model=self.model, llm_client=self.llm_client, )
[docs] def count_remaining_capacity(self, text: str | list[dict[str, str]], max_tokens: int) -> int: """Calculate remaining token capacity. Args: text: Current text or messages to count. max_tokens: Maximum token limit. Returns: Number of tokens remaining before reaching the limit. """ used_tokens = self.count_tokens(text) return max(0, max_tokens - used_tokens)
[docs] def estimate_compression_ratio(self, original_text: str, compressed_text: str) -> float: """Estimate the compression ratio achieved. Args: original_text: Original text before compression. compressed_text: Text after compression. Returns: Compression ratio as a float between 0.0 and 1.0, where 1.0 means complete compression and 0.0 means no compression. """ original_tokens = self.count_tokens(original_text) compressed_tokens = self.count_tokens(compressed_text) if original_tokens == 0: return 0.0 return 1.0 - (compressed_tokens / original_tokens)