From 9c76d7a0fbf80d5c7e865bdb48d3065e52c0b984 Mon Sep 17 00:00:00 2001 From: tawsifkamal Date: Wed, 19 Mar 2025 14:08:22 -0700 Subject: [PATCH 1/2] summarization errro stuff --- src/codegen/extensions/langchain/graph.py | 34 +++++++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index 22a49a78d..7888fff66 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -147,15 +147,27 @@ def format_header(header_type: str) -> str: # Format messages with appropriate headers formatted_messages = [] - for msg in to_summarize: # No need for slice when iterating full list + image_urls = [] # Track image URLs for the summary prompt + + for msg in to_summarize: if isinstance(msg, HumanMessage): - formatted_messages.append(format_header("human") + msg.content) + # Now we know content is always a list + for item in msg.content: + if item.get("type") == "text": + text_content = item.get("text", "") + if text_content: + formatted_messages.append(format_header("human") + text_content) + elif item.get("type") == "image_url": + image_url = item.get("image_url", {}).get("url") + if image_url: + # We are not including any string data in the summary for image. The image will be present itself! + image_urls.append({"type": "image_url", "image_url": {"url": image_url}}) elif isinstance(msg, AIMessage): # Check for summary message using additional_kwargs if msg.additional_kwargs.get("is_summary"): formatted_messages.append(format_header("summary") + msg.content) elif isinstance(msg.content, list) and len(msg.content) > 0 and isinstance(msg.content[0], dict): - for item in msg.content: # No need for slice when iterating full list + for item in msg.content: if item.get("type") == "text": formatted_messages.append(format_header("ai") + item["text"]) elif item.get("type") == "tool_use": @@ -165,7 +177,7 @@ def format_header(header_type: str) -> str: elif isinstance(msg, ToolMessage): formatted_messages.append(format_header("tool_response") + msg.content) - conversation = "\n".join(formatted_messages) # No need for slice when joining full list + conversation = "\n".join(formatted_messages) summary_llm = LLM( model_provider="anthropic", @@ -173,8 +185,18 @@ def format_header(header_type: str) -> str: temperature=0.3, ) - chain = ChatPromptTemplate.from_template(SUMMARIZE_CONVERSATION_PROMPT) | summary_llm - new_summary = chain.invoke({"conversation": conversation}).content + # Choose template based on whether we have images + summarizer_content = [{"type": "text", "text": SUMMARIZE_CONVERSATION_PROMPT}] + for image_url in image_urls: + summarizer_content.append({"type": "image_url", "image_url": {"url": image_url}}) + + summarizer_messages = [HumanMessage(content=summarizer_content)] + chain = ChatPromptTemplate.from_messages(summarizer_messages) | summary_llm + new_summary = chain.invoke( + { + "conversation": conversation, + } + ).content return {"messages": {"type": "summarize", "summary": new_summary, "tail": tail, "head": head}} From 9e308d7b89174cad0822a53ec22a69041a1d8d0f Mon Sep 17 00:00:00 2001 From: tawsifkamal Date: Wed, 19 Mar 2025 16:24:46 -0700 Subject: [PATCH 2/2] summarization works now --- src/codegen/extensions/langchain/graph.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index 7888fff66..5c68f7c3c 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -188,10 +188,9 @@ def format_header(header_type: str) -> str: # Choose template based on whether we have images summarizer_content = [{"type": "text", "text": SUMMARIZE_CONVERSATION_PROMPT}] for image_url in image_urls: - summarizer_content.append({"type": "image_url", "image_url": {"url": image_url}}) + summarizer_content.append(image_url) - summarizer_messages = [HumanMessage(content=summarizer_content)] - chain = ChatPromptTemplate.from_messages(summarizer_messages) | summary_llm + chain = ChatPromptTemplate([("human", summarizer_content)]) | summary_llm new_summary = chain.invoke( { "conversation": conversation, @@ -213,7 +212,7 @@ def should_continue(self, state: GraphState) -> Literal["tools", "summarize_conv return "summarize_conversation" # Summarize if the last message exceeds the max input tokens of the model - 10000 tokens - elif isinstance(last_message, AIMessage) and not just_summarized and curr_input_tokens > (max_input_tokens - 10000): + elif isinstance(last_message, AIMessage) and not just_summarized and curr_input_tokens > (max_input_tokens - 30000): return "summarize_conversation" elif hasattr(last_message, "tool_calls") and last_message.tool_calls: