From 9c4f96333a8f35a480976b940515cb38476b342c Mon Sep 17 00:00:00 2001 From: "codegen-sh[bot]" <131295404+codegen-sh[bot]@users.noreply.github.com> Date: Wed, 19 Mar 2025 21:20:21 +0000 Subject: [PATCH 1/2] Implement Anthropic prompt caching for Claude 3.7 Sonnet --- src/codegen/agents/code_agent.py | 10 +++++++++ src/codegen/extensions/langchain/llm.py | 30 +++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/codegen/agents/code_agent.py b/src/codegen/agents/code_agent.py index 693c0cd44..500136987 100644 --- a/src/codegen/agents/code_agent.py +++ b/src/codegen/agents/code_agent.py @@ -46,6 +46,7 @@ def __init__( agent_config: Optional[AgentConfig] = None, thread_id: Optional[str] = None, logger: Optional[ExternalLogger] = None, + enable_prompt_caching: bool = True, **kwargs, ): """Initialize a CodeAgent. @@ -58,6 +59,10 @@ def __init__( tools: Additional tools to use tags: Tags to add to the agent trace. Must be of the same type. metadata: Metadata to use for the agent. Must be a dictionary. + agent_config: Configuration for the agent + thread_id: Optional thread ID for message history + logger: Optional external logger + enable_prompt_caching: Whether to enable prompt caching for Anthropic models **kwargs: Additional LLM configuration options. Supported options: - temperature: Temperature parameter (0-1) - top_p: Top-p sampling parameter (0-1) @@ -65,6 +70,11 @@ def __init__( - max_tokens: Maximum number of tokens to generate """ self.codebase = codebase + + # Add prompt caching to kwargs if using Anthropic + if model_provider == "anthropic" and enable_prompt_caching: + kwargs["enable_prompt_caching"] = True + self.agent = create_codebase_agent( self.codebase, model_provider=model_provider, diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index 4c457e46d..ede6fc110 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -31,6 +31,11 @@ class LLM(BaseChatModel): top_k: Optional[int] = Field(default=None, description="Top-k sampling parameter.", ge=1) max_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate.", ge=1) + + enable_prompt_caching: bool = Field( + default=False, + description="Whether to enable prompt caching for Anthropic models. Only works with Claude 3.5 Sonnet and Claude 3.0 Haiku." + ) def __init__(self, model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", **kwargs: Any) -> None: """Initialize the LLM. @@ -43,13 +48,17 @@ def __init__(self, model_provider: str = "anthropic", model_name: str = "claude- - top_p: Top-p sampling parameter (0-1) - top_k: Top-k sampling parameter (>= 1) - max_tokens: Maximum number of tokens to generate + - enable_prompt_caching: Whether to enable prompt caching (Anthropic only) """ # Set model provider and name before calling super().__init__ kwargs["model_provider"] = model_provider kwargs["model_name"] = model_name # Filter out unsupported kwargs - supported_kwargs = {"model_provider", "model_name", "temperature", "top_p", "top_k", "max_tokens", "callbacks", "tags", "metadata"} + supported_kwargs = { + "model_provider", "model_name", "temperature", "top_p", "top_k", + "max_tokens", "callbacks", "tags", "metadata", "enable_prompt_caching" + } filtered_kwargs = {k: v for k, v in kwargs.items() if k in supported_kwargs} super().__init__(**filtered_kwargs) @@ -90,7 +99,24 @@ def _get_model(self) -> BaseChatModel: msg = "ANTHROPIC_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) max_tokens = 16384 if "claude-3-7" in self.model_name else 8192 - return ChatAnthropic(**self._get_model_kwargs(), max_tokens=max_tokens, max_retries=10, timeout=1000) + + # Add prompt caching if enabled + extra_kwargs = {} + if self.enable_prompt_caching: + # Only enable for supported models + if "claude-3-5-sonnet" in self.model_name or "claude-3-haiku" in self.model_name: + extra_kwargs["anthropic_beta"] = "prompt-caching-2024-07-31" + print("Prompt caching enabled for Anthropic model") + else: + print(f"Warning: Prompt caching requested but not supported for model {self.model_name}") + + return ChatAnthropic( + **self._get_model_kwargs(), + max_tokens=max_tokens, + max_retries=10, + timeout=1000, + **extra_kwargs + ) elif self.model_provider == "openai": if not os.getenv("OPENAI_API_KEY"): From 052ee49c80b095164598cc309b790e7555f7de29 Mon Sep 17 00:00:00 2001 From: "codegen-sh[bot]" <131295404+codegen-sh[bot]@users.noreply.github.com> Date: Wed, 19 Mar 2025 21:21:15 +0000 Subject: [PATCH 2/2] Automated pre-commit update --- src/codegen/agents/code_agent.py | 4 ++-- src/codegen/extensions/langchain/llm.py | 24 ++++++------------------ 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/codegen/agents/code_agent.py b/src/codegen/agents/code_agent.py index 500136987..3211965c8 100644 --- a/src/codegen/agents/code_agent.py +++ b/src/codegen/agents/code_agent.py @@ -70,11 +70,11 @@ def __init__( - max_tokens: Maximum number of tokens to generate """ self.codebase = codebase - + # Add prompt caching to kwargs if using Anthropic if model_provider == "anthropic" and enable_prompt_caching: kwargs["enable_prompt_caching"] = True - + self.agent = create_codebase_agent( self.codebase, model_provider=model_provider, diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index ede6fc110..3079ebc1a 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -31,11 +31,8 @@ class LLM(BaseChatModel): top_k: Optional[int] = Field(default=None, description="Top-k sampling parameter.", ge=1) max_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate.", ge=1) - - enable_prompt_caching: bool = Field( - default=False, - description="Whether to enable prompt caching for Anthropic models. Only works with Claude 3.5 Sonnet and Claude 3.0 Haiku." - ) + + enable_prompt_caching: bool = Field(default=False, description="Whether to enable prompt caching for Anthropic models. Only works with Claude 3.5 Sonnet and Claude 3.0 Haiku.") def __init__(self, model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", **kwargs: Any) -> None: """Initialize the LLM. @@ -55,10 +52,7 @@ def __init__(self, model_provider: str = "anthropic", model_name: str = "claude- kwargs["model_name"] = model_name # Filter out unsupported kwargs - supported_kwargs = { - "model_provider", "model_name", "temperature", "top_p", "top_k", - "max_tokens", "callbacks", "tags", "metadata", "enable_prompt_caching" - } + supported_kwargs = {"model_provider", "model_name", "temperature", "top_p", "top_k", "max_tokens", "callbacks", "tags", "metadata", "enable_prompt_caching"} filtered_kwargs = {k: v for k, v in kwargs.items() if k in supported_kwargs} super().__init__(**filtered_kwargs) @@ -99,7 +93,7 @@ def _get_model(self) -> BaseChatModel: msg = "ANTHROPIC_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) max_tokens = 16384 if "claude-3-7" in self.model_name else 8192 - + # Add prompt caching if enabled extra_kwargs = {} if self.enable_prompt_caching: @@ -109,14 +103,8 @@ def _get_model(self) -> BaseChatModel: print("Prompt caching enabled for Anthropic model") else: print(f"Warning: Prompt caching requested but not supported for model {self.model_name}") - - return ChatAnthropic( - **self._get_model_kwargs(), - max_tokens=max_tokens, - max_retries=10, - timeout=1000, - **extra_kwargs - ) + + return ChatAnthropic(**self._get_model_kwargs(), max_tokens=max_tokens, max_retries=10, timeout=1000, **extra_kwargs) elif self.model_provider == "openai": if not os.getenv("OPENAI_API_KEY"):