diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index e13982c..74e7cf7 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -82,6 +82,18 @@ def extract_text_content_from_content_dict(self, msg): # }, # ], # } + # or tool_result format?? not observed yet: + # msg = { + # "role": "tool", + # "content": [ + # { + # "type": "tool_result", + # "id": "call_xxx", + # "output": "tool output content", + # "name": "tool_name" + # }, + # ], + # } str_content = "" for item in msg["content"]: @@ -89,6 +101,9 @@ def extract_text_content_from_content_dict(self, msg): # "type": "text", # "text": "some text" # }, + item_type = item.get("type", "") + assert not item_type == "tool_use", f"never observed such protocal yet" + assert not item_type == "tool_result", f"never observed such protocal yet" assert isinstance(item, dict), f"Unsupported non-dict item in message content: {item}. Full message: {msg}" diff --git a/ajet/launcher.py b/ajet/launcher.py index 47345ce..4055713 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -99,6 +99,7 @@ def parse_args(): default=False, help="Kill system processes (ray + vllm + python) that may block the current experiment", ) + parser.add_argument("--prefix", type=str, default="", required=False, help="Prefix for deepfinance service names") return parser.parse_args() @@ -304,7 +305,7 @@ def main(): pty_launch("appworld") if args.with_deepfinance: - pty_launch("deepfinance") + pty_launch("deepfinance", prefix=args.prefix) if args.with_crafters: pty_launch("crafters") diff --git a/ajet/utils/metric_helper/save_trajectory_as_json.py b/ajet/utils/metric_helper/save_trajectory_as_json.py index 9dd5186..4ab263f 100644 --- a/ajet/utils/metric_helper/save_trajectory_as_json.py +++ b/ajet/utils/metric_helper/save_trajectory_as_json.py @@ -40,7 +40,7 @@ def save_trajectory_as_json(ctx_trackers, global_steps, prefix="train"): # Define save directory and file path traj_save_dir = os.path.join( os.environ.get("BEST_LOGGER_PATH", "launcher_record"), - "ctx_trackers", + "trajectory", prefix, f"step_{global_steps}" ) diff --git a/ajet/utils/metric_helper/tool_metric_helper.py b/ajet/utils/metric_helper/tool_metric_helper.py index 3ce5da2..a656a07 100644 --- a/ajet/utils/metric_helper/tool_metric_helper.py +++ b/ajet/utils/metric_helper/tool_metric_helper.py @@ -125,6 +125,7 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" if calls > 0: error_rate = errors / calls * 100 metrics[f"{prefix}tool_error/{tool_name}/error_rate"] = round(error_rate, 2) + metrics[f"{prefix}tool_error/{tool_name}/calls"] = calls return metrics diff --git a/ajet/utils/pty.py b/ajet/utils/pty.py index 6d859ae..e675611 100644 --- a/ajet/utils/pty.py +++ b/ajet/utils/pty.py @@ -96,13 +96,15 @@ def pty_wrapper_final(human_cmd, dir, env_dict): pty_wrapper(["/bin/bash", "-c", human_cmd], dir, env_dict) -def pty_launch(service_name: str, success_std_string="Starting server on"): +def pty_launch(service_name: str, success_std_string="Starting server on", prefix: str=""): from ajet.utils.smart_daemon import LaunchCommandWhenAbsent service_path = os.environ.get(f"{service_name.upper()}_PATH") service_script = os.environ.get(f"{service_name.upper()}_SCRIPT") if service_path is None or service_script is None: raise ValueError(f"Environment variables for {service_name} not properly set.") + if prefix != "": + service_name = prefix + "_" + service_name companion = LaunchCommandWhenAbsent( full_argument_list=[service_script], dir=service_path, diff --git a/pyproject.toml b/pyproject.toml index 856cddc..474e902 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ ] requires-python = ">=3.10,<3.13" dependencies = [ - "agentscope==1.0.7", + "agentscope==1.0.8", "chromadb", "httpx", "tenacity", diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index 6fd46f4..d9f624a 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -3,7 +3,7 @@ set -e #=============================================================================== # 1. 配置区域 - 用户只需修改这里 #=============================================================================== -SUFFIX="ajet_deep_finance" # 实验后缀,影响所有日志和实验名称 +SUFFIX="deep_finance" # 实验后缀,影响所有日志和实验名称 PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 # OpenJudge 模型配置 @@ -208,6 +208,7 @@ if [[ $HOSTNAME == *"-master-"* ]]; then --with-deepfinance \ --conf ${CONFIG_FILE} \ --backbone="verl" \ + --prefix=${SUFFIX} \ 2>&1 | tee ${TRAIN_LOG} diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index f49d88d..31e4be0 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -373,8 +373,12 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO fused_reward, contributions = self._fuse_grader_scores(grader_scores, rm_raw) # 6. 计算惩罚项(保留原有的 tool_calls 惩罚逻辑) - tool_calls = metadata.get("tool_stats", {}).get("total_calls", 0) + # 从 log_metrics 中提取 tool_stats(deep_finance.py 将其放在 log_metrics 而非 metadata) + tool_stats = workflow_output.log_metrics.get("tool_stats", {}) + tool_calls = tool_stats.get("total_calls", 0) penalty = self._compute_penalty(tool_calls) + if penalty < 0: + print(f"⚠️ Penalty applied: penalty={penalty}, tool_calls={tool_stats}") # 7. 汇总 final_reward = fused_reward + step_reward + penalty diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index a2d2cd7..8e6065d 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -32,7 +32,7 @@ ajet: rollout: # ✨✨✨✨ 编写并选择Agent user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol - force_disable_toolcalls: True + force_disable_toolcalls: False enable_oversample: False tensor_model_parallel_size: 8 num_repeat: {{NUM_REPEAT}}