diff --git a/.gitignore b/.gitignore index c16d08c4..95add49e 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,8 @@ datasets tutorial2 site dump.rdb + + +tutorial/example_deep_finance/yaml/* +tutorial/example_deep_finance/config/* +tutorial/example_deep_finance/scripts/* \ No newline at end of file diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 5b9d0853..cb573457 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -603,7 +603,7 @@ def fit(self): # noqa: C901 } ) save_trajectory_as_json_file(context_tracker_arr, self.global_steps, self.config, prefix="train") - update_metrics(context_tracker_arr, metrics) + update_metrics(context_tracker_arr, metrics, prefix="train_") if self.config.ajet.execute_test: # apply a test probe from swanlab.data.run.main import get_run @@ -1047,7 +1047,7 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch): "mean_reward": sum(rewards) / len(rewards) if rewards else 0, } save_trajectory_as_json_file(ctx_trackers, self.global_steps, self.config, prefix="eval") - update_metrics(ctx_trackers, val_metrics) + update_metrics(ctx_trackers, val_metrics, prefix="eval_") print_dict( val_metrics, narrow=True, diff --git a/ajet/backbone/warm_up.py b/ajet/backbone/warm_up.py index fcae673f..c7505c49 100644 --- a/ajet/backbone/warm_up.py +++ b/ajet/backbone/warm_up.py @@ -6,8 +6,9 @@ import asyncio import logging import os -from ajet.utils.async_utils import apply_httpx_aclose_patch +from ajet.utils.async_utils import apply_httpx_aclose_patch, suppress_httpx_aclose_exception apply_httpx_aclose_patch() +suppress_httpx_aclose_exception() def init_parallel_rollout_logger(experiment_name): diff --git a/ajet/context_tracker/base_tracker.py b/ajet/context_tracker/base_tracker.py index 948aee3e..856cd89c 100644 --- a/ajet/context_tracker/base_tracker.py +++ b/ajet/context_tracker/base_tracker.py @@ -1,5 +1,4 @@ -from typing import List, Tuple, Union -from typing import List, Union, Tuple, Dict, Optional +from typing import Any, Dict, List, Optional, Tuple, Union from ajet.schema.task import WorkflowTask from ajet.schema.extended_msg import ( @@ -141,7 +140,7 @@ def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs): self.already_mad_flag: bool = False self.round_cnt = 0 self.generation_prompt_token = None - self.log_metrics: Optional[Dict[str, Union[float, List[float]]]] = None # Initialize workflow_metadata to store tool statistics + self.log_metrics: Optional[Dict[str, Union[float, List[float], Dict[str, Any]]]] = None # Initialize workflow_metadata to store tool statistics assert ( self.config.ajet.data.max_prompt_length diff --git a/ajet/launcher.py b/ajet/launcher.py index 73a347aa..47345ce2 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -60,10 +60,10 @@ def parse_args(): help="Launch appworld", ) parser.add_argument( - "--with-finworld", + "--with-deepfinance", action="store_true", default=False, - help="Launch finworld", + help="Launch deepfinance", ) parser.add_argument( "--with-webshop", @@ -303,8 +303,8 @@ def main(): if args.with_appworld: pty_launch("appworld") - if args.with_finworld: - pty_launch("finworld") + if args.with_deepfinance: + pty_launch("deepfinance") if args.with_crafters: pty_launch("crafters") diff --git a/ajet/schema/task.py b/ajet/schema/task.py index 6d94796c..a20a4b59 100644 --- a/ajet/schema/task.py +++ b/ajet/schema/task.py @@ -43,4 +43,4 @@ class WorkflowOutput(BaseModel): reward: Union[float, List[float], None] = Field(default=None) is_success: Union[bool, None] = Field(default=None) metadata: Dict[str, Any] = Field(default_factory=dict) - log_metrics: Dict[str, Union[float, List[float]]] = Field(default_factory=dict) + log_metrics: Dict[str, Union[float, List[float], Dict[str, Any]]] = Field(default_factory=dict) diff --git a/ajet/task_reader/__init__.py b/ajet/task_reader/__init__.py index 2d7d7322..b431456f 100644 --- a/ajet/task_reader/__init__.py +++ b/ajet/task_reader/__init__.py @@ -61,6 +61,10 @@ def __init__(self, reader_type, reader_config): self.task_reader = DataGeneratorTaskReader(reader_config) elif task_reader_type == "random_dummy": self.task_reader = RandomDummyTaskReader(reader_config) + elif task_reader_type == "deep_finance": + # deep_finance: load message from JSON file and assemble init_messages, tool calls go through env_service + from tutorial.example_deep_finance.deep_finance_reader import DeepFinanceReader + self.task_reader = DeepFinanceReader(reader_config) else: raise ValueError(f"Unsupported task reader type: {task_reader_type}") diff --git a/ajet/task_rollout/resource_keeper.py b/ajet/task_rollout/resource_keeper.py index 2498b415..5e23389e 100644 --- a/ajet/task_rollout/resource_keeper.py +++ b/ajet/task_rollout/resource_keeper.py @@ -25,7 +25,7 @@ def __enter__(self): self.tokenizer = self.workflow_task.tokenizer self.llm_inference_fn = self.workflow_task.llm_inference_fn self.observation_window = self.workflow_task.observation_window - if self.config.ajet.task_reader.type == "env_service": + if self.config.ajet.task_reader.type in ("env_service", "deep_finance"): url = self.config.ajet.task_reader.env_service.env_url env_type = self.config.ajet.task_reader.env_service.env_type self.env = EnvClientNg(base_url=url) @@ -74,7 +74,9 @@ def _initialize_environment_and_messages(self) -> List[dict]: Exception: If environment creation fails or required task data is missing """ - if self.config.ajet.task_reader.type == "env_service": + reader_type = self.config.ajet.task_reader.type + + if reader_type == "env_service": if self.env is None: raise ValueError("Environment client is None but env_service type is specified") try: @@ -95,6 +97,32 @@ def _initialize_environment_and_messages(self) -> List[dict]: if self.env is not None: self.env.release_instance(self.workflow_task.episode_uuid) raise e + elif reader_type == "deep_finance": + # deep_finance: call create_instance to register instance, but use init_messages assembled by the reader + if self.env is None: + raise ValueError("Environment client is None but deep_finance type is specified") + try: + # call create_instance, let the server create an instance, so that subsequent step() can work + self.env.create_instance( + env_type=self.env_type, + task_id=self.task_id, + instance_id=self.workflow_task.episode_uuid, + params=self.env_params, + ) + # Do not use the returned state, directly use the init_messages assembled by the reader + task = self.workflow_task.task + if task.init_messages: + init_messages = task.init_messages + else: + assert task.main_query, "deep_finance requires init_messages or main_query." + init_messages = [{"role": "user", "content": task.main_query}] + except Exception as e: + logger.bind(exception=True).exception( + f"encounter exception in env_worker.create_instance~ error={e.args}" + ) + if self.env is not None: + self.env.release_instance(self.workflow_task.episode_uuid) + raise e else: task = self.workflow_task.task if task.init_messages: @@ -177,11 +205,15 @@ def step(self, action: dict) -> Tuple[str, float, bool, dict]: action=action, ) obs = "" + reward = 0 + info = {} assert isinstance(env_output, dict) if isinstance(env_output["state"], list): # 1. If state is a list (new standard format), pass through directly obs = env_output["state"] + reward = env_output["reward"] + info = env_output["info"] else: # 2. If state is a dict (old format or error) if ("content" not in env_output["state"]) and ("error" in env_output["state"]): @@ -191,8 +223,6 @@ def step(self, action: dict) -> Tuple[str, float, bool, dict]: else: obs = env_output["state"]["content"] - reward = 0 - info = {} terminate = env_output["is_terminated"] return obs, reward, terminate, info # type: ignore diff --git a/ajet/task_runner/general_runner.py b/ajet/task_runner/general_runner.py index 7ea76710..88f9ab11 100644 --- a/ajet/task_runner/general_runner.py +++ b/ajet/task_runner/general_runner.py @@ -55,6 +55,9 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: else: raw_reward, is_success = self.get_judge().compute_reward(workflow_task, workflow_output) + if "reward_stats" in workflow_output.metadata: + workflow_output.log_metrics["reward_stats"] = workflow_output.metadata["reward_stats"] + workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue assert not isinstance( diff --git a/ajet/utils/async_utils.py b/ajet/utils/async_utils.py index 219aba9c..c5869c1e 100644 --- a/ajet/utils/async_utils.py +++ b/ajet/utils/async_utils.py @@ -1,5 +1,6 @@ import asyncio import concurrent.futures +import logging from typing import Any def run_async_coroutine_with_timeout(coro, timeout: int = 3600) -> Any: @@ -68,3 +69,50 @@ def _patched_del(self) -> None: print("Applied httpx aclose patch.") except ImportError: pass + + +def suppress_httpx_aclose_exception(): + """ + Suppress the 'Task exception was never retrieved' error from httpx AsyncClient.aclose(). + This error occurs when the event loop is closed before the AsyncClient is properly closed. + """ + # Custom exception handler for asyncio + def custom_exception_handler(loop, context): + exception = context.get('exception') + message = context.get('message', '') + + # Check if this is the specific httpx aclose RuntimeError we want to suppress + if exception is not None: + if isinstance(exception, RuntimeError): + exc_str = str(exception) + if 'unable to perform operation on' in exc_str and 'the handler is closed' in exc_str: + return # Suppress this specific error + if 'TCPTransport' in exc_str and 'closed' in exc_str: + return # Suppress this specific error + + # For other exceptions, use the default handler + loop.default_exception_handler(context) + + # Apply custom exception handler to current or new event loop + try: + loop = asyncio.get_running_loop() + loop.set_exception_handler(custom_exception_handler) + except RuntimeError: + # No running loop, will be applied when loop starts + pass + + # Also filter the logging output for this specific error + class HttpxAcloseFilter(logging.Filter): + def filter(self, record): + msg = record.getMessage() + if 'Task exception was never retrieved' in msg and 'aclose' in msg: + return False + if 'unable to perform operation on' in msg and 'the handler is closed' in msg: + return False + if 'TCPTransport' in msg and 'closed' in msg: + return False + return True + + # Apply filter to root logger and asyncio logger + logging.getLogger().addFilter(HttpxAcloseFilter()) + logging.getLogger('asyncio').addFilter(HttpxAcloseFilter()) diff --git a/ajet/utils/env_service_client/env_client_ng.py b/ajet/utils/env_service_client/env_client_ng.py index bee86619..a8e1112f 100644 --- a/ajet/utils/env_service_client/env_client_ng.py +++ b/ajet/utils/env_service_client/env_client_ng.py @@ -49,7 +49,7 @@ def retry_call( class EnvClient: def __init__(self, base_url: str = "http://localhost:8000"): self.base_url = base_url.rstrip("/") - self.timeout = 30.0 + self.timeout = 300.0 def _make_request( self, diff --git a/ajet/utils/metric_helper/__init__.py b/ajet/utils/metric_helper/__init__.py index 70ce2818..a0475743 100644 --- a/ajet/utils/metric_helper/__init__.py +++ b/ajet/utils/metric_helper/__init__.py @@ -7,9 +7,9 @@ def save_trajectory_as_json_file(ctx_trackers, global_steps, config, prefix): if config.ajet.trainer_common.save_trajectory_as_json_file: save_trajectory_as_json(ctx_trackers, global_steps, prefix) -def update_metrics(context_tracker_arr, metrics:dict): - tool_metrics = compute_tool_metrics_from_trajectories(context_tracker_arr) - reward_metrics = compute_reward_metrics_from_trajectories(context_tracker_arr) +def update_metrics(context_tracker_arr, metrics:dict, prefix): + tool_metrics = compute_tool_metrics_from_trajectories(context_tracker_arr, prefix) + reward_metrics = compute_reward_metrics_from_trajectories(context_tracker_arr, prefix) if tool_metrics: metrics.update(tool_metrics) if reward_metrics: diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py index 49e069bf..76d034bf 100644 --- a/ajet/utils/metric_helper/reward_metric_helper.py +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -1,8 +1,8 @@ """ -FinWorld Reward Metrics Helper +deep_finance Reward Metrics Helper Provides standalone utility functions for reward_stats extraction and SwanLab metrics formatting. -Decouples finworld-specific logic from core code, reducing intrusion into native_compat_trainer. +Decouples deep_finance-specific logic from core code, reducing intrusion into native_compat_trainer. SwanLab metrics directory structure: - rewards/ Top-level aggregated scores @@ -20,45 +20,19 @@ def extract_reward_stats_from_trajectories(trajectories: List[Any]) -> List[Dict Extract reward_stats from trajectories list. Args: - trajectories: List of trajectory objects containing workflow_metadata + trajectories: List of trajectory objects containing log_metrics Returns: List of reward_stats dictionaries """ reward_stats_list = [] for traj in trajectories: - if hasattr(traj, 'workflow_metadata') and traj.workflow_metadata: - if 'reward_stats' in traj.workflow_metadata: - reward_stats_list.append(traj.workflow_metadata['reward_stats']) + if hasattr(traj, 'log_metrics') and traj.log_metrics: + if 'reward_stats' in traj.log_metrics: + reward_stats_list.append(traj.log_metrics['reward_stats']) return reward_stats_list -def extract_reward_stats_from_cmts(cmts: List[Any]) -> tuple[List[Dict[str, Any]], Dict[str, int]]: - """ - Extract reward_stats from cmts list and return debug statistics. - - Args: - cmts: List of cmt objects containing workflow_metadata - - Returns: - Tuple of (reward_stats_list, debug_stats) - """ - reward_stats_list = [] - debug_stats = { - 'total_cmts': len(cmts), - 'has_workflow_metadata': 0, - 'has_reward_stats': 0, - } - - for _cmt in cmts: - if hasattr(_cmt, 'workflow_metadata') and _cmt.workflow_metadata: - debug_stats['has_workflow_metadata'] += 1 - if 'reward_stats' in _cmt.workflow_metadata: - debug_stats['has_reward_stats'] += 1 - reward_stats_list.append(_cmt.workflow_metadata['reward_stats']) - - return reward_stats_list, debug_stats - def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str = "") -> Dict[str, float]: """ @@ -103,7 +77,6 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str if openjudge_enabled_count > 0: # ========== OpenJudge Metrics ========== - metrics[f"{prefix}rewards/openjudge_enabled_rate"] = openjudge_enabled_count / n * 100 # Dynamically extract OpenJudge grader fields # Currently supported graders: report_resolution, trajectory_faithfulness, @@ -142,48 +115,19 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str rm_raw_list = [rs.get('rm_raw', 0.0) for rs in reward_stats_list] rm_contribution_list = [rs.get('rm_contribution', 0.0) for rs in reward_stats_list] - # RefJudge - ref_final_raw_list = [rs.get('ref_final_raw', 0.0) for rs in reward_stats_list] - ref_citation_raw_list = [rs.get('ref_citation_raw', 0.0) for rs in reward_stats_list] - ref_grounding_raw_list = [rs.get('ref_grounding_raw', 0.0) for rs in reward_stats_list] - ref_contribution_list = [rs.get('ref_contribution', 0.0) for rs in reward_stats_list] - - # StructureJudge - structure_raw_list = [rs.get('structure_raw', 0.0) for rs in reward_stats_list] - structure_contribution_list = [rs.get('structure_contribution', 0.0) for rs in reward_stats_list] - # dimensions/ raw scores metrics[f"{prefix}rewards/dimensions/rm_raw_mean"] = float(np.mean(rm_raw_list)) - metrics[f"{prefix}rewards/dimensions/ref_final_raw_mean"] = float(np.mean(ref_final_raw_list)) - metrics[f"{prefix}rewards/dimensions/ref_citation_raw_mean"] = float(np.mean(ref_citation_raw_list)) - metrics[f"{prefix}rewards/dimensions/ref_grounding_raw_mean"] = float(np.mean(ref_grounding_raw_list)) - metrics[f"{prefix}rewards/dimensions/structure_raw_mean"] = float(np.mean(structure_raw_list)) # contribution/ weighted contributions metrics[f"{prefix}rewards/contribution/rm_contribution_mean"] = float(np.mean(rm_contribution_list)) - metrics[f"{prefix}rewards/contribution/ref_contribution_mean"] = float(np.mean(ref_contribution_list)) - metrics[f"{prefix}rewards/contribution/structure_contribution_mean"] = float(np.mean(structure_contribution_list)) - # Enabled state statistics - ref_judge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('ref_judge_enabled', False)) - if ref_judge_enabled_count > 0: - metrics[f"{prefix}rewards/ref_judge_enabled_rate"] = ref_judge_enabled_count / n * 100 - - structure_judge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('structure_judge_enabled', False)) - if structure_judge_enabled_count > 0: - metrics[f"{prefix}rewards/structure_judge_enabled_rate"] = structure_judge_enabled_count / n * 100 # Time consumption statistics rm_time_list = [rs.get('rm_time', 0.0) for rs in reward_stats_list] - refstruc_time_list = [rs.get('refstruc_time', 0.0) for rs in reward_stats_list] - metrics[f"{prefix}judge_time/rm_time_mean"] = float(np.mean(rm_time_list)) - metrics[f"{prefix}judge_time/refstruc_time_mean"] = float(np.mean(refstruc_time_list)) if rm_time_list: metrics[f"{prefix}judge_time/rm_time_max"] = float(np.max(rm_time_list)) - if refstruc_time_list: - metrics[f"{prefix}judge_time/refstruc_time_max"] = float(np.max(refstruc_time_list)) # ========== General Time Consumption Statistics ========== judge_total_time_list = [rs.get('judge_total_time', 0.0) for rs in reward_stats_list] @@ -194,7 +138,7 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str return metrics -def compute_reward_metrics_from_trajectories(trajectories: List[Any]) -> Dict[str, float]: +def compute_reward_metrics_from_trajectories(trajectories: List[Any], prefix: str = "") -> Dict[str, float]: """ Training phase: Extract reward_stats from trajectories and compute metrics. @@ -205,27 +149,5 @@ def compute_reward_metrics_from_trajectories(trajectories: List[Any]) -> Dict[st Formatted metrics dictionary """ reward_stats_list = extract_reward_stats_from_trajectories(trajectories) - return compute_reward_metrics(reward_stats_list, prefix="train_") - - -def compute_reward_metrics_from_cmts(cmts: List[Any], print_debug: bool = True) -> Dict[str, float]: - """ - Validation phase: Extract reward_stats from cmts and compute metrics. - - Args: - cmts: List of cmt objects - print_debug: Whether to print debug information - - Returns: - Formatted metrics dictionary (with "val_reward/" prefix) - """ - reward_stats_list, debug_stats = extract_reward_stats_from_cmts(cmts) - - if print_debug: - print(f"\n[DEBUG eval_dataset()] reward_stats statistics:") - print(f" - Total cmts count: {debug_stats['total_cmts']}") - print(f" - Has workflow_metadata: {debug_stats['has_workflow_metadata']}") - print(f" - Has reward_stats: {debug_stats['has_reward_stats']}") - print(f" - Extracted samples count: {len(reward_stats_list)}") + return compute_reward_metrics(reward_stats_list, prefix=prefix) - return compute_reward_metrics(reward_stats_list, prefix="val_") diff --git a/ajet/utils/metric_helper/save_trajectory_as_json.py b/ajet/utils/metric_helper/save_trajectory_as_json.py index 344a6ab4..9dd51868 100644 --- a/ajet/utils/metric_helper/save_trajectory_as_json.py +++ b/ajet/utils/metric_helper/save_trajectory_as_json.py @@ -22,7 +22,7 @@ def save_trajectory_as_json(ctx_trackers, global_steps, prefix="train"): else: ctx_tracker.tag = "half_success" - formatted_traj = convert_grouped_steps_to_openai_format(ctx_tracker.timeline_cache) + formatted_traj = convert_grouped_steps_to_openai_format(ctx_tracker.saved_timelines) # Prepare trajectory data traj_data = { @@ -51,6 +51,5 @@ def save_trajectory_as_json(ctx_trackers, global_steps, prefix="train"): with open(traj_file_path, "w", encoding="utf-8") as f: json.dump(traj_data, f, ensure_ascii=False, indent=2) - # Print confirmation for evaluation trajectories - if prefix != "train": - print(f"Saved trajectory to {traj_file_path}") + + print(f"Saved trajectory to {traj_file_path}") diff --git a/ajet/utils/metric_helper/tool_metric_helper.py b/ajet/utils/metric_helper/tool_metric_helper.py index 51a488b8..3ce5da21 100644 --- a/ajet/utils/metric_helper/tool_metric_helper.py +++ b/ajet/utils/metric_helper/tool_metric_helper.py @@ -1,8 +1,8 @@ """ -FinWorld Tool Metrics Helper +DeepFinance Tool Metrics Helper Specialized module for extracting tool-related statistics and formatting SwanLab reports. -Extracts data from workflow_metadata['tool_stats']. +Extracts data from log_metrics['tool_stats']. SwanLab metrics directory structure: - tool_stats/ Overall statistics (success rate, cache hit rate, etc.) @@ -20,16 +20,16 @@ def extract_tool_stats_from_trajectories(trajectories: List[Any]) -> List[Dict[s Extract tool_stats from trajectories list. Args: - trajectories: List of trajectory objects containing workflow_metadata + trajectories: List of trajectory objects containing log_metrics Returns: List of tool_stats dictionaries """ tool_stats_list = [] for traj in trajectories: - if hasattr(traj, 'workflow_metadata') and traj.workflow_metadata: - if 'tool_stats' in traj.workflow_metadata: - tool_stats_list.append(traj.workflow_metadata['tool_stats']) + if hasattr(traj, 'log_metrics') and traj.log_metrics: + if 'tool_stats' in traj.log_metrics: + tool_stats_list.append(traj.log_metrics['tool_stats']) return tool_stats_list @@ -90,7 +90,6 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" if time_list: metrics[f"{prefix}tool_time/{tool_name}/mean"] = float(np.mean(time_list)) metrics[f"{prefix}tool_time/{tool_name}/max"] = float(np.max(time_list)) - metrics[f"{prefix}tool_time/{tool_name}/count"] = len(time_list) # ========== 3. Cache Hit Rate by Tool ========== tool_cache_by_name = {} @@ -109,8 +108,6 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" if total > 0: hit_rate = hits / total * 100 metrics[f"{prefix}tool_cache/{tool_name}/hit_rate"] = round(hit_rate, 2) - metrics[f"{prefix}tool_cache/{tool_name}/hits"] = hits - metrics[f"{prefix}tool_cache/{tool_name}/misses"] = misses # ========== 4. Error Rate by Tool ========== tool_error_by_name = {} @@ -128,17 +125,16 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" if calls > 0: error_rate = errors / calls * 100 metrics[f"{prefix}tool_error/{tool_name}/error_rate"] = round(error_rate, 2) - metrics[f"{prefix}tool_error/{tool_name}/calls"] = calls - metrics[f"{prefix}tool_error/{tool_name}/errors"] = errors + return metrics -def compute_tool_metrics_from_trajectories(trajectories: List[Any]) -> Dict[str, float]: +def compute_tool_metrics_from_trajectories(trajectories: List[Any], prefix: str = "") -> Dict[str, float]: """ Training phase: Extract tool_stats from trajectories and compute metrics. """ tool_stats_list = extract_tool_stats_from_trajectories(trajectories) - return compute_tool_metrics(tool_stats_list, prefix="train_") + return compute_tool_metrics(tool_stats_list, prefix=prefix) diff --git a/tutorial/example_deep_finance/deep_finance.md b/tutorial/example_deep_finance/deep_finance.md new file mode 100644 index 00000000..1ac6d0c0 --- /dev/null +++ b/tutorial/example_deep_finance/deep_finance.md @@ -0,0 +1 @@ +# deep_finance \ No newline at end of file diff --git a/tutorial/example_deep_finance/deep_finance.py b/tutorial/example_deep_finance/deep_finance.py new file mode 100644 index 00000000..470e6225 --- /dev/null +++ b/tutorial/example_deep_finance/deep_finance.py @@ -0,0 +1,225 @@ +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask +from agentscope.message import Msg +from pydantic import Field +import logging +import threading +import time +import copy +from loguru import logger + + +# 创建信号量,允许同时12个线程运行 +sem = threading.Semaphore(30) + +class ExampleDeepResearchProtocol(Workflow): + + + async def execute( + self, workflow_task: WorkflowTask, tuner: AjetTuner + ) -> WorkflowOutput: + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + # 1. 初始化消息 + # init_messages 通常是 [System, User] + init_messages = workflow_task.task.init_messages + + # 分离 System Prompt 和 Initial User Input + if len(init_messages) >= 2: + first_msg, user_msgs = init_messages[0], init_messages[1:] + else: + first_msg = {"content": "You're a helpful assistant."} + user_msgs = init_messages + + # conversation_history: 维护最原始、最标准的 OpenAI 格式数据 (含 role: tool) + # 这是"真值",用于评测和训练保存 + conversation_history = [ + {"role": "system", "content": first_msg["content"]}, + ] + conversation_history.extend(user_msgs) + + # 2. 初始化 Agent + agent = ReActAgent( + name="Qwen", + sys_prompt=first_msg["content"], # Agent 内部会自动管理 System Prompt + model=tuner.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + memory=InMemoryMemory(), + toolkit=None, + print_hint_msg=False, + ) + agent.set_console_output_enabled(False) + env = workflow_task.gym_env + + # 3. 构造初始 Agent 输入 (List[Msg]) + # 注意:这里只包含 User 消息,不含 System,因为 System 已在 agent init 中设置 + # 必须转换为 Msg 对象 + agent_input = [] + for m in user_msgs: + agent_input.append(Msg( + name=m.get("name", "user"), + content=m.get("content", ""), + role=m.get("role", "user") + )) + + # 统计信息缓存 + latest_tool_stats = None + latest_reward_stats = {} + cumulative_tool_call_time = 0.0 # 累计工具调用时间 + cumulative_tool_time = {} # 按工具区分的累计耗时: {tool_name: [time1, time2, ...]} + step = 0 + for step in range(tuner.config.ajet.rollout.multi_turn.max_steps): + + # === Agent 推理 === + _llm_start = time.time() + # 传入增量消息 (agent_input),Agent 会将其添加到内存并生成回复 + reply_message = await agent(agent_input) + _llm_elapsed = time.time() - _llm_start + # 提取纯文本 content(兼容多模态格式) + if isinstance(reply_message.content, list): + # 多模态格式: [{'type': 'text', 'text': '...'}] + content_text = ''.join(item.get('text', '') for item in reply_message.content if isinstance(item, dict) and item.get('type') == 'text') + else: + content_text = reply_message.content + + content_preview = content_text[:100].replace('\n', ' ') + + # === 早期终止检查:在调用 env.step() 前检查 context_overflow === + # 修复问题:避免 token_overflow 后还继续调用工具导致阻塞 + if tuner.get_context_tracker().context_overflow: + logger.warning(f"上下文溢出,跳过 env.step(),在第 {step + 1} 步立即结束") + # 构造一个默认的结束响应 + conversation_history.append({ + "role": "assistant", + "content": content_text + }) + break + + # === Env 执行 === + _env_start = time.time() + with sem: + obs, reward, terminate, info = env.step( + action={"content": content_text, "role": "assistant"} + ) + _env_elapsed = time.time() - _env_start + + # === 3. 更新 conversation_history (Full History) === + # A. 添加 Assistant 消息 (补全 tool_calls) + current_assistant_msg = { + "role": "assistant", + "content": content_text + } + if info and 'generated_tool_calls' in info and info['generated_tool_calls']: + current_assistant_msg['tool_calls'] = info['generated_tool_calls'] + conversation_history.append(current_assistant_msg) + + # B. 添加 Tool 消息 (直接使用 obs) + # 注意:obs 可能是 [tool_results_msgs] 套了一层,需要解包 + if isinstance(obs, list): + actual_msgs = obs[0] if (len(obs) == 1 and isinstance(obs[0], list)) else obs + conversation_history.extend(actual_msgs) + else: + conversation_history.append({"role": "user", "content": obs}) + + # === 4. 更新统计信息 === + if info: + if 'tool_stats' in info: + latest_tool_stats = info['tool_stats'] + if latest_tool_stats.get('total_calls', 0) > 0: + logger.info(f"步骤 {step + 1} 工具统计: 调用={latest_tool_stats.get('total_calls', 0)}, " + f"成功率={latest_tool_stats.get('success_rate', 0):.1f}%") + if 'reward_stats' in info: + latest_reward_stats = info['reward_stats'] + # 累加工具调用时间 + step_tool_call_time = latest_reward_stats.get('tool_call_time', 0.0) + cumulative_tool_call_time += step_tool_call_time + # 累加按工具区分的耗时 + step_tool_time = latest_reward_stats.get('tool_time', {}) + for tool_name, time_list in step_tool_time.items(): + if tool_name not in cumulative_tool_time: + cumulative_tool_time[tool_name] = [] + if isinstance(time_list, list): + cumulative_tool_time[tool_name].extend(time_list) + + # === 5. 准备下一轮 Agent 输入 (Incremental) === + # 将 Env 返回的 obs 转换为 Msg 对象列表,供下一轮 agent() 调用 + # 关键:这里只放新的 obs,不要放完整的 history + agent_input = [] + + if isinstance(obs, list): + # Standard Mode: obs 是 tool messages 列表 + # 注意:deep_finance_env.step 返回 {"state": [tool_results_msgs]} 套了一层列表 + # BaseGymEnv.step 直接透传,所以 obs = [tool_results_msgs] + # 需要解包获取实际的消息列表 + actual_msgs = obs[0] if (len(obs) == 1 and isinstance(obs[0], list)) else obs + + # 按照 AgentScope 的 ContentBlock 格式转换消息 + # Agent.memory 会自动保存 assistant 的 tool_call 信息 + # 这里只需要传入 tool_result 消息即可 + for idx, m in enumerate(actual_msgs): + origin_role = m.get('role', 'user') + if origin_role == 'tool': + # 使用 ToolResultBlock 格式,作为 user 消息的 content + tool_result_block = { + "type": "tool_result", + "id": m.get('tool_call_id', ''), + "output": m.get('content', ''), + "name": m.get('name', '') + } + new_msg = Msg( + name="tool", + content=[tool_result_block], + role="user" + ) + agent_input.append(new_msg) + else: + # 其他消息(如 user 提示)直接添加 + content = m.get('content') + if content is None: content = "" + valid_role = origin_role if origin_role in ['user', 'assistant', 'system'] else 'user' + new_msg = Msg( + name=m.get('name', valid_role), + content=content, + role=valid_role + ) + agent_input.append(new_msg) + else: + # Legacy Mode + agent_input.append(Msg(name="env", content=obs, role="user")) + + # === 6. 终止检查 === + if terminate: + break + + if tuner.get_context_tracker().context_overflow: + logger.warning(f"上下文溢出,在第 {step + 1} 步结束") + break + + # === 结束处理 === + final_tool_stats = latest_tool_stats or { + 'total_calls': 0, 'total_errors': 0, 'success_calls': 0, 'success_rate': 0.0, + 'cache_hits': 0, 'cache_misses': 0 + } + # 将累计的 tool_time 合并到 tool_stats 中 + final_tool_stats['tool_time'] = cumulative_tool_time + final_tool_stats['tool_call_time'] = cumulative_tool_call_time + + logger.info(f"任务完成统计 (Task ID: {workflow_task.task.task_id}):") + logger.info(f" 总步骤: {step + 1}") + logger.info(f" 总调用: {final_tool_stats.get('total_calls', 0)}") + logger.info(f" 成功率: {final_tool_stats.get('success_rate', 0):.2f}%") + + return WorkflowOutput( + reward=None, + metadata={ + "total_step": step, + "tool_success_rate": round(final_tool_stats.get('success_rate', 0.0), 2), + "conversation_history": conversation_history, + "query": workflow_task.task.main_query, + "task_id": workflow_task.task.task_id, + }, + log_metrics={ + "tool_stats": final_tool_stats, + "reward_stats": latest_reward_stats, + } + ) \ No newline at end of file diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh new file mode 100644 index 00000000..6fd46f45 --- /dev/null +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -0,0 +1,224 @@ +#!/bin/bash +set -e +#=============================================================================== +# 1. 配置区域 - 用户只需修改这里 +#=============================================================================== +SUFFIX="ajet_deep_finance" # 实验后缀,影响所有日志和实验名称 +PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 + +# OpenJudge 模型配置 +OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 +RM_LLM='qwen-max' # RM Gallery 评分模型 +JUDGE_CONCURRENCY=10 + +# 奖励权重配置 +RM_WEIGHT=0.4 +CITATION_AUDIT_WEIGHT=0.2 +REPORT_RESOLUTION_WEIGHT=0.2 +TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 + +# 训练参数配置 +NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 +TRAIN_BATCH_SIZE=32 # 训练batchsize +NUM_STEPS=6 # 每个样本step轮数 +DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 + +# 主目录 +export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet" + +NNODES=${WORLD_SIZE} + +# 涉密的配置(API_KEY以及模型、数据位置)从.env读取 +cd ${AJET_ROOT} +source .venv/bin/activate + +# API密钥配置 - 从 .env 文件加载 +ENV_FILE="${AJET_ROOT}/.env" +if [ -f "$ENV_FILE" ]; then + set -a + source "$ENV_FILE" + set +a + echo -e "\033[32m已从 $ENV_FILE 加载环境变量\033[0m" +else + echo -e "\033[31m警告: 找不到 .env 文件: $ENV_FILE\033[0m" +fi + +#=============================================================================== +# 2. 动态生成配置文件 (从yaml template生成yaml) +#=============================================================================== +# 修改:配置文件生成路径,现在动态生成到 yaml 目录下 +CONFIG_TEMPLATE="tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml" +CONFIG_FILE="${AJET_ROOT}/tutorial/example_deep_finance/yaml/${SUFFIX}.yaml" +mkdir -p $(dirname ${CONFIG_FILE}) + +sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ + -e "s|{{PREFIX}}|${PREFIX}|g" \ + -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ + -e "s|{{NNODES}}|${NNODES}|g" \ + -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ + -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ + -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ + -e "s|{{RM_LLM}}|${RM_LLM}|g" \ + -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ + -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ + -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ + -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ + -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ + -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ + -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ + -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ + -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ + -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ + -e "s|{{CKPT_SAVE_PATH}}|${CKPT_SAVE_PATH}|g" \ + ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} + +echo "配置文件已生成: ${CONFIG_FILE}" +echo "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" + +#=============================================================================== +# 3. 环境配置 +#=============================================================================== +# MongoDB 缓存配置 +CACHE_TYPE="mongodb" +MONGO_URI="mongodb://${ADDR}:27117/" +MONGO_DB_NAME="finworld_cache" +MONGO_COLLECTION_NAME="tool_cache" +export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME + +# DeepFinance MCP 配置 +DEEPFINANCE_MCP_CONFIG="${AJET_ROOT}/tutorial/example_deep_finance/config/mcp_finance_tool_generated.json" + +# 动态生成 MCP 配置文件 +mkdir -p $(dirname ${DEEPFINANCE_MCP_CONFIG}) +cat > ${DEEPFINANCE_MCP_CONFIG} << EOF +{ + "mcpServers": { + "flowllm": { + "transport": "sse", + "url": "http://${ADDR}:${MCP_PORT}/sse", + "timeout": 600, + "sse_read_timeout": 1200 + } + } +} +EOF +export DEEPFINANCE_MCP_CONFIG DEEPFINANCE_TOOL_RESULT_MAX_CHARS + +# 其他服务配置 +HF_ENDPOINT="https://hf-mirror.com" +ES_HOSTS="http://11.160.132.46:8200" +export HF_ENDPOINT ES_HOSTS + +# log 文件位置 +CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") +LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" +MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" +ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" +TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" + +# 多机训练参数配置 +GPUS_PER_NODE=8 +EXPECTED_WORKERS=$WORLD_SIZE + + +#=============================================================================== +# 4. 工具函数 以及 NCCL 配置(固定) +#=============================================================================== +print_green() { + echo -e "\033[32m$1\033[0m" +} + +log() { + echo -e "\033[0;32m[$(date '+%Y-%m-%d %H:%M:%S')]\033[0m \033[0;34m[INFO]\033[0m $1" +} + +check_workers() { + local status_output=$(ray status 2>/dev/null) + if [ -z "$status_output" ]; then echo 0; return; fi + local node_count=$(echo "$status_output" | grep -E "^[[:space:]]*1[[:space:]]+node_" | wc -l) + if [ "$node_count" -gt 0 ]; then echo $node_count; return; fi + echo $(echo "$status_output" | grep -o "node_[0-9a-f]\+" | sort -u | wc -l) +} + +check_gpu_resources() { + gpu_count=$(ray status 2>/dev/null | grep -A 10 "Resources" | grep "GPU" | awk '{print $1}' | cut -d'/' -f2) + if [ -z "$gpu_count" ]; then echo 0; else printf "%.0f" "$gpu_count"; fi +} + + +export NCCL_TIMEOUT=1800 +export NCCL_DEBUG=WARN +export NCCL_IB_TIMEOUT=23 +export NCCL_ASYNC_ERROR_HANDLING=1 + +#=============================================================================== +# 5. 工具envservice 环境变量 +#=============================================================================== + +export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" +export RAY_CLUSTER_MODE="multi_node" +export DEEPFINANCE_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 +export DEEPFINANCE_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && DEEPFINANCE_TOOL_RESULT_MAX_CHARS=${DEEPFINANCE_TOOL_RESULT_MAX_CHARS} DEEPFINANCE_MCP_CONFIG=${DEEPFINANCE_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" + + +#=============================================================================== +# 6. 主流程 +#=============================================================================== +log "开始多机多卡训练: ${SUFFIX}" +log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" +mkdir -p ${LOG_DIR} +mkdir -p $(dirname ${CONFIG_FILE}) + +#=============================================================================== +# 6.1 Master 节点启动流程 +#=============================================================================== +if [[ $HOSTNAME == *"-master-"* ]]; then + print_green "==> This is MASTER node: $HOSTNAME" + + #--------------------------------------------------------------------------- + # 6.1.1 清理和初始化 Ray + #--------------------------------------------------------------------------- + rm -f "$MASTER_IP_FILE" + ray stop --force || true + sleep 3 + + #--------------------------------------------------------------------------- + # 6.1.2 启动 Ray Head + #--------------------------------------------------------------------------- + print_green "Starting Ray head node at $MASTER_ADDR" + ray start --head --node-ip-address $MASTER_ADDR --num-gpus 8 + sleep 10 + echo $MASTER_ADDR > $MASTER_IP_FILE + + #--------------------------------------------------------------------------- + # 6.1.3 启动训练任务 + #--------------------------------------------------------------------------- + print_green "Starting training job..." + source .venv/bin/activate + export RAY_ADDRESS="ray://localhost:10001" + + print_green "===================================" + print_green "Training Configuration" + print_green "Total GPUs: $((NNODES * GPUS_PER_NODE))" + print_green "Log: ${TRAIN_LOG}" + print_green "===================================" + + # 启动训练任务(最核心) + python ajet/launcher.py \ + --with-deepfinance \ + --conf ${CONFIG_FILE} \ + --backbone="verl" \ + 2>&1 | tee ${TRAIN_LOG} + + +#=============================================================================== +# 6.2 Worker 节点启动流程 +#=============================================================================== +else + print_green "==> This is WORKER node: $HOSTNAME" + while [ ! -f $MASTER_IP_FILE ]; do sleep 5; done + MASTER_ADDR=$(cat $MASTER_IP_FILE) + ray stop || true + ray start --address $MASTER_ADDR:6379 --num-gpus 8 + while true; do sleep 60; done +fi \ No newline at end of file diff --git a/tutorial/example_deep_finance/deep_finance.yaml b/tutorial/example_deep_finance/deep_finance.yaml new file mode 100644 index 00000000..f67d5a8b --- /dev/null +++ b/tutorial/example_deep_finance/deep_finance.yaml @@ -0,0 +1,88 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: ajet_deep_finance + experiment_name: "ajet_deep_finance" + # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) + judge: + openjudge_llm: qwen-flash # OpenJudge 模型 + rm_llm: qwen-max # RM Gallery 模型 + concurrency: 10 # Judge 并发数 + train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径 + val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 + # OpenJudge 权重配置 + report_resolution_weight: 0.2 # 报告质量评估 + trajectory_faithfulness_weight: 0.2 # 事实准确性评估 + citation_audit_weight: 0.2 # 引用审计评估 (覆盖率 + 真实性) + rm_weight: 0.4 # RM Gallery 权重 + task_judge: + # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) + judge_protocol: tutorial.example_deep_finance.deep_finance_judge->DeepFinanceJudgeByOpenJudge + model: + # ✨✨✨✨ 设置待训练的模型 + path: {{MODEL_PATH}} + trainer_common: + nnodes: 8 + n_gpus_per_node: 8 + val_before_train: True + val_pass_n: 8 + save_freq: 10 + test_freq: 2 + total_epochs: 200 + save_trajectory_as_json_file: True + rollout: + # ✨✨✨✨ 编写并选择Agent + user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol + force_disable_toolcalls: True + enable_oversample: False + tensor_model_parallel_size: 8 + num_repeat: 4 + max_env_worker: 64 # 增加环境并行数 + max_num_seqs: 64 # 增加VLLM并发序列数 + max_response_length_in_one_turn: 8000 + max_model_len: 50000 + agent_madness_reward: 0.0 + compute_madness_checklist: None + multi_turn: + max_steps: 6 + interchange_server: + interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + debug: + debug_max_parallel: 64 # 增加并行任务数,充分利用GPU + debug_first_n_tasks: 100 # 增加处理的任务数 + data: + train_batch_size: 32 + max_prompt_length: 8000 + max_response_length: 41000 + + task_reader: + type: deep_finance # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service + deep_finance: + training: + file_path: {{TRAIN_PATH}} + validation: + file_path: {{VAL_PATH}} + # env_service 仍需配置(用于工具调用) + env_service: + env_type: "finworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code +trainer: + default_local_dir: {{CKPT_SAVE_PATH}} + # resume_mode: disable # 禁用自动恢复,从头开始训练 +actor_rollout_ref: + rollout: + tensor_model_parallel_size: 8 + gpu_memory_utilization: 0.8 +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py new file mode 100644 index 00000000..f49d88d3 --- /dev/null +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -0,0 +1,773 @@ +"""DeepFinance Task Judge - OpenJudge 版本 +集成: RM Gallery, OpenJudge Graders (含 CitationAudit) +""" + +import os +import json +import asyncio +import time +import logging +from datetime import datetime +from typing import Dict, Any, Optional, Tuple, List + +from ajet.task_judge.base_judge import BaseJudge +from ajet.workflow import WorkflowOutput, WorkflowTask + +from openjudge.graders.agent.action.action_loop import ActionLoopDetectionGrader +from openjudge.graders.agent.observation.observation_information_gain import ( + ObservationInformationGainGrader, +) +from openjudge.graders.agent.trajectory.trajectory_comprehensive import ( + TrajectoryComprehensiveGrader, +) +from openjudge.models.openai_chat_model import OpenAIChatModel +from openjudge.models.schema.prompt_template import LanguageEnum +from openjudge.runner.grading_runner import GraderConfig, GradingRunner +from openjudge.scenarios.deep_research.graders.financial_report_resolution import ( + FinancialReportResolutionGrader, +) +from openjudge.scenarios.deep_research.graders.financial_trajectory_faithfulness import ( + FinancialTrajectoryFaithfulGrader, +) +from openjudge.scenarios.deep_research.graders.rubrics_based_trajectory_performance import ( + RubricsBasedTrajectoryPerformance, +) +from openjudge.scenarios.deep_research.graders.financial_report_citation_audit import ( + FinancialReportCitationAuditGrader, +) + + +# RewardStats 不再使用,OpenJudge 版本直接使用字典存储 +# Reference Answer 路径现在从 config 中读取,见 _init_reference_answers 方法 + +# OpenJudge imports +# ============================================================================= +# 全局辅助函数 +# ============================================================================= + +def extract_text_content(content) -> str: + """统一提取纯文本内容""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + texts = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + texts.append(item.get("text", "")) + elif isinstance(item, str): + texts.append(item) + return "".join(texts) + return str(content) + + +def load_reference_answers_from_file(file_path: str) -> Tuple[Dict[str, str], Dict[str, str]]: + """加载参考答案 (RM Gallery 需要)""" + if not os.path.exists(file_path): + raise FileNotFoundError(f"Reference answers file not found: {file_path}") + try: + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + ref_answers, ref_domains = {}, {} + for item in data: + task_id = item.get("task", {}).get("task_id") + if not task_id or "answer" not in item: continue + ref_answers[task_id] = item["answer"] + domain = item.get("task", {}).get("metadata", {}).get("domain") + if domain: ref_domains[task_id] = domain + return ref_answers, ref_domains + except Exception as e: + raise ValueError(f"Error loading reference answers: {e}") + + +# ============================================================================= +# DeepFinanceJudgeByOpenJudge 类 +# ============================================================================= + +class DeepFinanceJudgeByOpenJudge(BaseJudge): + """ + 使用 OpenJudge 框架的 DeepFinance Judge + 集成: RM Gallery, OpenJudge Graders (含 CitationAudit) + + 分析: + - compute_reward 每次处理 **一条采样**(单个 workflow_output) + - 输入:workflow_task, workflow_output + - 输出:(final_reward: float, is_success: bool) + - 副作用:更新 workflow_output.metadata["reward_stats"] + + 注意:GradingRunner 不能使用单例模式,因为其内部 Semaphore 会绑定到创建时的事件循环 + """ + + _model_instance = None # Model 可以复用 + _rm_evaluator_instance = None # RM Gallery Evaluator (单例) + _ref_answers_cache: Dict[str, Dict[str, str]] = {} # 参考答案缓存 + _ref_domains_cache: Dict[str, Dict[str, str]] = {} # 领域缓存 + + def __init__(self, config): + super().__init__(config) + self._setup_weights() + self._init_openjudge_model() # 只初始化 model,runner 在每次调用时创建 + self._init_rm_components() # 初始化 RM Gallery 组件 + self._init_reference_answers() # 初始化参考答案 + + def _setup_weights(self): + """ + 配置 OpenJudge 各 grader 的权重并归一化 + + graders 对应关系: + - financial_report_resolution: 报告质量和问题解决能力 + - financial_trajectory_faithfulness: 事实准确性(忠实度) + - citation_audit: 引用审计(覆盖率 + 真实性) + - rubrics_based_trajectory_performance: 基于 rubrics 的评估 + - trajectory_comprehensive: 轨迹综合评估 + - observation_information_gain: 信息增益(去重) + - action_loop_detection: 动作循环检测(惩罚项) + """ + cfg = getattr(self.config, "ajet", None) + + # 定义各 grader 的权重(可从 config 中读取)- 与 deep_finance_judge.py 对齐 + self.w = { + "rm": getattr(cfg, "rm_weight", 1.0) if cfg else 1.0, # RM Gallery 权重 + "citation_audit": getattr(cfg, "citation_audit_weight", 0.0) if cfg else 0.0, # CitationAudit 权重 + "report_resolution": getattr(cfg, "report_resolution_weight", 0.0) if cfg else 0.0, + "trajectory_faithfulness": getattr(cfg, "trajectory_faithfulness_weight", 0.0) if cfg else 0.0, + # "rubrics_performance": getattr(cfg, "rubrics_performance_weight", 0.2) if cfg else 0.2, + # "trajectory_comprehensive": getattr(cfg, "trajectory_comprehensive_weight", 0.2) if cfg else 0.2, + # "information_gain": getattr(cfg, "information_gain_weight", 0.1) if cfg else 0.1, + # "action_loop": getattr(cfg, "action_loop_weight", 0.1) if cfg else 0.1 + } + + # 归一化(注意:action_loop 是惩罚项,不参与归一化;rm 需要参与归一化) + positive_weights = {k: v for k, v in self.w.items() if k != "action_loop" and v > 0} + total = sum(positive_weights.values()) + if total > 0: + for k in positive_weights: + self.w[k] = self.w[k] / total + + + def _init_openjudge_model(self): + """初始化 OpenJudge LLM Model""" + # --- model name from config.ajet.judge.* --- + openjudge_model_name = self.config.ajet.judge.openjudge_llm + openjudge_base_url = os.environ.get("OPENJUDGE_BASE_URL") + openjudge_api_key = os.environ.get("OPENJUDGE_API_KEY") + + self._model_instance = OpenAIChatModel( + model=openjudge_model_name, + base_url=openjudge_base_url, + api_key=openjudge_api_key, + ) + # 设置实例变量供 _create_runner_in_loop 使用 + self.model = self._model_instance + self.max_concurrency = getattr(self.config.ajet.judge, "concurrency", 6) + + print( + f"[Init OpenJudge Model] model={openjudge_model_name}, base_url={openjudge_base_url}, " + f"api_key={'SET' if openjudge_api_key else 'NONE'}, max_concurrency={self.max_concurrency}" + ) + + def _init_rm_components(self): + """初始化 RM Gallery Evaluator(仅当 rm_weight > 0 时)""" + self._rm_enabled = (self.w.get("rm", 0) > 0) + if self._rm_enabled: + if DeepFinanceJudgeByOpenJudge._rm_evaluator_instance is None: + self._init_rm_evaluator() + DeepFinanceJudgeByOpenJudge._rm_evaluator_instance = self.rm_evaluator + else: + self.rm_evaluator = DeepFinanceJudgeByOpenJudge._rm_evaluator_instance + else: + self.rm_evaluator = None + + def _init_rm_evaluator(self): + """初始化 RM Gallery Evaluator""" + try: + # Monkey patch OpenAI client timeout (RM Gallery 默认只有60s,对于30B模型不够用) + import openai + _original_openai_init = openai.OpenAI.__init__ + def _patched_openai_init(self, *args, **kwargs): + kwargs.setdefault('timeout', 600.0) # 增大到600秒 + return _original_openai_init(self, *args, **kwargs) + openai.OpenAI.__init__ = _patched_openai_init + + from rm_gallery.core.reward.registry import RewardRegistry + import logging + logging.getLogger("rm_gallery").setLevel(logging.WARNING) + + # 从 config 读取 rm_llm,环境变量作为 fallback + rm_llm_name = self.config.ajet.judge.rm_llm + rm_api_key = os.environ.get("RM_API_KEY") + rm_base_url = os.environ.get("RM_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + + rm_params = {"is_parallel": True, "enable_thinking": False, "base_url": rm_base_url} + if rm_api_key: + rm_params["api_key"] = rm_api_key + + self.rm_evaluator = RewardRegistry.get("finance_composition")( + llm=rm_llm_name, name="finance_composition", params=rm_params + ) + print(f"[Init RM Evaluator] llm={rm_llm_name}, base_url={rm_base_url}, api_key={'SET' if rm_api_key else 'NONE'} (timeout=600s)") + except Exception as e: + print(f"✗ Failed to initialize RM evaluator: {e}") + import traceback + traceback.print_exc() + self.rm_evaluator = None + + def _init_reference_answers(self): + """初始化参考答案缓存,从 config 中读取路径""" + # 从 config 中获取 reference answer 路径 + train_ref_ans_path = getattr(self.config.ajet.judge, "train_ref_ans_path", "") + val_ref_ans_path = getattr(self.config.ajet.judge, "val_ref_ans_path", "") + + def _load(path, key): + if path and key not in DeepFinanceJudgeByOpenJudge._ref_answers_cache: + try: + ans, dom = load_reference_answers_from_file(path) + DeepFinanceJudgeByOpenJudge._ref_answers_cache[key], DeepFinanceJudgeByOpenJudge._ref_domains_cache[key] = ans, dom + except Exception: + DeepFinanceJudgeByOpenJudge._ref_answers_cache[key], DeepFinanceJudgeByOpenJudge._ref_domains_cache[key] = {}, {} + _load(train_ref_ans_path, "train") + _load(val_ref_ans_path, "val") + + def _get_reference_data(self, task_id: str) -> Tuple[str, str]: + """获取任务的参考答案和领域""" + cache_key = "val" if task_id.startswith("val_") else "train" + ans = DeepFinanceJudgeByOpenJudge._ref_answers_cache.get(cache_key, {}).get(task_id, "") + dom = DeepFinanceJudgeByOpenJudge._ref_domains_cache.get(cache_key, {}).get(task_id) + return ans, dom + + + def _create_runner_in_loop(self) -> GradingRunner: + """ + 在当前事件循环中创建 GradingRunner + + 注意:GradingRunner 内部的 Semaphore 会绑定到创建时的事件循环, + 因此不能使用单例模式,必须在每次调用的事件循环中创建新实例。 + """ + language = LanguageEnum.ZH + grader_configs = self._create_grader_configs(self.model, language) + return GradingRunner( + grader_configs=grader_configs, + max_concurrency=self.max_concurrency, + show_progress=False + ) + + def _create_grader_configs(self, model: OpenAIChatModel, language: LanguageEnum) -> Dict[str, GraderConfig]: + """ + 创建所有 grader 的配置 + + 返回:Dict[str, GraderConfig] + - key: grader 名称 + - value: GraderConfig(grader=..., mapper=...) + """ + return { + # 1. 报告质量评估 - 需要 messages 和 chat_date + "report_resolution": GraderConfig( + grader=FinancialReportResolutionGrader(model=model, language=language), + mapper=lambda data: { + "messages": data["messages"], + "chat_date": data.get("chat_date") + }, + ), + + # 2. 事实准确性评估 - 需要 messages + "trajectory_faithfulness": GraderConfig( + grader=FinancialTrajectoryFaithfulGrader(model=model, language=language), + mapper=lambda data: {"messages": data["messages"]}, + ), + + # 3. 引用审计评估 - 需要 messages + "citation_audit": GraderConfig( + grader=FinancialReportCitationAuditGrader(model=model, language=language), + mapper=lambda data: {"messages": data["messages"]}, + ), + + # 4. Rubrics 评估 - 需要 messages 和 rubrics + # "rubrics_performance": GraderConfig( + # grader=RubricsBasedTrajectoryPerformance(model=model, language=language), + # mapper=lambda data: { + # "messages": data["messages"], + # "rubrics": data.get("rubrics", []) + # }, + # ), + + # 5. 轨迹综合评估 - 需要 messages + # "trajectory_comprehensive": GraderConfig( + # grader=TrajectoryComprehensiveGrader(model=model, language=language), + # mapper=lambda data: {"messages": data["messages"]}, + # ), + + # 6. 信息增益评估 - 需要 messages(非 LLM grader) + # "information_gain": GraderConfig( + # grader=ObservationInformationGainGrader(similarity_threshold=0.5), + # mapper=lambda data: {"messages": data["messages"]}, + # ), + + # 7. 动作循环检测 - 需要 messages(非 LLM grader) + # "action_loop": GraderConfig( + # grader=ActionLoopDetectionGrader(similarity_threshold=1.0), + # mapper=lambda data: {"messages": data["messages"]}, + # ), + } + + def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> Tuple[float, bool]: + """ + 主计算逻辑:使用 OpenJudge Runner.arun 计算 reward + + 流程: + 1. 从 workflow_output.metadata 提取 conversation_history、query、rubrics 等 + 2. 转换为 OpenJudge 的输入格式 (messages, chat_date, rubrics) + 3. 调用 Runner.arun([sample]) 获取所有 graders 的评分 + 4. 加权融合各 grader 分数 + 5. 计算惩罚项(tool_calls) + 6. 更新 metadata["reward_stats"] + 7. 返回 (final_reward, is_success) + """ + judge_start_time = time.time() + + try: + metadata = workflow_output.metadata + + # 1. 提取输入数据 + history = metadata.get("conversation_history", []) + query = metadata.get("query") or getattr(workflow_task.task, "main_query", "") + task_id = metadata.get("task_id") or getattr(workflow_task.task, "task_id", "") + rubrics = metadata.get("rubrics") # 可能是 None 或 list of dicts + step_reward = metadata.get("reward_stats", {}).get("step_reward", 0.0) + chat_date = metadata.get("chat_date") if metadata else datetime.now().strftime("%Y-%m-%d") + + if not history: + print(f"⚠️ Empty conversation history for task_id={task_id}") + return 0.0, False + + # 1.5 RM Gallery 评估(如果启用) + ref_ans, domain = self._get_reference_data(task_id) + assistants = [extract_text_content(m["content"]) for m in history if m["role"] == "assistant"] + + # RM Gallery 耗时记录 + rm_start_time = time.time() + if self._rm_enabled and self.rm_evaluator: + rm_raw = self._evaluate_with_rm_gallery(query, assistants[-1] if assistants else "", ref_ans, task_id, domain) + else: + rm_raw = 0.0 + rm_time = time.time() - rm_start_time + + # 2. 转换为 OpenJudge 输入格式 + openjudge_sample = self._convert_to_openjudge_format( + history=history, + query=query, + task_id=task_id, + rubrics=rubrics, + chat_date=chat_date + ) + + # 3. 调用 OpenJudge Runner.arun(异步) + grading_start_time = time.time() + grader_results = self._run_openjudge_evaluation([openjudge_sample]) + grading_time = time.time() - grading_start_time + + # 4. 提取各 grader 分数(arun 返回 Dict[str, List[GraderScore]],这里取第一条) + grader_scores, quota_exceeded_flags = self._extract_grader_scores(grader_results) + + # 5. 加权融合(包含 RM Gallery 和 OpenJudge Graders) + fused_reward, contributions = self._fuse_grader_scores(grader_scores, rm_raw) + + # 6. 计算惩罚项(保留原有的 tool_calls 惩罚逻辑) + tool_calls = metadata.get("tool_stats", {}).get("total_calls", 0) + penalty = self._compute_penalty(tool_calls) + + # 7. 汇总 + final_reward = fused_reward + step_reward + penalty + + judge_total_time = time.time() - judge_start_time + + # 8. 更新元数据(实例化 RewardStats) + time_stats = { + "rm_time": rm_time, + "grading_time": grading_time, + "judge_total_time": judge_total_time, + } + self._update_metadata_stats( + metadata=metadata, + final_reward=final_reward, + fused_reward=fused_reward, + penalty=penalty, + step_reward=step_reward, + grader_scores=grader_scores, + contributions=contributions, + time_stats=time_stats, + rm_raw=rm_raw, + quota_exceeded_flags=quota_exceeded_flags + ) + + print(f"DeepFinanceJudgeByOpenJudge: task_id={task_id}, fused={fused_reward:.4f}, final={final_reward:.4f}, rm_time={rm_time:.2f}s, grading_time={grading_time:.2f}s, total={judge_total_time:.2f}s") + + # 9. 判断是否成功(可根据实际需求调整阈值) + is_success = final_reward >= 0.7 + + return final_reward, is_success + + except Exception as e: + print(f"✗ Error in OpenJudge compute_reward: {e}") + import traceback + traceback.print_exc() + return 0.0, False + + def _convert_to_openjudge_format( + self, + history: List[Dict], + query: str, + task_id: str, + rubrics: Optional[Any], + chat_date: Optional[str] + ) -> Dict[str, Any]: + """ + 将训练框架的 conversation_history 转换为 OpenJudge 的输入格式 + + 输入: + - history: [{"role": "user/assistant/tool", "content": ..., "tool_calls": ...}, ...] + + 输出: + - { + "messages": [...], # OpenJudge 格式 + "chat_date": "YYYY-MM-DD", + "rubrics": [...] + } + """ + # 1. 规范化 messages + messages = [] + for msg in history: + content = extract_text_content(msg.get("content", "")) + normalized_msg = { + "role": msg.get("role", "user"), + "content": content + } + + # 透传 tool_calls 等字段(OpenJudge 需要) + for field in ["tool_calls", "tool_call_id", "name"]: + if field in msg: + normalized_msg[field] = msg[field] + + messages.append(normalized_msg) + + + # 3. 转换 rubrics 格式(如果存在) + # OpenJudge 期望的格式:[{"dimension": ..., "description": ..., "check_points": [...]}, ...] + openjudge_rubrics = [] + if rubrics: + if isinstance(rubrics, list): + openjudge_rubrics = rubrics + elif isinstance(rubrics, dict): + # 如果 rubrics 是 dict,尝试转换 + # 假设格式类似 {"criteria": [...], "scoring_dimensions": [...]} + if "criteria" in rubrics: + for criterion in rubrics.get("criteria", []): + openjudge_rubrics.append({ + "dimension": criterion.get("name", ""), + "description": criterion.get("description", ""), + "check_points": criterion.get("check_points", []) + }) + + return { + "messages": messages, + "chat_date": chat_date, + "rubrics": openjudge_rubrics + } + + def _run_openjudge_evaluation(self, dataset: List[Dict[str, Any]]) -> Dict[str, List[Any]]: + """ + 调用 OpenJudge Runner.arun 进行评估(带重试机制) + + 输入: + - dataset: List[Dict] - OpenJudge 格式的样本列表 + + 输出: + - Dict[str, List[GraderScore]] - 每个 grader 的评分结果 + + 注意:GradingRunner 必须在当前事件循环中创建,因为其内部 Semaphore 会绑定事件循环 + """ + result = {} + judge_instance = self # 保存引用以便在 async 函数中访问 + max_retries = 3 # 最大重试次数 + + async def run_with_retry(): + nonlocal result + last_exception = None + + for attempt in range(max_retries): + try: + # 在当前事件循环中创建 Runner(避免 Semaphore 绑定错误的事件循环) + runner = judge_instance._create_runner_in_loop() + result = await runner.arun(dataset) + return # 成功则直接返回 + except Exception as e: + last_exception = e + error_str = str(e) + + # 判断是否为可重试的连接错误 + is_connection_error = any(keyword in error_str for keyword in [ + "Connection", "connection", "TCPTransport", + "SSLWantReadError", "BrokenPipe", "timeout", + "closed", "APIConnectionError" + ]) + + if is_connection_error and attempt < max_retries - 1: + wait_time = 2 ** attempt # 指数退避: 1s, 2s, 4s + print(f"⚠️ OpenJudge connection error (attempt {attempt+1}/{max_retries}), retrying in {wait_time}s... Error: {error_str[:100]}") + await asyncio.sleep(wait_time) + continue + else: + # 非连接错误或已达最大重试次数 + raise last_exception + + # 所有重试都失败 + if last_exception: + raise last_exception + + try: + # 创建新的标准 asyncio 事件循环,并设置为当前线程的事件循环 + # 这样可以避免 Semaphore 绑定到不同事件循环的问题 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) # 关键:将新循环设置为当前线程的事件循环 + try: + loop.run_until_complete(run_with_retry()) + finally: + loop.close() + asyncio.set_event_loop(None) # 清理:避免引用已关闭的循环 + except Exception as e: + print(f"✗ OpenJudge Runner.arun failed after {max_retries} attempts: {e}") + import traceback + traceback.print_exc() + + return result + + def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[Dict[str, float], Dict[str, bool]]: + """ + 从 Runner.arun 结果中提取各 grader 的分数 + + 输入: + - grader_results: Dict[str, List[GraderScore]] + { + "report_resolution": [GraderScore(score=0.88, reason="...", metadata={...})], + "trajectory_faithfulness": [GraderScore(score=1.0, ...)], + ... + } + + 输出: + - Tuple[Dict[str, float], Dict[str, bool]] + - scores: 每个 grader 的分数(取第一条采样的分数) + - quota_exceeded_flags: 每个 grader 是否发生 429 quota exceeded + """ + scores = {} + quota_exceeded_flags = {} + + for grader_name, score_list in grader_results.items(): + quota_exceeded_flags[grader_name] = False + if score_list and len(score_list) > 0: + # 取第一条采样的分数(因为每次只评估一条) + grader_score = score_list[0] + if hasattr(grader_score, "score"): + scores[grader_name] = grader_score.score + # 检测错误类型:分数为0且有错误信息 + if grader_score.score == 0.0 and hasattr(grader_score, "reason"): + reason = str(grader_score.reason) if grader_score.reason else "" + # 检测 429 quota exceeded + if "429" in reason or "insufficient_quota" in reason or "exceeded your current quota" in reason: + quota_exceeded_flags[grader_name] = True + else: + # 如果出错,设为 0 + scores[grader_name] = 0.0 + else: + scores[grader_name] = 0.0 + + print(f" [OpenJudge Scores] {scores}") + if any(quota_exceeded_flags.values()): + quota_graders = [k for k, v in quota_exceeded_flags.items() if v] + print(f" [OpenJudge QuotaExceeded] {quota_graders}") + return scores, quota_exceeded_flags + + def _fuse_grader_scores(self, grader_scores: Dict[str, float], rm_raw: float = 0.0) -> Tuple[float, Dict[str, float]]: + """ + 加权融合各 grader 的分数(包含 RM Gallery 和 OpenJudge Graders) + + 输入: + - grader_scores: Dict[str, float] - 各 grader 的原始分数 + - rm_raw: float - RM Gallery 原始分数 + + 输出: + - (fused_reward, contributions) + - fused_reward: 加权后的总分 + - contributions: Dict[str, float] - 各 grader 的贡献分数 + """ + contributions = {} + + # 添加 RM Gallery 贡献 + contributions["rm_contribution"] = self.w.get("rm", 0.0) * rm_raw + + # 添加 OpenJudge Graders 贡献(包括 citation_audit) + for grader_name, weight in self.w.items(): + if grader_name == "rm": + continue # 已单独处理 + score = grader_scores.get(grader_name, 0.0) + contributions[grader_name] = weight * score + + fused_reward = sum(contributions.values()) + + return fused_reward, contributions + + def _evaluate_with_rm_gallery(self, query: str, current: str, reference: str, task_id: str, domain: str) -> float: + """使用 RM Gallery 评估""" + if not self.rm_evaluator or not domain or not reference: + return 0.0 + try: + from rm_gallery.core.data.schema import DataSample + sample = DataSample( + unique_id=task_id, + input=[{"role": "user", "content": query}], + output=[ + {"answer": {"role": "assistant", "content": current, "label": {"model_name": "training"}}, "steps": None}, + {"answer": {"role": "assistant", "content": reference, "label": {"model_name": "reference"}}, "steps": None}, + ], + task_category="financial_analysis", source="finance_samples", metadata={"domain": domain} + ) + result = self.rm_evaluator.evaluate(sample) + self._save_rm_log(result, query, task_id) + return result.metadata["dimension_scores"]["overall_score"]["training"] + except Exception as e: + print(f"✗ RM Gallery evaluation failed: {e}") + return 0.0 + + def _save_rm_log(self, result, query: str, task_id: str): + """保存 RM Gallery 评估日志""" + try: + log = { + "task_id": task_id, + "query": query, + "timestamp": datetime.now().isoformat(), + "scores": result.metadata.get("dimension_scores", {}) + } + save_dir = "./outputs/rm_evaluation_logs" + os.makedirs(save_dir, exist_ok=True) + with open(os.path.join(save_dir, f"rmeval_{datetime.now().strftime('%Y%m%d')}.json"), "a") as f: + f.write(json.dumps(log, ensure_ascii=False) + "\n") + except Exception: + pass + + def _compute_penalty(self, tool_calls: int) -> float: + """ + 计算工具调用惩罚(保留原有逻辑) + + - 0 次调用:-1.0 + - 1-2 次:-0.5 + - 3+ 次:0.0 + """ + if tool_calls == 0: + return -1.0 + elif tool_calls <= 2: + return -0.5 + else: + return 0.0 + + def _update_metadata_stats( + self, + metadata: Dict[str, Any], + final_reward: float, + fused_reward: float, + penalty: float, + step_reward: float, + grader_scores: Dict[str, float], + contributions: Dict[str, float], + time_stats: Dict[str, float], + rm_raw: float = 0.0, + quota_exceeded_flags: Optional[Dict[str, bool]] = None + ): + """ + 更新 metadata["reward_stats"] - 直接使用 OpenJudge 原始字段 + + OpenJudge graders(按实际启用情况): + - report_resolution: 报告质量和问题解决能力 + - trajectory_faithfulness: 事实准确性(忠实度) + - citation_audit: 引用审计(覆盖率 + 真实性) + - rubrics_performance: 基于 rubrics 的评估(可选) + - trajectory_comprehensive: 轨迹综合评估(可选) + - information_gain: 信息增益/去重(可选) + - action_loop: 动作循环检测(惩罚项,可选) + + 注意:不再硬套 RewardStats 的字段名,直接使用 openjudge_ 前缀 + """ + quota_exceeded_flags = quota_exceeded_flags or {} + + # 计算 quota exceeded 统计 + quota_exceeded_count = sum(1 for v in quota_exceeded_flags.values() if v) + quota_exceeded_any = quota_exceeded_count > 0 + + # 基础分数 + stats_dict = { + "final_reward": final_reward, + "fused_reward": fused_reward, + "penalty": penalty, + "step_reward": step_reward, + "openjudge_enabled": True, + # Quota exceeded (429) 统计 + "quota_exceeded_any": quota_exceeded_any, # 是否有任何 grader 超额 + "quota_exceeded_count": quota_exceeded_count, # 超额的 grader 数量 + "quota_exceeded_graders": quota_exceeded_flags, # 各 grader 的超额标记 + # RM Gallery 相关 + "rm_enabled": self._rm_enabled, + "rm_raw": rm_raw, + "rm_weight": self.w.get("rm", 0.0), + "rm_contribution": contributions.get("rm_contribution", 0.0), + } + + # OpenJudge grader 原始分数(dimensions) + for grader_name, score in grader_scores.items(): + stats_dict[f"openjudge_{grader_name}_raw"] = score + stats_dict[f"openjudge_{grader_name}_weight"] = self.w.get(grader_name, 0.0) + + # OpenJudge grader 加权贡献(contribution) + for grader_name, contrib in contributions.items(): + stats_dict[f"openjudge_{grader_name}_contribution"] = contrib + + # 保留原始字典便于调试 + stats_dict["openjudge_grader_scores"] = grader_scores + stats_dict["openjudge_contributions"] = contributions + + # 注入耗时统计 + if time_stats: + stats_dict.update(time_stats) + + metadata["reward_stats"] = stats_dict + + def _save_evaluation_log(self, task_id: str, grader_results: Dict[str, List[Any]], query: str): + """ + 保存 OpenJudge 评估日志(可选) + """ + try: + log = { + "task_id": task_id, + "query": query, + "timestamp": datetime.now().isoformat(), + "grader_results": {} + } + + # 简化 grader_results 以便序列化 + for grader_name, score_list in grader_results.items(): + log["grader_results"][grader_name] = [] + for score in score_list: + if hasattr(score, "score"): + log["grader_results"][grader_name].append({ + "score": score.score, + "reason": score.reason[:200] if hasattr(score, "reason") else "", + }) + + save_dir = "./outputs/openjudge_logs" + os.makedirs(save_dir, exist_ok=True) + + log_file = os.path.join(save_dir, f"openjudge_{datetime.now().strftime('%Y%m%d')}.json") + with open(log_file, "a", encoding="utf-8") as f: + f.write(json.dumps(log, ensure_ascii=False) + "\n") + + except Exception as e: + print(f"⚠️ Failed to save evaluation log: {e}") + pass + diff --git a/tutorial/example_deep_finance/deep_finance_reader.py b/tutorial/example_deep_finance/deep_finance_reader.py new file mode 100644 index 00000000..1752bcd0 --- /dev/null +++ b/tutorial/example_deep_finance/deep_finance_reader.py @@ -0,0 +1,262 @@ +"""DeepFinance Reader + +从 JSON 文件加载任务数据,并现场组装 init_messages。 +- 数据来源:训练集/测试集 JSON 文件 +- 消息组装:加载 prompt 模板 + query +- 工具调用:仍走 env_service +""" +import os +import json +import logging +from typing import List, Dict, Any +from datetime import datetime + +from ajet.schema.task import Task +from ajet.task_reader.task_reader_base import BaseTaskReader + +# 配置 logger +logger = logging.getLogger(__name__) + +# 控制 debug 输出的开关(可通过环境变量控制) +DEBUG_ENABLED = os.environ.get("DEEPFINANCE_DEBUG", "0") == "1" + +def _debug_log(msg: str): + """统一的 debug 日志输出""" + if DEBUG_ENABLED: + print(f"[DEBUG][DeepFinanceReader] {msg}") + logger.debug(msg) + + +class DeepFinanceReader(BaseTaskReader): + """ + DeepFinance 专用的数据加载器 + + 特点: + 1. 从 JSON 文件加载任务数据(支持 list 和 dict 格式) + 2. 现场组装 init_messages(system_prompt + user_query) + 3. env_type 固定为 "deep_finance",由 env_service 负责工具调用 + """ + + # 类级别缓存 + _prompt_template_cache = None + _tool_prompt_cache = None + + def __init__(self, reader_config): + super().__init__(reader_config) + self.reader_config = reader_config + + _debug_log(f"Initializing DeepFinanceReader...") + _debug_log(f"reader_config type: {type(reader_config).__name__}") + + # 获取 prompt 目录路径 + self.local_path = os.path.dirname(os.path.abspath(__file__)) + _debug_log(f"local_path: {self.local_path}") + + # 初始化 prompt 缓存 + self._init_prompt_templates() + _debug_log(f"Initialization complete.") + + def _init_prompt_templates(self): + """初始化 prompt 模板缓存""" + if DeepFinanceReader._prompt_template_cache is None: + prompt_file = os.path.join(self.local_path, 'prompt', 'finance_analyst_prompt.md') + _debug_log(f"Loading prompt template from: {prompt_file}") + with open(prompt_file, 'r', encoding='utf-8') as f: + DeepFinanceReader._prompt_template_cache = f.read() + _debug_log(f"Prompt template loaded, length: {len(DeepFinanceReader._prompt_template_cache)} chars") + else: + _debug_log(f"Using cached prompt template, length: {len(DeepFinanceReader._prompt_template_cache)} chars") + + if DeepFinanceReader._tool_prompt_cache is None: + # 使用 tool_prompt_builder.py 中的静态模板 + _debug_log(f"Loading tool prompt template...") + from tutorial.example_deep_finance.prompt.tool_prompt_builder import get_tool_prompt_template + DeepFinanceReader._tool_prompt_cache = get_tool_prompt_template() + _debug_log(f"Tool prompt template loaded, length: {len(DeepFinanceReader._tool_prompt_cache)} chars") + else: + _debug_log(f"Using cached tool prompt template, length: {len(DeepFinanceReader._tool_prompt_cache)} chars") + + def _build_system_prompt(self) -> str: + """构建 system prompt""" + current_date = datetime.now().strftime('%Y-%m-%d') + _debug_log(f"Building system prompt with date: {current_date}") + + # 替换日期占位符 + system_prompt = DeepFinanceReader._prompt_template_cache.replace( + '{current_date}', + current_date + ) + # 替换工具列表占位符 + system_prompt = system_prompt.replace( + '{tool_list}', + DeepFinanceReader._tool_prompt_cache + ) + _debug_log(f"System prompt built, final length: {len(system_prompt)} chars") + return system_prompt + + def _build_init_messages(self, query: str) -> List[Dict[str, Any]]: + """ + 构建 init_messages + + Args: + query: 用户问题 + + Returns: + [{"role": "system", "content": ...}, {"role": "user", "content": ...}] + """ + _debug_log(f"Building init_messages for query (len={len(query)}): {query[:100]}..." if len(query) > 100 else f"Building init_messages for query: {query}") + system_prompt = self._build_system_prompt() + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": query} + ] + _debug_log(f"init_messages built: {len(messages)} messages, system_prompt_len={len(system_prompt)}") + return messages + + def _read_json_file(self, file_path: str, split: str = "train") -> List[Task]: + """ + 从 JSON 文件读取任务列表 + + 支持的数据格式: + 1. List 格式: [{"task": {"task_id": ..., "query": ...}, ...}, ...] + 2. Dict 格式: {"task_id_1": {"task": {...}, ...}, "task_id_2": {...}, ...} + + Args: + file_path: JSON 文件路径 + split: 数据集划分(train/val) + + Returns: + List[Task]: 任务列表 + """ + _debug_log(f"Reading JSON file: {file_path}, split={split}") + + if not os.path.exists(file_path): + _debug_log(f"ERROR: File not found: {file_path}") + raise FileNotFoundError(f"JSON file not found: {file_path}") + + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + _debug_log(f"JSON data loaded, type: {type(data).__name__}, size: {len(data) if isinstance(data, (list, dict)) else 'N/A'}") + + tasks = [] + skipped_count = 0 + split_filtered_count = 0 + + # 解析数据 + if isinstance(data, list): + # List 格式 + _debug_log(f"Parsing List format data, total items: {len(data)}") + for idx, item in enumerate(data): + task_info = item.get('task', {}) + task_id = task_info.get('task_id', '') + query = task_info.get('query', '') + + if not task_id or not query: + skipped_count += 1 + _debug_log(f" Item {idx}: SKIPPED (missing task_id or query)") + continue + + # 过滤 split + item_split = task_info.get('metadata', {}).get('split', split) + if item_split != split: + split_filtered_count += 1 + _debug_log(f" Item {idx} ({task_id}): FILTERED by split (item_split={item_split}, expected={split})") + continue + + # 构建 Task + _debug_log(f" Item {idx} ({task_id}): Creating task...") + task = self._create_task(task_id, query, item) + tasks.append(task) + + elif isinstance(data, dict): + # Dict 格式 + _debug_log(f"Parsing Dict format data, total keys: {len(data)}") + for idx, (task_id, item) in enumerate(data.items()): + task_info = item.get('task', {}) + query = task_info.get('query', '') + + if not query: + skipped_count += 1 + _debug_log(f" Key {idx} ({task_id}): SKIPPED (missing query)") + continue + + # 过滤 split + item_split = task_info.get('metadata', {}).get('split', split) + if item_split != split: + split_filtered_count += 1 + _debug_log(f" Key {idx} ({task_id}): FILTERED by split (item_split={item_split}, expected={split})") + continue + + # 构建 Task(使用 dict key 作为 task_id) + _debug_log(f" Key {idx} ({task_id}): Creating task...") + task = self._create_task(task_id, query, item) + tasks.append(task) + + _debug_log(f"Summary: loaded={len(tasks)}, skipped={skipped_count}, split_filtered={split_filtered_count}") + print(f"[DeepFinanceReader] Loaded {len(tasks)} tasks from {file_path} (split={split})") + + if len(tasks) == 0: + raise ValueError(f"No tasks found in file: {file_path} for split={split}") + + return tasks + + def _create_task(self, task_id: str, query: str, raw_item: Dict[str, Any]) -> Task: + """ + 创建 Task 对象 + + Args: + task_id: 任务 ID + query: 用户问题 + raw_item: 原始数据项 + + Returns: + Task: 任务对象 + """ + _debug_log(f"Creating Task: task_id={task_id}") + + # 现场组装 init_messages + init_messages = self._build_init_messages(query) + + # 提取 metadata + task_info = raw_item.get('task', {}) + metadata = task_info.get('metadata', {}) + + # 将原始数据存入 metadata,供 env 和 judge 使用 + # 注意:序列化为 JSON 字符串,避免嵌套字典导致 PyArrow 序列化时递归深度超限 + metadata['raw_task_data'] = json.dumps(raw_item, ensure_ascii=False) + metadata['query'] = query + metadata['confidence'] = raw_item.get('confidence', 1.0) + metadata['rubrics'] = raw_item.get('rubrics', None) + metadata['ground_truth'] = task_info.get('ground_truth', '') + + _debug_log(f" Task metadata: confidence={metadata['confidence']}, has_rubrics={metadata['rubrics'] is not None}, has_ground_truth={bool(metadata['ground_truth'])}") + _debug_log(f" Task init_messages: {len(init_messages)} messages") + + task = Task( + main_query=query, + init_messages=init_messages, + task_id=task_id, + env_type="deep_finance", # 固定为 deep_finance,由 env_service 处理 + metadata=metadata + ) + _debug_log(f" Task created successfully: {task_id}") + return task + + def get_training_tasks(self) -> List[Task]: + """获取训练任务""" + _debug_log(f"get_training_tasks() called") + file_path = self.reader_config.deep_finance.training.file_path + _debug_log(f"Training file path: {file_path}") + tasks = self._read_json_file(file_path, split="train") + _debug_log(f"get_training_tasks() returning {len(tasks)} tasks") + return tasks + + def get_validation_tasks(self) -> List[Task]: + """获取验证任务""" + _debug_log(f"get_validation_tasks() called") + file_path = self.reader_config.deep_finance.validation.file_path + _debug_log(f"Validation file path: {file_path}") + tasks = self._read_json_file(file_path, split="val") + _debug_log(f"get_validation_tasks() returning {len(tasks)} tasks") + return tasks diff --git a/tutorial/example_deep_finance/prompt/finance_analyst_prompt.md b/tutorial/example_deep_finance/prompt/finance_analyst_prompt.md new file mode 100644 index 00000000..f3dd2bad --- /dev/null +++ b/tutorial/example_deep_finance/prompt/finance_analyst_prompt.md @@ -0,0 +1,189 @@ +你是一位专业的金融研究分析师。你的任务是通过工具收集信息,进行深度研究,并最终输出一份结构化的 Markdown 格式研究报告。 + +当前日期: {current_date} + +## 研究流程 + +你必须采用两阶段深度研究方法: + +### 第一阶段:先大纲、后调研(必须执行) + +**你必须先输出研究大纲,再通过工具收集信息;禁止在没有数据支撑的情况下直接生成完整报告。** + +1. **理解需求**:分析用户问题的类型(个股分析/行业研究/事件解读/宏观分析/股票检索等),明确用户关注的核心结论与评估维度。 +2. **先写研究大纲(必须先做,且此时不要调用工具)**: + - 输出一个“报告大纲”,包含:一级/二级标题 + 每节要回答的关键问题(Key Questions)。 + - 大纲应明确:每一部分需要哪些证据类型(财务/估值/新闻/政策/行业对比/同业数据等)、需要哪些关键表格或对比指标。 + - 注意:此步骤只写结构与问题清单,不要在没有数据前给出确定性的数字结论。 +3. **按大纲逐段调研(必须执行)**: + - 以大纲为索引,一节一节地收集证据与数据,逐步补全每个 Key Question。 + - **必须使用工具收集数据** - 不要凭空猜测或使用过时信息 + - **分批调用工具** - 每次最多调用3个相关工具,避免一次性调用过多 + - 每轮调用工具后先做小结:本轮获得了哪些可用证据、还缺什么,再决定下一轮 + - 多维度交叉验证,确保数据全面性和准确性 + +### 第二阶段:深度分析与报告生成 + +**仅当你已经通过工具获取了充分的数据后,才能进入此阶段。** + +- 当收集到充分信息后,进入写作阶段,基于真实数据输出完整的 Markdown 格式研究报告,并在报告末尾添加 `[TASK_COMPLETED]` 标记。 +- 写作过程中如果发现关键结论缺少证据支撑,允许**追加 1-2 轮工具调用**进行补充取证,然后继续完成报告(不要为刷工具而调用)。 + +## 工具调用格式 + +当需要使用工具时,**必须严格按照以下标准JSON格式输出**(注意:使用单花括号`{`,不要使用双花括号`{{`): + +```json +[ + { + "tool_name": "工具名称", + "tool_args": { + "参数名": "参数值" + } + } +] +```` + +**重要限制:每次最多调用3个工具** + +* 先调用核心工具获取关键信息,分析后再调用补充工具 +* 例如:先查股票代码 → 再查财务数据 → 最后查行业新闻 + +**工具调用示例(每次1-3个工具)**: + +```json +[ + { + "tool_name": "dashscope_search", + "tool_args": { + "query": "茅台股票代码" + } + } +] +``` + +**重要提醒**: + +* ✓ 使用标准JSON格式,单花括号 `{` 和 `}` +* ✗ 不要使用双花括号 `{{` 或 `}}` +* ✓ 所有字符串必须用双引号包裹 +* ✓ 确保JSON格式可被 `json.loads()` 正确解析 + +## 可用工具 + +以下是当前环境中可用的工具列表。**带✅标记的工具是推荐优先使用的,它们经过验证更加稳定和可靠。** + +{tool_list} + +## 工具使用准则 + +1. **数量限制(重要)**:每次最多调用3个工具,采用多轮次渐进式调研,避免单次调用过多工具 +2. **优先使用推荐工具**:带✅标记的工具经过验证,更加稳定可靠,应优先考虑使用 +3. **先搜索后查询**:不确定的信息(如股票代码、行业分类)必须先用所提供的搜索工具查询,不要臆测 +4. **逐步推进**:每次调用工具后分析结果,再决定下一步调研方向,不要一次性规划所有调用 +5. **多源验证**:使用多个工具源交叉验证关键数据 +6. **全面覆盖**:根据研究主题,通过多轮调用逐步覆盖相关的多个维度(基本面、财务、估值、行业、新闻等) +7. **仅调用存在的工具**:只能使用上述"可用工具"列表中的工具,不要调用不存在的工具名称 + + +## 引用规范(必须遵守) + +你必须使用学术论文风格的引用标注,保证读者可追溯到工具获取的信息来源。 + +### 1) 何时必须引用(关键事实句) + +“关键事实句”指:包含数字/同比环比/日期/财务指标/估值倍数/明确事实结论/具体事件/具体公司或行业陈述/政策条款的句子。 + +* 所有关键事实句句末必须添加引用编号:**[1]** 或 **[1][2]**。 +* **同一来源务必在全文重复使用同一编号**。 + +### 2) References(必须,包含内容与格式要求) + +* 报告末尾必须包含 `## References` 小节。 +* 每条引用一行,编号从 `[1]` 开始连续;同一来源务必重复使用同一编号。 +* **URL 优先**:若工具返回有可用 `url`,References 中应填写该 URL,以及来源 +* **URL 提取指南**: + - 对于 `dashscope_search`:从 `search_results` 的 `"url"` 字段提取 + - 对于 `crawl_ths_*` 系列工具:从返回内容的 `"以下内容来自:"` 后提取 URL +* 禁止伪造链接/来源;无法证据支撑的只能写“推测/假设”,不要用引用包装成事实。 +* 正文出现的每个 `[n]` 必须在 References 中有对应条目;References 不得包含正文未使用的编号。 + +**行格式模板(URL 可选)**: + +* `[n] 标题或简述,来源 - URL` +* `[n] 标题或简述,工具:,参数:,数据日期/报告期: ,来源 - URL` + +### 3) 输出前自检(必须) + +输出前检查: + +* 所有关键事实句是否都有 `[n]`; +* `## References` 是否覆盖全部编号。 + + +## 最终报告要求 + +当信息收集完成后,必须输出 **Markdown 格式的结构化研究报告**。 + +### 报告结构说明 + +根据用户问题类型,选择合适的报告结构: + +**个股分析**:包含公司概况、财务分析、估值分析、行业地位、最新动态、投资建议等 +**行业研究**:包含行业概况、发展趋势、政策环境、竞争格局、龙头企业、投资机会等 +**事件解读**:包含事件背景、影响分析、相关标的、投资策略等 +**宏观分析**:包含宏观环境、政策分析、市场影响、配置建议等 +**股票检索**:包含筛选标准、候选标的、对比分析、推荐排序等 + +### 报告格式要求 + +1. **使用 Markdown 语法**:标题(#)、列表(-)、表格(|)、加粗(**)等 +2. **结构清晰**:使用多级标题组织内容 +3. **数据可视化**:适当使用表格展示关键数据对比(表格中的关键数据同样需要引用 [n]) +4. **逻辑完整**:包含执行摘要、详细分析、结论建议 +5. **引用与参考文献(必须)**:正文关键事实句使用 [n] 引用;文末提供 `## References` + +### 报告示例框架 + +```markdown +# [研究主题] + +## 摘要 +[核心观点和结论,3-5条,每条如包含关键事实也要加引用 [n]] + +## [主体部分 - 根据主题自适应] +### [二级标题] +[具体分析内容...关键事实句末尾加 [n]] + +## 结论与建议 +[明确的结论和操作建议...关键事实句末尾加 [n]] + +## References +[1] 标题或简要描述 - https://... +[2] 贵州茅台历史股价分析(报告期2025-09-30),工具:history_calculate,参数:code=600519,query=过去一周涨跌情况 - https:// +--- +*本报告基于公开信息整理分析,仅供参考,不构成投资建议。投资有风险,入市需谨慎。* + +[TASK_COMPLETED] +``` + +## 何时停止调用工具并输出报告 + +**必须满足以下所有条件后,才能输出最终报告:** + +1. ✓ **已实际调用工具获取足够证据**(通常至少 2-4 轮;以“信息充分支撑结论”为准,而非强制轮数) +2. ✓ **已获取核心数据**:财务数据、市场数据、新闻动态等关键信息 +3. ✓ **已交叉验证**:从多个数据源验证了关键结论(至少对关键数字/事件做到交叉验证) +4. ✓ **数据完整性**:具备足够信息支撑每一个分析结论和投资建议(无法支撑的必须标注为推测/假设) + +**输出格式要求:** + +* 输出完整的 Markdown 格式研究报告(包含标题、摘要、分析、结论、References) +* 报告必须基于真实调用工具获取的数据,不能是空洞的框架 +* **在报告的最后一行单独输出** `[TASK_COMPLETED]` 标记 + +**警告:禁止在没有调用工具、没有真实数据的情况下直接输出报告框架+`[TASK_COMPLETED]`,这是无效的研究。** + +--- + +现在开始深度研究用户的问题。 \ No newline at end of file diff --git a/tutorial/example_deep_finance/prompt/tool_prompt_builder.py b/tutorial/example_deep_finance/prompt/tool_prompt_builder.py new file mode 100644 index 00000000..5c940fd7 --- /dev/null +++ b/tutorial/example_deep_finance/prompt/tool_prompt_builder.py @@ -0,0 +1,150 @@ +""" +工具信息Prompt构建模块 +用于生成清晰、结构化的工具使用说明 +""" + +def get_tool_prompt_template() -> str: + """ + 获取工具prompt模板(静态版本) + 基于实际探测到的19个工具进行配置 + + Returns: + 预定义的工具说明文本 + """ + + return """## 可用工具列表 + +### ⚠️ 重要说明 +**股票代码格式规范**: +- 涉及A股代码时,通常使用 **6位纯数字** 格式(如 `000001`、`600000`)。 +- **注意**: 用户输入股票名称时,必须先使用 `extract_entities_code` 转换为对应的代码。 + +--- + +### 🔍 实体与数据计算工具 + +#### ✅ extract_entities_code +**功能**: 从查询中提取金融实体(股票、债券、基金、加密货币、指数、商品、ETF等),并查找对应的代码。最后返回查询中出现的金融实体及其类型和代码。 +**参数**: + - `query` (必填, string): 关于金融实体的自然语言查询文本 + +#### ✅ history_calculate +**功能**: 获取指定A股股票的历史股价数据,并根据用户问题进行分析。 +**数据结构**: 工具内部包含以下字段的历史数据: + - `ts_code`(代码), `trade_date`(交易日期) + - `open`(开), `high`(高), `low`(低), `close`(收), `pre_close`(昨收) + - `change`(涨跌额), `pct_chg`(涨跌幅) + - `vol`(成交量), `amount`(成交额) +**使用说明**: 你无需编写任何代码——只需直接提问即可,例如:“过去一周涨了多少,有没有出现顶背离?”、“MACD是否形成了金叉?”。 +**参数**: + - `code` (必填, string): A股代码 (如 '600000' 或 '000001') + - `query` (必填, string): 关于股票历史表现的具体问题 + +--- + +### 💻 代码与通用网络工具 + +#### ✅ execute_code +**功能**: 执行 Python 代码,适用于复杂分析或计算场景。最终结果请使用 `print` 函数输出。 +**参数**: + - `code` (必填, string): 需要执行的代码 + +#### ✅ execute_shell +**功能**: 执行 Shell 命令 (如 `ls`, `pwd`, 运行脚本)。 +**注意**: 每次调用起始目录相同。如需多步操作,请在一条命令中使用 `&&` 连接 (例如: `cd aa/bb && bash xxx`)。 +**参数**: + - `command` (必填, string): 需要执行的命令 + +#### ✅ dashscope_search +**功能**: 使用搜索关键词从互联网检索相关信息。如果有多个关键词,请分开多次调用。 +**参数**: + - `query` (必填, string): 搜索关键词 + +#### ✅ crawl_url +**功能**: 网页内容解析工具,获取并格式化指定URL的网页内容。 +**参数**: + - `url` (必填, string): 目标网页URL + +# --- + +### 📈 同花顺专项数据工具 (Crawl THS) +*以下工具用于获取特定维度的深度金融数据,请根据用户意图选择最匹配的工具* + +#### ✅ crawl_ths_company +**功能**: 获取上市公司基本资料。 +**数据范围**: 详细情况、高管介绍、发行相关、参控股公司。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_holder +**功能**: 获取股东研究信息。 +**数据范围**: 股东人数、十大流通股东、十大股东、十大债券持有人、控股层级关系。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_operate +**功能**: 获取经营分析信息。 +**数据范围**: 主营介绍、运营业务数据、主营构成分析、主要客户及供应商、董事会经营评述、产品价格。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_equity +**功能**: 获取股本结构信息。 +**数据范围**: 解禁时间表、总股本构成、A股结构图、历次股本变动。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_capital +**功能**: 获取资本运作信息。 +**数据范围**: 募集资金来源、项目投资、收购兼并、股权投资、参股IPO、股权转让、关联交易、质押解冻。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_finance +**功能**: 获取财务分析信息。 +**数据范围**: 财务诊断、财务指标、指标变动说明、资产负债构成、财务报告、杜邦分析。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_worth +**功能**: 获取盈利预测信息。 +**数据范围**: 业绩预测、业绩预测详表、研报评级。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_news +**功能**: 获取新闻公告信息。 +**数据范围**: 新闻与股价联动、公告列表、热点新闻列表、研报列表。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_concept +**功能**: 获取概念题材信息。 +**数据范围**: 常规概念、其他概念、题材要点、概念对比。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_position +**功能**: 获取主力持仓信息。 +**数据范围**: 机构持股汇总、机构持股明细、被举牌情况、IPO获配机构。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_bonus +**功能**: 获取分红融资信息。 +**数据范围**: 分红诊断、分红情况、增发机构获配明细、增发概况、配股概况。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_event +**功能**: 获取公司大事信息。 +**数据范围**: 高管持股变动、股东持股变动、担保明细、违规处理、机构调研、投资者互动。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) + +#### ✅ crawl_ths_field +**功能**: 获取行业对比信息。 +**数据范围**: 行业地位、行业新闻。 +**参数**: + - `code` (必填, string): 股票代码 (6位数字) +""" \ No newline at end of file diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml new file mode 100644 index 00000000..a2d2cd73 --- /dev/null +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -0,0 +1,88 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: ajet_deep_finance + experiment_name: "{{SUFFIX}}" + # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) + judge: + openjudge_llm: {{OPENJUDGE_LLM}} # OpenJudge 模型 + rm_llm: {{RM_LLM}} # RM Gallery 模型 + concurrency: {{JUDGE_CONCURRENCY}} # Judge 并发数 + train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径 + val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 + # OpenJudge 权重配置 + report_resolution_weight: {{REPORT_RESOLUTION_WEIGHT}} # 报告质量评估 + trajectory_faithfulness_weight: {{TRAJECTORY_FAITHFULNESS_WEIGHT}} # 事实准确性评估 + citation_audit_weight: {{CITATION_AUDIT_WEIGHT}} # 引用审计评估 (覆盖率 + 真实性) + rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 + task_judge: + # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) + judge_protocol: tutorial.example_deep_finance.deep_finance_judge->DeepFinanceJudgeByOpenJudge + model: + # ✨✨✨✨ 设置待训练的模型 + path: {{MODEL_PATH}} + trainer_common: + nnodes: {{NNODES}} + n_gpus_per_node: 8 + val_before_train: True + val_pass_n: 8 + save_freq: 10 + test_freq: 2 + total_epochs: 200 + save_trajectory_as_json_file: True + rollout: + # ✨✨✨✨ 编写并选择Agent + user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol + force_disable_toolcalls: True + enable_oversample: False + tensor_model_parallel_size: 8 + num_repeat: {{NUM_REPEAT}} + max_env_worker: 64 # 增加环境并行数 + max_num_seqs: 64 # 增加VLLM并发序列数 + max_response_length_in_one_turn: 8000 + max_model_len: 50000 + agent_madness_reward: 0.0 + compute_madness_checklist: None + multi_turn: + max_steps: {{NUM_STEPS}} + interchange_server: + interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + debug: + debug_max_parallel: 64 # 增加并行任务数,充分利用GPU + debug_first_n_tasks: 100 # 增加处理的任务数 + data: + train_batch_size: {{TRAIN_BATCH_SIZE}} + max_prompt_length: 8000 + max_response_length: 41000 + + task_reader: + type: deep_finance # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service + deep_finance: + training: + file_path: {{TRAIN_DATA_PATH}} + validation: + file_path: {{VAL_DATA_PATH}} + # env_service 仍需配置(用于工具调用) + env_service: + env_type: "finworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code +trainer: + default_local_dir: "{{CKPT_SAVE_PATH}}/{{PREFIX}}/{{SUFFIX}}" + # resume_mode: disable # 禁用自动恢复,从头开始训练 +actor_rollout_ref: + rollout: + tensor_model_parallel_size: 8 + gpu_memory_utilization: 0.8 +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_ma_deepresearch/ma_deepresearch.py b/tutorial/example_ma_deepresearch/ma_deepresearch.py index d044458b..3293d769 100644 --- a/tutorial/example_ma_deepresearch/ma_deepresearch.py +++ b/tutorial/example_ma_deepresearch/ma_deepresearch.py @@ -42,7 +42,7 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl init_messages=init_messages, task_id=workflow_task.task.task_id, main_query=workflow_task.task.main_query, - max_steps=tuner.config.astune.rollout.multi_turn.max_steps, + max_steps=tuner.config.ajet.rollout.multi_turn.max_steps, env_service_url=workflow_task.gym_env.service_url, )