diff --git a/src/codegen/agents/code_agent.py b/src/codegen/agents/code_agent.py index 693c0cd44..3211965c8 100644 --- a/src/codegen/agents/code_agent.py +++ b/src/codegen/agents/code_agent.py @@ -46,6 +46,7 @@ def __init__( agent_config: Optional[AgentConfig] = None, thread_id: Optional[str] = None, logger: Optional[ExternalLogger] = None, + enable_prompt_caching: bool = True, **kwargs, ): """Initialize a CodeAgent. @@ -58,6 +59,10 @@ def __init__( tools: Additional tools to use tags: Tags to add to the agent trace. Must be of the same type. metadata: Metadata to use for the agent. Must be a dictionary. + agent_config: Configuration for the agent + thread_id: Optional thread ID for message history + logger: Optional external logger + enable_prompt_caching: Whether to enable prompt caching for Anthropic models **kwargs: Additional LLM configuration options. Supported options: - temperature: Temperature parameter (0-1) - top_p: Top-p sampling parameter (0-1) @@ -65,6 +70,11 @@ def __init__( - max_tokens: Maximum number of tokens to generate """ self.codebase = codebase + + # Add prompt caching to kwargs if using Anthropic + if model_provider == "anthropic" and enable_prompt_caching: + kwargs["enable_prompt_caching"] = True + self.agent = create_codebase_agent( self.codebase, model_provider=model_provider, diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index 4c457e46d..3079ebc1a 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -32,6 +32,8 @@ class LLM(BaseChatModel): max_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate.", ge=1) + enable_prompt_caching: bool = Field(default=False, description="Whether to enable prompt caching for Anthropic models. Only works with Claude 3.5 Sonnet and Claude 3.0 Haiku.") + def __init__(self, model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", **kwargs: Any) -> None: """Initialize the LLM. @@ -43,13 +45,14 @@ def __init__(self, model_provider: str = "anthropic", model_name: str = "claude- - top_p: Top-p sampling parameter (0-1) - top_k: Top-k sampling parameter (>= 1) - max_tokens: Maximum number of tokens to generate + - enable_prompt_caching: Whether to enable prompt caching (Anthropic only) """ # Set model provider and name before calling super().__init__ kwargs["model_provider"] = model_provider kwargs["model_name"] = model_name # Filter out unsupported kwargs - supported_kwargs = {"model_provider", "model_name", "temperature", "top_p", "top_k", "max_tokens", "callbacks", "tags", "metadata"} + supported_kwargs = {"model_provider", "model_name", "temperature", "top_p", "top_k", "max_tokens", "callbacks", "tags", "metadata", "enable_prompt_caching"} filtered_kwargs = {k: v for k, v in kwargs.items() if k in supported_kwargs} super().__init__(**filtered_kwargs) @@ -90,7 +93,18 @@ def _get_model(self) -> BaseChatModel: msg = "ANTHROPIC_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) max_tokens = 16384 if "claude-3-7" in self.model_name else 8192 - return ChatAnthropic(**self._get_model_kwargs(), max_tokens=max_tokens, max_retries=10, timeout=1000) + + # Add prompt caching if enabled + extra_kwargs = {} + if self.enable_prompt_caching: + # Only enable for supported models + if "claude-3-5-sonnet" in self.model_name or "claude-3-haiku" in self.model_name: + extra_kwargs["anthropic_beta"] = "prompt-caching-2024-07-31" + print("Prompt caching enabled for Anthropic model") + else: + print(f"Warning: Prompt caching requested but not supported for model {self.model_name}") + + return ChatAnthropic(**self._get_model_kwargs(), max_tokens=max_tokens, max_retries=10, timeout=1000, **extra_kwargs) elif self.model_provider == "openai": if not os.getenv("OPENAI_API_KEY"):