diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index 03ea70040..436cc7d30 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -155,15 +155,27 @@ def format_header(header_type: str) -> str: # Format messages with appropriate headers formatted_messages = [] - for msg in to_summarize: # No need for slice when iterating full list + image_urls = [] # Track image URLs for the summary prompt + + for msg in to_summarize: if isinstance(msg, HumanMessage): - formatted_messages.append(format_header("human") + msg.content) + # Now we know content is always a list + for item in msg.content: + if item.get("type") == "text": + text_content = item.get("text", "") + if text_content: + formatted_messages.append(format_header("human") + text_content) + elif item.get("type") == "image_url": + image_url = item.get("image_url", {}).get("url") + if image_url: + # We are not including any string data in the summary for image. The image will be present itself! + image_urls.append({"type": "image_url", "image_url": {"url": image_url}}) elif isinstance(msg, AIMessage): # Check for summary message using additional_kwargs if msg.additional_kwargs.get("is_summary"): formatted_messages.append(format_header("summary") + msg.content) elif isinstance(msg.content, list) and len(msg.content) > 0 and isinstance(msg.content[0], dict): - for item in msg.content: # No need for slice when iterating full list + for item in msg.content: if item.get("type") == "text": formatted_messages.append(format_header("ai") + item["text"]) elif item.get("type") == "tool_use": @@ -173,7 +185,7 @@ def format_header(header_type: str) -> str: elif isinstance(msg, ToolMessage): formatted_messages.append(format_header("tool_response") + msg.content) - conversation = "\n".join(formatted_messages) # No need for slice when joining full list + conversation = "\n".join(formatted_messages) summary_llm = LLM( model_provider="anthropic", @@ -181,8 +193,17 @@ def format_header(header_type: str) -> str: temperature=0.3, ) - chain = ChatPromptTemplate.from_template(SUMMARIZE_CONVERSATION_PROMPT) | summary_llm - new_summary = chain.invoke({"conversation": conversation}).content + # Choose template based on whether we have images + summarizer_content = [{"type": "text", "text": SUMMARIZE_CONVERSATION_PROMPT}] + for image_url in image_urls: + summarizer_content.append(image_url) + + chain = ChatPromptTemplate([("human", summarizer_content)]) | summary_llm + new_summary = chain.invoke( + { + "conversation": conversation, + } + ).content return {"messages": {"type": "summarize", "summary": new_summary, "tail": tail, "head": head}} @@ -199,7 +220,7 @@ def should_continue(self, state: GraphState) -> Literal["tools", "summarize_conv return "summarize_conversation" # Summarize if the last message exceeds the max input tokens of the model - 10000 tokens - elif isinstance(last_message, AIMessage) and not just_summarized and curr_input_tokens > (max_input_tokens - 10000): + elif isinstance(last_message, AIMessage) and not just_summarized and curr_input_tokens > (max_input_tokens - 30000): return "summarize_conversation" elif hasattr(last_message, "tool_calls") and last_message.tool_calls: