Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Anthropic prompt caching for Claude 3.7 Sonnet #912

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/codegen/agents/code_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
agent_config: Optional[AgentConfig] = None,
thread_id: Optional[str] = None,
logger: Optional[ExternalLogger] = None,
enable_prompt_caching: bool = True,
**kwargs,
):
"""Initialize a CodeAgent.
Expand All @@ -58,13 +59,22 @@
tools: Additional tools to use
tags: Tags to add to the agent trace. Must be of the same type.
metadata: Metadata to use for the agent. Must be a dictionary.
agent_config: Configuration for the agent
thread_id: Optional thread ID for message history
logger: Optional external logger
enable_prompt_caching: Whether to enable prompt caching for Anthropic models
**kwargs: Additional LLM configuration options. Supported options:
- temperature: Temperature parameter (0-1)
- top_p: Top-p sampling parameter (0-1)
- top_k: Top-k sampling parameter (>= 1)
- max_tokens: Maximum number of tokens to generate
"""
self.codebase = codebase

# Add prompt caching to kwargs if using Anthropic
if model_provider == "anthropic" and enable_prompt_caching:
kwargs["enable_prompt_caching"] = True

self.agent = create_codebase_agent(
self.codebase,
model_provider=model_provider,
Expand All @@ -87,14 +97,14 @@
print(f"Using LangSmith project: {self.project_name}")

# Store SWEBench metadata if provided
self.run_id = metadata.get("run_id")

Check failure on line 100 in src/codegen/agents/code_agent.py

View workflow job for this annotation

GitHub Actions / mypy

error: Item "None" of "dict[Any, Any] | None" has no attribute "get" [union-attr]
self.instance_id = metadata.get("instance_id")

Check failure on line 101 in src/codegen/agents/code_agent.py

View workflow job for this annotation

GitHub Actions / mypy

error: Item "None" of "dict[Any, Any] | None" has no attribute "get" [union-attr]
# Extract difficulty value from "difficulty_X" format
difficulty_str = metadata.get("difficulty", "")

Check failure on line 103 in src/codegen/agents/code_agent.py

View workflow job for this annotation

GitHub Actions / mypy

error: Item "None" of "dict[Any, Any] | None" has no attribute "get" [union-attr]
self.difficulty = int(difficulty_str.split("_")[1]) if difficulty_str and "_" in difficulty_str else None

# Initialize tags for agent trace
self.tags = [*tags, self.model_name]

Check failure on line 107 in src/codegen/agents/code_agent.py

View workflow job for this annotation

GitHub Actions / mypy

error: Expected iterable as variadic argument [misc]

# set logger if provided
self.logger = logger
Expand All @@ -103,7 +113,7 @@
self.metadata = {
"project": self.project_name,
"model": self.model_name,
**metadata,

Check failure on line 116 in src/codegen/agents/code_agent.py

View workflow job for this annotation

GitHub Actions / mypy

error: Unpacked dict entry 2 has incompatible type "dict[Any, Any] | None"; expected "SupportsKeysAndGetItem[Any, Any]" [dict-item]
}

def run(self, prompt: str, image_urls: Optional[list[str]] = None) -> str:
Expand All @@ -128,24 +138,24 @@
# Prepare content with prompt and images if provided
content = [{"type": "text", "text": prompt}]
if image_urls:
content += [{"type": "image_url", "image_url": {"url": image_url}} for image_url in image_urls]

Check failure on line 141 in src/codegen/agents/code_agent.py

View workflow job for this annotation

GitHub Actions / mypy

error: Dict entry 1 has incompatible type "str": "dict[str, str]"; expected "str": "str" [dict-item]

config = RunnableConfig(configurable={"thread_id": self.thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=200)
# we stream the steps instead of invoke because it allows us to access intermediate nodes

stream = self.agent.stream({"messages": [HumanMessage(content=content)]}, config=config, stream_mode="values")

Check failure on line 146 in src/codegen/agents/code_agent.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument "content" to "HumanMessage" has incompatible type "list[dict[str, str]]"; expected "str | list[str | dict[Any, Any]]" [arg-type]

_tracer = MessageStreamTracer(logger=self.logger)

# Process the stream with the tracer
traced_stream = _tracer.process_stream(stream)

Check failure on line 151 in src/codegen/agents/code_agent.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 to "process_stream" of "MessageStreamTracer" has incompatible type "Iterator[dict[str, Any] | Any]"; expected "Generator[Any, None, None]" [arg-type]

# Keep track of run IDs from the stream
run_ids = []

for s in traced_stream:
if len(s["messages"]) == 0 or isinstance(s["messages"][-1], HumanMessage):
message = HumanMessage(content=content)

Check failure on line 158 in src/codegen/agents/code_agent.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument "content" to "HumanMessage" has incompatible type "list[dict[str, str]]"; expected "str | list[str | dict[Any, Any]]" [arg-type]
else:
message = s["messages"][-1]

Expand Down
18 changes: 16 additions & 2 deletions src/codegen/extensions/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

max_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate.", ge=1)

enable_prompt_caching: bool = Field(default=False, description="Whether to enable prompt caching for Anthropic models. Only works with Claude 3.5 Sonnet and Claude 3.0 Haiku.")

def __init__(self, model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", **kwargs: Any) -> None:
"""Initialize the LLM.

Expand All @@ -43,13 +45,14 @@
- top_p: Top-p sampling parameter (0-1)
- top_k: Top-k sampling parameter (>= 1)
- max_tokens: Maximum number of tokens to generate
- enable_prompt_caching: Whether to enable prompt caching (Anthropic only)
"""
# Set model provider and name before calling super().__init__
kwargs["model_provider"] = model_provider
kwargs["model_name"] = model_name

# Filter out unsupported kwargs
supported_kwargs = {"model_provider", "model_name", "temperature", "top_p", "top_k", "max_tokens", "callbacks", "tags", "metadata"}
supported_kwargs = {"model_provider", "model_name", "temperature", "top_p", "top_k", "max_tokens", "callbacks", "tags", "metadata", "enable_prompt_caching"}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in supported_kwargs}

super().__init__(**filtered_kwargs)
Expand Down Expand Up @@ -90,7 +93,18 @@
msg = "ANTHROPIC_API_KEY not found in environment. Please set it in your .env file or environment variables."
raise ValueError(msg)
max_tokens = 16384 if "claude-3-7" in self.model_name else 8192
return ChatAnthropic(**self._get_model_kwargs(), max_tokens=max_tokens, max_retries=10, timeout=1000)

# Add prompt caching if enabled
extra_kwargs = {}
if self.enable_prompt_caching:
# Only enable for supported models
if "claude-3-5-sonnet" in self.model_name or "claude-3-haiku" in self.model_name:
extra_kwargs["anthropic_beta"] = "prompt-caching-2024-07-31"
print("Prompt caching enabled for Anthropic model")
else:
print(f"Warning: Prompt caching requested but not supported for model {self.model_name}")

return ChatAnthropic(**self._get_model_kwargs(), max_tokens=max_tokens, max_retries=10, timeout=1000, **extra_kwargs)

elif self.model_provider == "openai":
if not os.getenv("OPENAI_API_KEY"):
Expand Down Expand Up @@ -129,7 +143,7 @@

def bind_tools(
self,
tools: Sequence[BaseTool],

Check failure on line 146 in src/codegen/extensions/langchain/llm.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 of "bind_tools" is incompatible with supertype "BaseChatModel"; supertype defines the argument type as "Sequence[dict[str, Any] | type | Callable[..., Any] | BaseTool]" [override]
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the underlying model.
Expand Down