@@ -221,33 +221,32 @@ def generate_episodic_semantic_retrieval_queries(num_problems_fetched, problem_d
221
221
save_json (final_queries , 'queries_firstsolve_episodic_semantic' )
222
222
return final_queries
223
223
224
- def generate_reflexion_queries (rdict , sdict , problem_dict , model_name , prev_queries , retrieval = False ):
224
+ def generate_reflexion_queries (rdict , sdict , problem_dict , model_name , iteration , prev_queries_dict = None , retrieval = False ):
225
225
reflection_queries_dict = dict ()
226
- prev_queries_dict = dict ()
227
- for query in prev_queries :
228
- prev_queries_dict [query ['problem_id' ]] = query
229
226
230
- # Extracting Original Response
231
227
for problem_id in sdict .keys ():
232
228
if problem_id in problem_dict .keys ():
233
229
for solution in sdict [problem_id ][:1 ]:
234
- original_response = solution ['solution' ]
230
+ prev_buffer = ''
231
+ if prev_queries_dict :
232
+ prev_buffer = prev_queries_dict [problem_id ]['reflection_buffer' ]
233
+ current_response = solution ['solution' ]
234
+ current_execution_output = rdict [solution ['problem_id' ]][0 ]['result_list' ]
235
235
num_samples = problem_dict [problem_id ]['description' ].count ("SAMPLE INPUT" )
236
- unparsed_execution_output = rdict [problem_id ][0 ]['result_list' ]
237
236
execution_output = ""
238
- if unparsed_execution_output :
239
- unparsed_execution_output = unparsed_execution_output [:num_samples ]
240
- for i , result in enumerate (unparsed_execution_output ):
237
+ if current_execution_output :
238
+ current_execution_output = current_execution_output [:num_samples ]
239
+ for i , result in enumerate (current_execution_output ):
241
240
execution_output += f"Test Case { i } \n " + result ['status' ] + "\n "
242
241
else :
243
242
execution_output = "No submission, formatting error during judging."
243
+
244
244
if retrieval :
245
245
retrieval_text = prev_queries_dict [problem_id ]['retrieval_text' ]
246
246
retrieval_problem_ids = prev_queries_dict [problem_id ]['retrieval_problem_ids' ]
247
- reflection_queries_dict [problem_id ] = {'problem_id' : problem_id , 'original_response ' : original_response , 'execution_response' : execution_output , 'retrieval_text' : retrieval_text , 'retrieval_problem_ids' : retrieval_problem_ids , 'problem_description' : problem_dict [problem_id ]['description' ]}
247
+ reflection_queries_dict [problem_id ] = {'problem_id' : problem_id , 'reflection_buffer ' : prev_buffer + f' \n Reflection Response Number { iteration + 1 } : \n ' + current_response + f' \n Reflection Response Execution Output Number { iteration + 1 } : \n ' + execution_output , 'retrieval_text' : retrieval_text , 'retrieval_problem_ids' : retrieval_problem_ids , 'problem_description' : problem_dict [problem_id ]['description' ]}
248
248
else :
249
- reflection_queries_dict [problem_id ] = {'problem_id' : problem_id , 'original_response' : original_response , 'execution_response' : execution_output , 'problem_description' : problem_dict [problem_id ]['description' ]}
250
-
249
+ reflection_queries_dict [problem_id ] = {'problem_id' : problem_id , 'reflection_buffer' : prev_buffer + f'\n Reflection Response Number { iteration + 1 } \n ' + current_response + f'\n Reflection Response Execution Output Number { iteration + 1 } :\n ' + execution_output , 'problem_description' : problem_dict [problem_id ]['description' ]}
251
250
if retrieval :
252
251
name = f'queries_dict_{ model_name } _retrieval_reflexion'
253
252
else :
@@ -257,7 +256,7 @@ def generate_reflexion_queries(rdict, sdict, problem_dict, model_name, prev_quer
257
256
258
257
def calculate_final_rs (reflexions , problem_dict ):
259
258
rs = []
260
- for problem_id in problem_dict .keys ():
259
+ for problem_id in reflexions [ 0 ] .keys ():
261
260
num_samples = problem_dict [problem_id ]['description' ].count ('SAMPLE INPUT' )
262
261
for i , reflexion_result in enumerate (reflexions ):
263
262
if reflexion_result [problem_id ][0 ]['result_list' ]:
@@ -288,7 +287,7 @@ def run_solve(model_fn, model_name, problem_dict, attempts, return_queries=False
288
287
for problem_id in problem_dict .keys ():
289
288
queries .append ({'problem_id' : problem_id , 'problem_description' : problem_dict [problem_id ]['description' ]})
290
289
291
- rdict , sdict , rs , ss = evaluate_model (model_fn , solve_prompt_fn , queries = queries , verbose = True , attempts = attempts , problem_ids = list (problem_dict .keys ()))
290
+ rdict , sdict , rs , ss = evaluate_model (model_fn , solve_prompt_fn , queries = queries , verbose = True , attempts = attempts , problem_ids = list (problem_dict .keys ())[: 2 ] )
292
291
save_json ([rdict , sdict , rs , ss ], f'results/results_{ model_name } _solve_{ attempts } attempts' )
293
292
return (rdict , sdict , rs , ss ) if not return_queries else (rdict , sdict , rs , ss , queries )
294
293
@@ -306,9 +305,9 @@ def run_retrieval(model_fn, model_name, problem_dict, attempts, solution_sets, n
306
305
307
306
return (rdict , sdict , rs , ss ) if not return_queries else (rdict , sdict , rs , ss , queries )
308
307
309
- def run_reflexion (model_fn , model_name , problem_dict , attempts , prev_result_dict , prev_solution_dict , prev_queries , iteration , return_queries = True ):
310
- new_reflexion_queries = generate_reflexion_queries (prev_result_dict , prev_solution_dict , problem_dict , model_name , prev_queries )
311
- rdict , sdict , rs , ss = evaluate_model (model_fn , reflexion_prompt_fn , queries = new_reflexion_queries , verbose = True , attempts = attempts , problem_ids = list ( problem_dict . keys ()) )
308
+ def run_reflexion (model_fn , model_name , problem_dict , attempts , prev_result_dict , prev_solution_dict , prev_queries_dict , iteration , return_queries = True , retrieval = False ):
309
+ new_reflexion_queries_dict = generate_reflexion_queries (prev_result_dict , prev_solution_dict , problem_dict , model_name , iteration , prev_queries_dict = prev_queries_dict , retrieval = retrieval )
310
+ rdict , sdict , rs , ss = evaluate_model (model_fn , reflexion_prompt_fn , queries = list ( new_reflexion_queries_dict . values ()) , verbose = True , attempts = attempts )
312
311
save_json ([rdict , sdict , rs , ss ], f'results_{ model_name } _reflexion_{ str (iteration )} iteration' )
313
312
314
- return (rdict , sdict , rs , ss ) if not return_queries else (rdict , sdict , rs , ss , new_reflexion_queries )
313
+ return (rdict , sdict , rs , ss ) if not return_queries else (rdict , sdict , rs , ss , new_reflexion_queries_dict )
0 commit comments