33import os
44import re
55import sys
6+ import threading
67
78from concurrent .futures import ThreadPoolExecutor , as_completed
89from time import time
@@ -85,8 +86,13 @@ def generate_response(llm_client, context, question, choice_a, choice_b, choice_
8586 return ""
8687
8788
88- def process_sample (search_result , llm_client ):
89+ def process_sample (search_result , llm_client , success_records , record_file , file_lock ):
8990 """Process a single sample: generate answer."""
91+ sample_idx = search_result .get ("sample_idx" )
92+ # Skip if already processed
93+ if sample_idx is not None and str (sample_idx ) in success_records :
94+ return None
95+
9096 start = time ()
9197
9298 context = search_result .get ("context" , "" )
@@ -96,6 +102,10 @@ def process_sample(search_result, llm_client):
96102 choice_c = search_result .get ("choice_C" , "" )
97103 choice_d = search_result .get ("choice_D" , "" )
98104
105+ # Skip empty/placeholder contexts (e.g., "\n" or whitespace-only)
106+ if not context or context .strip () == "" :
107+ return None
108+
99109 # Generate answer
100110 response = generate_response (
101111 llm_client , context , question , choice_a , choice_b , choice_c , choice_d
@@ -106,7 +116,7 @@ def process_sample(search_result, llm_client):
106116
107117 response_duration_ms = (time () - start ) * 1000
108118
109- return {
119+ result = {
110120 "sample_idx" : search_result .get ("sample_idx" ),
111121 "_id" : search_result .get ("_id" ),
112122 "domain" : search_result .get ("domain" ),
@@ -123,10 +133,20 @@ def process_sample(search_result, llm_client):
123133 "response" : response ,
124134 "judge" : pred == search_result .get ("answer" ) if pred else False ,
125135 "search_context" : context ,
136+ # Preserve full search results payload (e.g., list of memories)
137+ "search_results" : search_result .get ("search_results" ),
126138 "response_duration_ms" : response_duration_ms ,
127139 "search_duration_ms" : search_result .get ("search_duration_ms" , 0 ),
128140 }
129141
142+ # Record successful processing (thread-safe)
143+ if sample_idx is not None :
144+ with file_lock , open (record_file , "a" ) as f :
145+ f .write (f"{ sample_idx } \n " )
146+ f .flush ()
147+
148+ return result
149+
130150
131151def main (frame , version = "default" , num_workers = 10 ):
132152 """Main response generation function."""
@@ -136,10 +156,16 @@ def main(frame, version="default", num_workers=10):
136156 print (f"🚀 LONGBENCH V2 RESPONSE GENERATION - { frame .upper ()} v{ version } " .center (80 ))
137157 print ("=" * 80 + "\n " )
138158
139- # Load search results
140- search_path = (
141- f" results/long_bench-v2/ { frame } -{ version } / { frame } _longbench_v2_search_results.json "
159+ # Initialize checkpoint file for resume functionality
160+ checkpoint_dir = os . path . join (
161+ ROOT_DIR , "evaluation" , " results" , "long_bench_v2" , f" { frame } -{ version } "
142162 )
163+ os .makedirs (checkpoint_dir , exist_ok = True )
164+ record_file = os .path .join (checkpoint_dir , "response_success_records.txt" )
165+ search_path = os .path .join (checkpoint_dir , f"{ frame } _longbench_v2_search_results.json" )
166+ output_path = os .path .join (checkpoint_dir , f"{ frame } _longbench_v2_responses.json" )
167+
168+ # Load search results
143169 if not os .path .exists (search_path ):
144170 print (f"❌ Search results not found: { search_path } " )
145171 print ("Please run longbench_v2_search.py first" )
@@ -148,6 +174,30 @@ def main(frame, version="default", num_workers=10):
148174 with open (search_path , encoding = "utf-8" ) as f :
149175 search_results = json .load (f )
150176
177+ # Load existing results and success records for resume
178+ existing_results = {}
179+ success_records = set ()
180+ if os .path .exists (output_path ):
181+ with open (output_path , encoding = "utf-8" ) as f :
182+ existing_results_list = json .load (f )
183+ for result in existing_results_list :
184+ sample_idx = result .get ("sample_idx" )
185+ if sample_idx is not None :
186+ existing_results [sample_idx ] = result
187+ success_records .add (str (sample_idx ))
188+ print (f"📋 Found { len (existing_results )} existing responses (resume mode)" )
189+ else :
190+ print ("📋 Starting fresh response generation (no checkpoint found)" )
191+
192+ # Load additional success records from checkpoint file
193+ if os .path .exists (record_file ):
194+ with open (record_file ) as f :
195+ for line in f :
196+ line = line .strip ()
197+ if line and line not in success_records :
198+ success_records .add (line )
199+ print (f"📋 Total { len (success_records )} samples already processed" )
200+
151201 # Initialize LLM client
152202 llm_client = OpenAI (
153203 api_key = os .getenv ("CHAT_MODEL_API_KEY" ),
@@ -156,9 +206,15 @@ def main(frame, version="default", num_workers=10):
156206 print (f"🔌 Using OpenAI client with model: { os .getenv ('CHAT_MODEL' )} " )
157207
158208 # Process all samples
159- all_responses = []
209+ new_results = []
210+ file_lock = threading .Lock () # Lock for thread-safe file writing
160211 with ThreadPoolExecutor (max_workers = num_workers ) as executor :
161- futures = [executor .submit (process_sample , sample , llm_client ) for sample in search_results ]
212+ futures = [
213+ executor .submit (
214+ process_sample , sample , llm_client , success_records , record_file , file_lock
215+ )
216+ for sample in search_results
217+ ]
162218
163219 for future in tqdm (
164220 as_completed (futures ),
@@ -167,11 +223,16 @@ def main(frame, version="default", num_workers=10):
167223 ):
168224 result = future .result ()
169225 if result :
170- all_responses .append (result )
171-
172- # Save responses
173- output_path = f"results/long_bench-v2/{ frame } -{ version } /{ frame } _longbench_v2_responses.json"
174- os .makedirs (os .path .dirname (output_path ), exist_ok = True )
226+ new_results .append (result )
227+ # Update existing results with new result
228+ sample_idx = result .get ("sample_idx" )
229+ if sample_idx is not None :
230+ existing_results [sample_idx ] = result
231+
232+ # Merge and save all results
233+ all_responses = list (existing_results .values ())
234+ # Sort by sample_idx to maintain order
235+ all_responses .sort (key = lambda x : x .get ("sample_idx" , 0 ))
175236
176237 with open (output_path , "w" , encoding = "utf-8" ) as f :
177238 json .dump (all_responses , f , ensure_ascii = False , indent = 2 )
0 commit comments