Skip to content

Commit c976532

Browse files
committed
fixed reflexion pipeline
1 parent e374d88 commit c976532

File tree

4 files changed

+66
-29
lines changed

4 files changed

+66
-29
lines changed

USACOBench/prompts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def retrieval_prompt_fn(query, retrieval_type=RetrievalType.EPISODIC):
3232
[END PROBLEM]"""
3333

3434
def reflexion_prompt_fn(query, retrieval=False):
35-
retrieval = ""
35+
retrieval_text = ""
3636
if retrieval:
3737
retrieval_text = "You were also given a couple of similar problems to the problem above along with their solutions to aid you in solving the problem at hand. Here are the similar problems you were given:\n" + query['retrieval_text']
3838

models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def chatgpt(messages, model="gpt-4", temperature=0.7, max_tokens=2000, n=1, stop
105105
def chatgpt_raw(messages, model="gpt-4", temperature=0.7, max_tokens=2000, n=1, stop=None, **kwargs) -> list:
106106
return chatgpts_raw([messages] * n, model=model, temperature=temperature, max_tokens=max_tokens, stop=stop, **kwargs)[0]
107107

108-
def chatgpts(messages_list, model="gpt-4", temperature=0.7, max_tokens=2000, stop=None, max_messages=200, **kwargs) -> list:
108+
def chatgpts(messages_list, model="gpt-4", temperature=0.7, max_tokens=2000, stop=None, max_messages=400, **kwargs) -> list:
109109
texts = []
110110
for i in range(0, len(messages_list), max_messages):
111111
responses = asyncio.run(generate_from_openai_chat_completion(model=model, messages_list=messages_list[i: i + max_messages], temperature=temperature, max_tokens=max_tokens, top_p=1, stop=stop, **kwargs))
@@ -115,7 +115,7 @@ def chatgpts(messages_list, model="gpt-4", temperature=0.7, max_tokens=2000, sto
115115
# prompt_tokens[model] += sum(x["usage"]["prompt_tokens"] for x in responses if "usage" in x and "prompt_tokens" in x["usage"])
116116
return texts
117117

118-
def chatgpts_raw(messages_list, model="gpt-4", temperature=0.7, max_tokens=2000, stop=None, max_messages=200, **kwargs) -> list:
118+
def chatgpts_raw(messages_list, model="gpt-4", temperature=0.7, max_tokens=2000, stop=None, max_messages=400, **kwargs) -> list:
119119
'''
120120
Returns raw response messages, not just the text content
121121
'''
@@ -128,7 +128,7 @@ def chatgpts_raw(messages_list, model="gpt-4", temperature=0.7, max_tokens=2000,
128128
# prompt_tokens[model] += sum(x["usage"]["prompt_tokens"] for x in responses if "usage" in x and "prompt_tokens" in x["usage"])
129129
return responses_all
130130

131-
def claude(prompts, model="claude-3-sonnet-20240229", temperature=0.7, max_tokens=3000, stop=None, max_messages=200, system_prompt=None, **kwargs) -> list:
131+
def claude(prompts, model="claude-3-sonnet-20240229", temperature=0.7, max_tokens=3000, stop=None, max_messages=400, system_prompt=None, **kwargs) -> list:
132132
texts = []
133133
if system_prompt is not None:
134134
messages_list = [[{'role': 'system', 'content': system_prompt},

run_usaco.py

+44-6
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
parser.add_argument('-s', '--semantic_retrieval', help='whether to use semantic retrieval', action="store_true", default=False)
2626
parser.add_argument('-r', '--reflexion', help='whether to use reflexion', action="store_true", default=False)
2727
parser.add_argument('-a', '--attempts', help='number of attempts', default=1)
28-
parser.add_argument('-n', '--num_reflexion', help='number of reflexion iterations', default=3)
28+
parser.add_argument('-n', '--num_reflexion', help='number of reflexion iterations', default=2)
2929
args = parser.parse_args()
3030

3131
model_name = args.model_name
@@ -39,9 +39,10 @@
3939
problem_dict = load_problem_dict('usaco_subset307')
4040
model_fn = partial(model_fn, model=model_name)
4141

42+
# A little redundant but it does the job and it's readable...
4243
if not args.episodic_retrieval and not args.semantic_retrieval and not args.reflexion:
4344
rdict, sdict, rs, ss = run_solve(model_fn, model_name, problem_dict, args.attempts)
44-
45+
4546
elif args.episodic_retrieval and not args.semantic_retrieval and not args.reflexion:
4647
rdict, sdict, rs, ss = run_solve(model_fn, model_name, problem_dict, args.attempts)
4748
rdict, sdict, rs, ss = run_retrieval(model_fn, model_name, problem_dict, args.attempts, ss, args.num_retrieved, RetrievalType.EPISODIC)
@@ -55,11 +56,48 @@
5556
rdict, sdict, rs, ss = run_retrieval(model_fn, model_name, problem_dict, args.attempts, ss, args.num_retrieved, RetrievalType.EPISODIC_SEMANTIC)
5657

5758
elif not args.episodic_retrieval and not args.semantic_retrieval and args.reflexion:
58-
rdict, sdict, rs, ss, queries = run_solve(model_fn, model_name, problem_dict, args.attempts, return_queries=True)
59-
reflexions = []
59+
rdict, sdict, rs, ss = run_solve(model_fn, model_name, problem_dict, args.attempts)
60+
reflexions = [rdict]
61+
query_dict = None
62+
for i in range(args.num_reflexion):
63+
rdict, sdict, rs, ss, query_dict = run_reflexion(model_fn, model_name, problem_dict, args.attempts, rdict, sdict, query_dict, i, return_queries=True)
64+
reflexions.append(rdict)
65+
66+
rs = calculate_final_rs(reflexions, problem_dict)
67+
68+
elif args.episodic_retrieval and not args.semantic_retrieval and args.reflexion:
69+
rdict, sdict, rs, ss = run_solve(model_fn, model_name, problem_dict, args.attempts)
70+
rdict, sdict, rs, ss = run_retrieval(model_fn, model_name, problem_dict, args.attempts, ss, args.num_retrieved, RetrievalType.EPISODIC)
71+
72+
reflexions = [rdict]
73+
query_dict = None
74+
for i in range(args.num_reflexion):
75+
rdict, sdict, rs, ss, query_dict = run_reflexion(model_fn, model_name, problem_dict, args.attempts, rdict, sdict, query_dict, i, return_queries=True, retrieval=True)
76+
reflexions.append(rdict)
77+
78+
rs = calculate_final_rs(reflexions, problem_dict)
79+
80+
elif not args.episodic_retrieval and args.semantic_retrieval and args.reflexion:
81+
rdict, sdict, rs, ss = run_solve(model_fn, model_name, problem_dict, args.attempts)
82+
rdict, sdict, rs, ss = run_retrieval(model_fn, model_name, problem_dict, args.attempts, ss, args.num_retrieved, RetrievalType.SEMANTIC)
83+
84+
reflexions = [rdict]
85+
query_dict = None
86+
for i in range(args.num_reflexion):
87+
rdict, sdict, rs, ss, query_dict = run_reflexion(model_fn, model_name, problem_dict, args.attempts, rdict, sdict, query_dict, i, return_queries=True, retrieval=True)
88+
reflexions.append(rdict)
89+
90+
rs = calculate_final_rs(reflexions, problem_dict)
91+
92+
elif args.episodic_retrieval and args.semantic_retrieval and args.reflexion:
93+
rdict, sdict, rs, ss = run_solve(model_fn, model_name, problem_dict, args.attempts)
94+
rdict, sdict, rs, ss = run_retrieval(model_fn, model_name, problem_dict, args.attempts, ss, args.num_retrieved, RetrievalType.EPISODIC_SEMANTIC)
95+
96+
reflexions = [rdict]
97+
query_dict = None
6098
for i in range(args.num_reflexion):
61-
rdict, sdict, rs, ss, queries = run_reflexion(model_fn, model_name, problem_dict, args.attempts, rdict, sdict, queries, i, return_queries=True)
62-
reflexions.append(rs)
99+
rdict, sdict, rs, ss, query_dict = run_reflexion(model_fn, model_name, problem_dict, args.attempts, rdict, sdict, query_dict, i, return_queries=True, retrieval=True)
100+
reflexions.append(rdict)
63101

64102
rs = calculate_final_rs(reflexions, problem_dict)
65103

utils.py

+18-19
Original file line numberDiff line numberDiff line change
@@ -221,33 +221,32 @@ def generate_episodic_semantic_retrieval_queries(num_problems_fetched, problem_d
221221
save_json(final_queries, 'queries_firstsolve_episodic_semantic')
222222
return final_queries
223223

224-
def generate_reflexion_queries(rdict, sdict, problem_dict, model_name, prev_queries, retrieval=False):
224+
def generate_reflexion_queries(rdict, sdict, problem_dict, model_name, iteration, prev_queries_dict=None, retrieval=False):
225225
reflection_queries_dict = dict()
226-
prev_queries_dict = dict()
227-
for query in prev_queries:
228-
prev_queries_dict[query['problem_id']] = query
229226

230-
# Extracting Original Response
231227
for problem_id in sdict.keys():
232228
if problem_id in problem_dict.keys():
233229
for solution in sdict[problem_id][:1]:
234-
original_response = solution['solution']
230+
prev_buffer = ''
231+
if prev_queries_dict:
232+
prev_buffer = prev_queries_dict[problem_id]['reflection_buffer']
233+
current_response = solution['solution']
234+
current_execution_output = rdict[solution['problem_id']][0]['result_list']
235235
num_samples = problem_dict[problem_id]['description'].count("SAMPLE INPUT")
236-
unparsed_execution_output = rdict[problem_id][0]['result_list']
237236
execution_output = ""
238-
if unparsed_execution_output:
239-
unparsed_execution_output = unparsed_execution_output[:num_samples]
240-
for i, result in enumerate(unparsed_execution_output):
237+
if current_execution_output:
238+
current_execution_output = current_execution_output[:num_samples]
239+
for i, result in enumerate(current_execution_output):
241240
execution_output += f"Test Case {i}\n" + result['status'] + "\n"
242241
else:
243242
execution_output = "No submission, formatting error during judging."
243+
244244
if retrieval:
245245
retrieval_text = prev_queries_dict[problem_id]['retrieval_text']
246246
retrieval_problem_ids = prev_queries_dict[problem_id]['retrieval_problem_ids']
247-
reflection_queries_dict[problem_id] = {'problem_id': problem_id, 'original_response': original_response, 'execution_response': execution_output, 'retrieval_text': retrieval_text, 'retrieval_problem_ids': retrieval_problem_ids, 'problem_description': problem_dict[problem_id]['description']}
247+
reflection_queries_dict[problem_id] = {'problem_id': problem_id, 'reflection_buffer': prev_buffer + f'\n Reflection Response Number {iteration+1}: \n' + current_response + f'\n Reflection Response Execution Output Number {iteration+1}:\n' + execution_output, 'retrieval_text': retrieval_text, 'retrieval_problem_ids': retrieval_problem_ids, 'problem_description': problem_dict[problem_id]['description']}
248248
else:
249-
reflection_queries_dict[problem_id] = {'problem_id': problem_id, 'original_response': original_response, 'execution_response': execution_output, 'problem_description': problem_dict[problem_id]['description']}
250-
249+
reflection_queries_dict[problem_id] = {'problem_id': problem_id, 'reflection_buffer': prev_buffer + f'\n Reflection Response Number {iteration+1} \n' + current_response + f'\n Reflection Response Execution Output Number {iteration+1}:\n' + execution_output, 'problem_description': problem_dict[problem_id]['description']}
251250
if retrieval:
252251
name = f'queries_dict_{model_name}_retrieval_reflexion'
253252
else:
@@ -257,7 +256,7 @@ def generate_reflexion_queries(rdict, sdict, problem_dict, model_name, prev_quer
257256

258257
def calculate_final_rs(reflexions, problem_dict):
259258
rs = []
260-
for problem_id in problem_dict.keys():
259+
for problem_id in reflexions[0].keys():
261260
num_samples = problem_dict[problem_id]['description'].count('SAMPLE INPUT')
262261
for i, reflexion_result in enumerate(reflexions):
263262
if reflexion_result[problem_id][0]['result_list']:
@@ -288,7 +287,7 @@ def run_solve(model_fn, model_name, problem_dict, attempts, return_queries=False
288287
for problem_id in problem_dict.keys():
289288
queries.append({'problem_id': problem_id, 'problem_description': problem_dict[problem_id]['description']})
290289

291-
rdict, sdict, rs, ss = evaluate_model(model_fn, solve_prompt_fn, queries=queries, verbose=True, attempts=attempts, problem_ids=list(problem_dict.keys()))
290+
rdict, sdict, rs, ss = evaluate_model(model_fn, solve_prompt_fn, queries=queries, verbose=True, attempts=attempts, problem_ids=list(problem_dict.keys())[:2])
292291
save_json([rdict, sdict, rs, ss], f'results/results_{model_name}_solve_{attempts}attempts')
293292
return (rdict, sdict, rs, ss) if not return_queries else (rdict, sdict, rs, ss, queries)
294293

@@ -306,9 +305,9 @@ def run_retrieval(model_fn, model_name, problem_dict, attempts, solution_sets, n
306305

307306
return (rdict, sdict, rs, ss) if not return_queries else (rdict, sdict, rs, ss, queries)
308307

309-
def run_reflexion(model_fn, model_name, problem_dict, attempts, prev_result_dict, prev_solution_dict, prev_queries, iteration, return_queries=True):
310-
new_reflexion_queries = generate_reflexion_queries(prev_result_dict, prev_solution_dict, problem_dict, model_name, prev_queries)
311-
rdict, sdict, rs, ss = evaluate_model(model_fn, reflexion_prompt_fn, queries=new_reflexion_queries, verbose=True, attempts=attempts, problem_ids=list(problem_dict.keys()))
308+
def run_reflexion(model_fn, model_name, problem_dict, attempts, prev_result_dict, prev_solution_dict, prev_queries_dict, iteration, return_queries=True, retrieval=False):
309+
new_reflexion_queries_dict = generate_reflexion_queries(prev_result_dict, prev_solution_dict, problem_dict, model_name, iteration, prev_queries_dict=prev_queries_dict, retrieval=retrieval)
310+
rdict, sdict, rs, ss = evaluate_model(model_fn, reflexion_prompt_fn, queries=list(new_reflexion_queries_dict.values()), verbose=True, attempts=attempts)
312311
save_json([rdict, sdict, rs, ss], f'results_{model_name}_reflexion_{str(iteration)}iteration')
313312

314-
return (rdict, sdict, rs, ss) if not return_queries else (rdict, sdict, rs, ss, new_reflexion_queries)
313+
return (rdict, sdict, rs, ss) if not return_queries else (rdict, sdict, rs, ss, new_reflexion_queries_dict)

0 commit comments

Comments
 (0)