Skip to content

Commit 2c2b9ff

Browse files
author
Yifu Cai
committed
run black
1 parent 2a9f811 commit 2c2b9ff

File tree

9 files changed

+65
-46
lines changed

9 files changed

+65
-46
lines changed

aide/agent.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515

1616
logger = logging.getLogger("aide")
1717

18+
1819
def format_time(time_in_sec: int):
1920
return f"{time_in_sec // 3600}hrs {(time_in_sec % 3600) // 60}mins {time_in_sec % 60}secs"
20-
21+
22+
2123
ExecCallbackType = Callable[[str, bool], ExecutionResult]
2224

2325
review_func_spec = FunctionSpec(
@@ -65,7 +67,7 @@ def __init__(
6567
self.current_step = 0
6668
if self.acfg.cost_limit:
6769
self.token_counter = TokenCounter(
68-
max_cost=self.acfg.cost_limit,
70+
cost_limit=self.acfg.cost_limit,
6971
)
7072
else:
7173
self.token_counter = None
@@ -121,7 +123,7 @@ def _prompt_environment(self):
121123
pkg_str = ", ".join([f"`{p}`" for p in pkgs])
122124

123125
ts_pksg = [
124-
"sktime",
126+
"sktime",
125127
"statsforecast",
126128
"tsfresh",
127129
"neuralforecast",
@@ -142,16 +144,18 @@ def _prompt_impl_guideline(self):
142144
exec_timeout = int(min(self.cfg.exec.timeout, tot_time_remaining))
143145

144146
if self.acfg.remind_resource_limit:
145-
impl_guideline = [f"<TOTAL_TIME_REMAINING: {format_time(tot_time_remaining)}>",
146-
f"<TOTAL_STEPS_REMAINING: {self.acfg.steps - self.current_step}>"]
147-
147+
impl_guideline = [
148+
f"<TOTAL_TIME_REMAINING: {format_time(tot_time_remaining)}>",
149+
f"<TOTAL_STEPS_REMAINING: {self.acfg.steps - self.current_step}>",
150+
]
151+
148152
if self.token_counter:
149153
impl_guideline.append(
150154
f"<OUTPUT_TOKEN_LIMIT_REMAINING: {self.token_counter.remaining_output_tokens(self.acfg.code.model)}>"
151155
)
152156
else:
153157
impl_guideline = []
154-
158+
155159
impl_guideline += [
156160
"The code should **implement the proposed solution** and **print the value of the evaluation metric computed on a hold-out validation set**.",
157161
"**AND MOST IMPORTANTLY SAVE PREDICTIONS ON THE PROVIDED UNLABELED TEST DATA IN REQUIRED FILE FORMAT IN THE ./submission/ DIRECTORY.**",
@@ -352,7 +356,9 @@ def step(self, exec_callback: ExecCallbackType) -> bool:
352356
for item in submission_idr.iterdir():
353357
if item.is_file():
354358
shutil.copy(item, best_submission_dir / item.name)
355-
logger.info(f"Copied {item.name} to {best_submission_dir / item.name}")
359+
logger.info(
360+
f"Copied {item.name} to {best_submission_dir / item.name}"
361+
)
356362
# copy solution.py and relevant node id to best_solution/
357363
with open(best_solution_dir / "solution.py", "w") as f:
358364
f.write(result_node.code)
@@ -366,7 +372,7 @@ def step(self, exec_callback: ExecCallbackType) -> bool:
366372

367373
exceed_budget_limit = self.token_counter.exceed_budget_limit()
368374
return exceed_budget_limit
369-
375+
370376
def parse_exec_result(self, node: Node, exec_result: ExecutionResult):
371377
logger.info(f"Agent is parsing execution results for node {node.id}")
372378

@@ -407,9 +413,7 @@ def parse_exec_result(self, node: Node, exec_result: ExecutionResult):
407413
)
408414

409415
if node.is_buggy:
410-
logger.info(
411-
f"Parsed results: Node {node.id} is buggy"
412-
)
416+
logger.info(f"Parsed results: Node {node.id} is buggy")
413417
node.metric = WorstMetricValue()
414418
else:
415419
logger.info(f"Parsed results: Node {node.id} is not buggy")

aide/backend/__init__.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,21 @@
77

88
logger = logging.getLogger("aide")
99

10-
#cost per input/output token for each model
10+
# cost per input/output token for each model
1111
MODEL_COST = {
12-
"gpt-4o-2024-08-06": {"input": 2.5/1000000, "output": 10/1000000},
13-
"o3-mini": {"input": 1.1/1000000, "output": 4.4/1000000},
14-
"o3": {"input": 10/1000000, "output": 40/1000000},
12+
"gpt-4o-2024-08-06": {"input": 2.5 / 1000000, "output": 10 / 1000000},
13+
"o3-mini": {"input": 1.1 / 1000000, "output": 4.4 / 1000000},
14+
"o3": {"input": 10 / 1000000, "output": 40 / 1000000},
1515
}
1616

17+
1718
def determine_provider(model: str) -> str:
18-
if model.startswith("gpt-") or model.startswith("o1-") or model.startswith("o3-") or model.startswith("o4-"):
19+
if (
20+
model.startswith("gpt-")
21+
or model.startswith("o1-")
22+
or model.startswith("o3-")
23+
or model.startswith("o4-")
24+
):
1925
return "openai"
2026
elif model.startswith("claude-"):
2127
return "anthropic"
@@ -33,67 +39,69 @@ def determine_provider(model: str) -> str:
3339
"openrouter": backend_openrouter.query,
3440
}
3541

42+
3643
class TokenCounter:
37-
def __init__(self, cost_limit:int):
44+
def __init__(self, cost_limit: int):
3845
self.cost_limit = cost_limit
3946
self.total_input_tokens = defaultdict(int)
4047
self.total_output_tokens = defaultdict(int)
41-
48+
4249
def cost(self) -> float:
43-
'''
50+
"""
4451
compute to total cost of the tokens used
45-
'''
52+
"""
4653
total_cost = 0
4754

48-
#compute cost for input tokens
55+
# compute cost for input tokens
4956
for model_name, input_tokens in self.total_input_tokens.items():
5057
if model_name not in MODEL_COST:
5158
raise ValueError(f"Model {model_name} not supported for token counting")
5259
total_cost += input_tokens * MODEL_COST[model_name]["input"]
53-
54-
#compute cost for output tokens
60+
61+
# compute cost for output tokens
5562
for model_name, output_tokens in self.total_output_tokens.items():
5663
if model_name not in MODEL_COST:
5764
raise ValueError(f"Model {model_name} not supported for token counting")
5865
total_cost += output_tokens * MODEL_COST[model_name]["output"]
5966
return total_cost
60-
61-
def add_tokens(self, model_name:str, input_tokens=None, output_tokens=None):
62-
'''
67+
68+
def add_tokens(self, model_name: str, input_tokens=None, output_tokens=None):
69+
"""
6370
update the token counts
64-
'''
71+
"""
6572
if model_name not in MODEL_COST:
6673
raise ValueError(f"Model {model_name} not supported for token counting")
67-
74+
6875
if input_tokens is not None:
6976
self.total_input_tokens[model_name] += input_tokens
7077
if output_tokens is not None:
7178
self.total_output_tokens[model_name] += output_tokens
7279

73-
def remaining_output_tokens(self, model_name:str, max_budget:int) -> int:
74-
'''
80+
def remaining_output_tokens(self, model_name: str, max_budget: int) -> int:
81+
"""
7582
max_budget: the maximum dollar budget for the model
7683
compute the remaining tokens for a model
77-
'''
84+
"""
7885
if model_name not in MODEL_COST:
7986
raise ValueError(f"Model {model_name} not supported for token counting")
80-
87+
8188
current_cost = self.cost
8289
remaining_budget = max_budget - current_cost
8390
if remaining_budget <= 0:
8491
return 0
8592
else:
8693
output_tokens_cost = MODEL_COST[model_name]["output"]
8794
return int(remaining_budget / output_tokens_cost)
88-
95+
8996
def exceed_budget_limit(self) -> bool:
90-
'''
97+
"""
9198
check if the budget limit is exceeded
92-
'''
93-
99+
"""
100+
94101
current_cost = self.cost
95102
return current_cost > self.cost_limit
96-
103+
104+
97105
def query(
98106
system_message: PromptType | None,
99107
user_message: PromptType | None,
@@ -150,6 +158,8 @@ def query(
150158
logger.info(f"response: {output}", extra={"verbose": True})
151159
logger.info("---Query complete---", extra={"verbose": True})
152160
if token_counter is not None:
153-
token_counter.add_tokens(model, input_tokens=in_tok_count, output_tokens=out_tok_count)
161+
token_counter.add_tokens(
162+
model, input_tokens=in_tok_count, output_tokens=out_tok_count
163+
)
154164

155165
return output

aide/backend/backend_anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@ def query(
7373
"stop_reason": message.stop_reason,
7474
}
7575

76-
return output, req_time, in_tokens, out_tokens, info
76+
return output, req_time, in_tokens, out_tokens, info

aide/backend/backend_gdm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,4 @@ def query(
9595
info = {} # this isnt used anywhere, but is an expected return value
9696

9797
# only `output` is actually used by scaffolding
98-
return output, req_time, in_tokens, out_tokens, info
98+
return output, req_time, in_tokens, out_tokens, info

aide/backend/backend_openai.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def query(
4242
_setup_openai_client()
4343
filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore
4444

45-
messages = opt_messages_to_list(system_message, user_message, convert_system_to_user=convert_system_to_user)
45+
messages = opt_messages_to_list(
46+
system_message, user_message, convert_system_to_user=convert_system_to_user
47+
)
4648

4749
if func_spec is not None:
4850
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
@@ -86,4 +88,4 @@ def query(
8688
"created": completion.created,
8789
}
8890

89-
return output, req_time, in_tokens, out_tokens, info
91+
return output, req_time, in_tokens, out_tokens, info

aide/backend/backend_openrouter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,4 @@ def query(
8383
"created": completion.created,
8484
}
8585

86-
return output, req_time, in_tokens, out_tokens, info
86+
return output, req_time, in_tokens, out_tokens, info

aide/backend/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,4 @@ def openai_tool_choice_dict(self):
8686
return {
8787
"type": "function",
8888
"function": {"name": self.name},
89-
}
89+
}

aide/journal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def generate_summary(self, include_code: bool = False) -> str:
192192
summary.append(summary_part)
193193
return "\n-------------------------------\n".join(summary)
194194

195+
195196
def get_path_to_node(journal: Journal, node_id: str) -> list[str]:
196197
path = [node_id]
197198

@@ -243,4 +244,4 @@ def filter_journal(journal: Journal) -> Journal:
243244
else:
244245
filtered_journal = filter_for_longest_path(journal)
245246

246-
return filtered_journal
247+
return filtered_journal

aide/run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from rich.tree import Tree
2727
from .utils.config import load_task_desc, prep_agent_workspace, save_run, load_cfg
2828

29+
2930
class VerboseFilter(logging.Filter):
3031
"""
3132
Filter (remove) logs that have verbose attribute set to True
@@ -34,6 +35,7 @@ class VerboseFilter(logging.Filter):
3435
def filter(self, record):
3536
return not (hasattr(record, "verbose") and record.verbose)
3637

38+
3739
def journal_to_rich_tree(journal: Journal):
3840
best_node = journal.get_best_node()
3941

0 commit comments

Comments
 (0)