Skip to content

Commit a536a86

Browse files
committed
feat(context): naive token estimation via tiktoken
1 parent 94fc8dd commit a536a86

File tree

3 files changed

+389
-1
lines changed

3 files changed

+389
-1
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ sagemaker = [
5757
"openai>=1.68.0,<3.0.0", # SageMaker uses OpenAI-compatible interface
5858
]
5959
otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"]
60+
# Rename this extra once compression/context management features land #555
61+
token-estimation = ["tiktoken>=0.7.0,<1.0.0"]
6062
docs = [
6163
"sphinx>=5.0.0,<10.0.0",
6264
"sphinx-rtd-theme>=1.0.0,<4.0.0",

src/strands/models/model.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Abstract base class for Agent model providers."""
22

33
import abc
4+
import json
45
import logging
56
from collections.abc import AsyncGenerator, AsyncIterable
67
from dataclasses import dataclass
@@ -10,7 +11,7 @@
1011

1112
from ..hooks.events import AfterInvocationEvent
1213
from ..plugins.plugin import Plugin
13-
from ..types.content import Messages, SystemContentBlock
14+
from ..types.content import ContentBlock, Messages, SystemContentBlock
1415
from ..types.streaming import StreamEvent
1516
from ..types.tools import ToolChoice, ToolSpec
1617

@@ -21,6 +22,110 @@
2122

2223
T = TypeVar("T", bound=BaseModel)
2324

25+
_DEFAULT_ENCODING = "cl100k_base"
26+
_cached_encoding: Any = None
27+
28+
29+
def _get_encoding() -> Any:
30+
"""Get the default tiktoken encoding, caching to avoid repeated lookups."""
31+
global _cached_encoding
32+
if _cached_encoding is None:
33+
try:
34+
import tiktoken
35+
except ImportError as err:
36+
raise ImportError(
37+
"tiktoken is required for token estimation. "
38+
"Install it with: pip install strands-agents[token-estimation]"
39+
) from err
40+
_cached_encoding = tiktoken.get_encoding(_DEFAULT_ENCODING)
41+
return _cached_encoding
42+
43+
44+
def _count_content_block_tokens(block: ContentBlock, encoding: Any) -> int:
45+
"""Count tokens for a single content block."""
46+
total = 0
47+
48+
if "text" in block:
49+
total += len(encoding.encode(block["text"]))
50+
51+
if "toolUse" in block:
52+
tool_use = block["toolUse"]
53+
total += len(encoding.encode(tool_use.get("name", "")))
54+
try:
55+
total += len(encoding.encode(json.dumps(tool_use.get("input", {}))))
56+
except (TypeError, ValueError):
57+
logger.debug(
58+
"tool_name=<%s> | skipping non-serializable toolUse input for token estimation",
59+
tool_use.get("name", "unknown"),
60+
)
61+
62+
if "toolResult" in block:
63+
tool_result = block["toolResult"]
64+
for item in tool_result.get("content", []):
65+
if "text" in item:
66+
total += len(encoding.encode(item["text"]))
67+
68+
if "reasoningContent" in block:
69+
reasoning = block["reasoningContent"]
70+
if "reasoningText" in reasoning:
71+
reasoning_text = reasoning["reasoningText"]
72+
if "text" in reasoning_text:
73+
total += len(encoding.encode(reasoning_text["text"]))
74+
75+
if "guardContent" in block:
76+
guard = block["guardContent"]
77+
if "text" in guard:
78+
total += len(encoding.encode(guard["text"]["text"]))
79+
80+
if "citationsContent" in block:
81+
citations = block["citationsContent"]
82+
if "content" in citations:
83+
for citation_item in citations["content"]:
84+
if "text" in citation_item:
85+
total += len(encoding.encode(citation_item["text"]))
86+
87+
return total
88+
89+
90+
def _estimate_tokens_with_tiktoken(
91+
messages: Messages,
92+
tool_specs: list[ToolSpec] | None = None,
93+
system_prompt: str | None = None,
94+
system_prompt_content: list[SystemContentBlock] | None = None,
95+
) -> int:
96+
"""Estimate tokens by serializing messages/tools to text and counting with tiktoken.
97+
98+
This is a best-effort fallback for providers that don't expose native counting.
99+
Accuracy varies by model but is sufficient for threshold-based decisions.
100+
"""
101+
encoding = _get_encoding()
102+
total = 0
103+
104+
# Prefer system_prompt_content (structured) over system_prompt (plain string) to avoid double-counting,
105+
# since providers wrap system_prompt into system_prompt_content when both are provided.
106+
if system_prompt_content:
107+
for block in system_prompt_content:
108+
if "text" in block:
109+
total += len(encoding.encode(block["text"]))
110+
elif system_prompt:
111+
total += len(encoding.encode(system_prompt))
112+
113+
for message in messages:
114+
for block in message["content"]:
115+
total += _count_content_block_tokens(block, encoding)
116+
117+
if tool_specs:
118+
for spec in tool_specs:
119+
try:
120+
total += len(encoding.encode(json.dumps(spec)))
121+
except (TypeError, ValueError):
122+
logger.debug(
123+
"tool_name=<%s> | skipping non-serializable tool spec for token estimation",
124+
spec.get("name", "unknown"),
125+
)
126+
127+
return total
128+
24129

25130
@dataclass
26131
class CacheConfig:
@@ -130,6 +235,34 @@ def stream(
130235
"""
131236
pass
132237

238+
def _estimate_tokens(
239+
self,
240+
messages: Messages,
241+
tool_specs: list[ToolSpec] | None = None,
242+
system_prompt: str | None = None,
243+
system_prompt_content: list[SystemContentBlock] | None = None,
244+
) -> int:
245+
"""Estimate token count for the given input before sending to the model.
246+
247+
Used for proactive context management (e.g., triggering compression at a
248+
threshold). This is a naive approximation using tiktoken's cl100k_base encoding.
249+
Accuracy varies by model provider but is typically within 5-10% for most providers.
250+
Not intended for billing or precise quota calculations.
251+
252+
Subclasses may override this method to provide model-specific token counting
253+
using native APIs for improved accuracy.
254+
255+
Args:
256+
messages: List of message objects to estimate tokens for.
257+
tool_specs: List of tool specifications to include in the estimate.
258+
system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided.
259+
system_prompt_content: Structured system prompt content blocks. Takes priority over system_prompt.
260+
261+
Returns:
262+
Estimated total input tokens.
263+
"""
264+
return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt, system_prompt_content)
265+
133266

134267
class _ModelPlugin(Plugin):
135268
"""Plugin that manages model-related lifecycle hooks."""

0 commit comments

Comments
 (0)