-
Notifications
You must be signed in to change notification settings - Fork 1
DeepFinance Enhancements #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bac05b5
ba41164
c7ca8c7
7f2b017
9dd3c42
757f8a1
079e4bd
bcce8f0
4662d63
de81c1d
248acc4
9d651fd
7475ecc
b95d491
f20ab91
ea87d4b
3082bca
ef44b63
0889483
db7114c
5a25550
623b7d9
0aaab86
04f4959
d0ff68b
1c356d7
37dcbcc
529ae7e
f4eb231
1e07515
08ba184
3d55692
a478827
88be3e4
fb41962
a1f909b
8d2e5d7
3c85960
9b541c5
06fda5f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -82,27 +82,46 @@ def extract_text_content_from_content_dict(self, msg): | |
| # }, | ||
| # ], | ||
| # } | ||
| # or tool_result format: | ||
| # msg = { | ||
| # "role": "tool", | ||
| # "content": [ | ||
| # { | ||
| # "type": "tool_result", | ||
| # "id": "call_xxx", | ||
| # "output": "tool output content", | ||
| # "name": "tool_name" | ||
| # }, | ||
| # ], | ||
| # } | ||
|
|
||
| str_content = "" | ||
| for item in msg["content"]: | ||
| # item = { | ||
| # "type": "text", | ||
| # "text": "some text" | ||
| # }, | ||
|
|
||
| assert isinstance(item, dict), f"Unsupported non-dict item in message content: {item}. Full message: {msg}" | ||
|
|
||
| if ("text" not in item): | ||
| item_type = item.get("type", "") | ||
|
|
||
| # Handle text content block | ||
| if "text" in item: | ||
| if isinstance(item["text"], str): | ||
| str_content += item["text"] | ||
| # Handle tool_result content block (AgentScope format) | ||
| elif item_type == "tool_result" and "output" in item: | ||
| output = item["output"] | ||
| if isinstance(output, str): | ||
| str_content += output | ||
| else: | ||
| str_content += str(output) | ||
| # Handle tool_use content block (for completeness) | ||
| elif item_type == "tool_use": | ||
| # tool_use blocks are handled via tool_calls field, skip content extraction | ||
| continue | ||
| else: | ||
| logger.warning( | ||
| f"Non-text content in message content detected: {item}. Ignoring." | ||
| f"Non-text content in message content detected: {item}. Ignoring this item." | ||
| ) | ||
| should_skip_message = True | ||
| return str_content, should_skip_message | ||
|
|
||
| if isinstance(item["text"], str): | ||
| str_content += str(item["text"]) | ||
| else: | ||
| str_content = "" | ||
| # Continue processing other items instead of skipping the entire message | ||
| continue | ||
|
Comment on lines
120
to
+124
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changing the error handling from immediately returning to logging a warning and continuing to process other items in the message content is a significant improvement. This makes the system more resilient to malformed or unexpected content blocks, ensuring that valid parts of a message are still processed instead of skipping the entire message. |
||
|
|
||
| should_skip_message = False | ||
| return str_content, should_skip_message | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,6 +99,7 @@ def parse_args(): | |
| default=False, | ||
| help="Kill system processes (ray + vllm + python) that may block the current experiment", | ||
| ) | ||
| parser.add_argument("--prefix", type=str, default="", required=False, help="Prefix for service names") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return parser.parse_args() | ||
|
|
||
|
|
||
|
|
@@ -304,7 +305,7 @@ def main(): | |
| pty_launch("appworld") | ||
|
|
||
| if args.with_deepfinance: | ||
| pty_launch("deepfinance") | ||
| pty_launch("deepfinance", prefix=args.prefix) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| if args.with_crafters: | ||
| pty_launch("crafters") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,7 +40,7 @@ def save_trajectory_as_json(ctx_trackers, global_steps, prefix="train"): | |
| # Define save directory and file path | ||
| traj_save_dir = os.path.join( | ||
| os.environ.get("BEST_LOGGER_PATH", "launcher_record"), | ||
| "ctx_trackers", | ||
| "trajectory", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| prefix, | ||
| f"step_{global_steps}" | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -125,6 +125,7 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" | |
| if calls > 0: | ||
| error_rate = errors / calls * 100 | ||
| metrics[f"{prefix}tool_error/{tool_name}/error_rate"] = round(error_rate, 2) | ||
| metrics[f"{prefix}tool_error/{tool_name}/calls"] = calls | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
|
|
||
| return metrics | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -96,13 +96,15 @@ def pty_wrapper_final(human_cmd, dir, env_dict): | |
| pty_wrapper(["/bin/bash", "-c", human_cmd], dir, env_dict) | ||
|
|
||
|
|
||
| def pty_launch(service_name: str, success_std_string="Starting server on"): | ||
| def pty_launch(service_name: str, success_std_string="Starting server on", prefix: str=""): | ||
| from ajet.utils.smart_daemon import LaunchCommandWhenAbsent | ||
|
|
||
| service_path = os.environ.get(f"{service_name.upper()}_PATH") | ||
| service_script = os.environ.get(f"{service_name.upper()}_SCRIPT") | ||
| if service_path is None or service_script is None: | ||
| raise ValueError(f"Environment variables for {service_name} not properly set.") | ||
| if prefix != "": | ||
| service_name = prefix + "_" + service_name | ||
|
Comment on lines
+106
to
+107
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| companion = LaunchCommandWhenAbsent( | ||
| full_argument_list=[service_script], | ||
| dir=service_path, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -373,8 +373,12 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO | |
| fused_reward, contributions = self._fuse_grader_scores(grader_scores, rm_raw) | ||
|
|
||
| # 6. 计算惩罚项(保留原有的 tool_calls 惩罚逻辑) | ||
| tool_calls = metadata.get("tool_stats", {}).get("total_calls", 0) | ||
| # 从 log_metrics 中提取 tool_stats(deep_finance.py 将其放在 log_metrics 而非 metadata) | ||
| tool_stats = workflow_output.log_metrics.get("tool_stats", {}) | ||
| tool_calls = tool_stats.get("total_calls", 0) | ||
|
Comment on lines
+377
to
+378
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| penalty = self._compute_penalty(tool_calls) | ||
| if penalty < 0: | ||
| print(f"⚠️ Penalty applied: penalty={penalty}, tool_calls={tool_stats}") | ||
|
Comment on lines
+380
to
+381
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # 7. 汇总 | ||
| final_reward = fused_reward + step_reward + penalty | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,7 @@ ajet: | |
| rollout: | ||
| # ✨✨✨✨ 编写并选择Agent | ||
| user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol | ||
| force_disable_toolcalls: True | ||
| force_disable_toolcalls: False | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| enable_oversample: False | ||
| tensor_model_parallel_size: 8 | ||
| num_repeat: {{NUM_REPEAT}} | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The updated logic to handle
tool_resultcontent blocks and convert theiroutputto a string, even if it's not initially a string, is a good improvement. This makes the content extraction more robust and prevents potential data loss or type errors when processing diverse tool outputs.