TradingAgents-CN

LangGraph 图结构设计

概述

TradingAgents 基于 LangGraph 框架构建,采用有向无环图(DAG)结构来组织智能体的工作流程。这种设计确保了智能体之间的有序协作和信息的正确流转。

图结构架构

整体工作流图

graph TD
    START([开始]) --> INIT[初始化状态]
    INIT --> PARALLEL_ANALYSIS{并行分析阶段}
    
    PARALLEL_ANALYSIS --> FA[基本面分析师]
    PARALLEL_ANALYSIS --> TA[技术分析师]
    PARALLEL_ANALYSIS --> NA[新闻分析师]
    PARALLEL_ANALYSIS --> SA[社交媒体分析师]
    
    FA --> COLLECT_ANALYSIS[收集分析结果]
    TA --> COLLECT_ANALYSIS
    NA --> COLLECT_ANALYSIS
    SA --> COLLECT_ANALYSIS
    
    COLLECT_ANALYSIS --> RESEARCH_MANAGER[研究经理]
    RESEARCH_MANAGER --> DEBATE_INIT[初始化辩论]
    
    DEBATE_INIT --> BULL_RESEARCHER[看涨研究员]
    DEBATE_INIT --> BEAR_RESEARCHER[看跌研究员]
    
    BULL_RESEARCHER --> DEBATE_ROUND{辩论轮次}
    BEAR_RESEARCHER --> DEBATE_ROUND
    
    DEBATE_ROUND -->|继续辩论| BULL_RESEARCHER
    DEBATE_ROUND -->|辩论结束| RESEARCH_CONSENSUS[研究共识]
    
    RESEARCH_CONSENSUS --> TRADER[交易员]
    TRADER --> TRADING_DECISION[交易决策]
    
    TRADING_DECISION --> RISK_PARALLEL{并行风险评估}
    
    RISK_PARALLEL --> AGGRESSIVE_RISK[激进风险评估]
    RISK_PARALLEL --> CONSERVATIVE_RISK[保守风险评估]
    RISK_PARALLEL --> NEUTRAL_RISK[中性风险评估]
    
    AGGRESSIVE_RISK --> RISK_DEBATE[风险辩论]
    CONSERVATIVE_RISK --> RISK_DEBATE
    NEUTRAL_RISK --> RISK_DEBATE
    
    RISK_DEBATE --> RISK_MANAGER[风险经理]
    RISK_MANAGER --> PORTFOLIO_DECISION[投资组合决策]
    
    PORTFOLIO_DECISION --> END([结束])

核心组件设计

1. 图构建器 (GraphSetup)

class GraphSetup:
    """LangGraph 图结构设置"""
    
    def build_graph(self) -> StateGraph:
        """构建完整的交易决策图"""
        
        # 创建状态图
        workflow = StateGraph(AgentState)
        
        # 添加节点
        self._add_analysis_nodes(workflow)
        self._add_research_nodes(workflow)
        self._add_trading_nodes(workflow)
        self._add_risk_nodes(workflow)
        
        # 添加边和条件逻辑
        self._add_edges(workflow)
        self._add_conditional_edges(workflow)
        
        # 设置入口和出口
        workflow.set_entry_point("initialize")
        workflow.set_finish_point("portfolio_decision")
        
        return workflow.compile()
    
    def _add_analysis_nodes(self, workflow: StateGraph):
        """添加分析师节点"""
        workflow.add_node("fundamentals_analyst", self.fundamentals_analyst)
        workflow.add_node("technical_analyst", self.technical_analyst)
        workflow.add_node("news_analyst", self.news_analyst)
        workflow.add_node("social_analyst", self.social_analyst)
    
    def _add_research_nodes(self, workflow: StateGraph):
        """添加研究员节点"""
        workflow.add_node("research_manager", self.research_manager)
        workflow.add_node("bull_researcher", self.bull_researcher)
        workflow.add_node("bear_researcher", self.bear_researcher)
    
    def _add_trading_nodes(self, workflow: StateGraph):
        """添加交易节点"""
        workflow.add_node("trader", self.trader)
    
    def _add_risk_nodes(self, workflow: StateGraph):
        """添加风险管理节点"""
        workflow.add_node("aggressive_risk", self.aggressive_risk)
        workflow.add_node("conservative_risk", self.conservative_risk)
        workflow.add_node("neutral_risk", self.neutral_risk)
        workflow.add_node("risk_manager", self.risk_manager)

2. 条件逻辑 (ConditionalLogic)

class ConditionalLogic:
    """图的条件逻辑控制"""
    
    def should_continue_debate(self, state: AgentState) -> str:
        """判断是否继续研究员辩论"""
        
        current_round = state.get("debate_round", 0)
        max_rounds = self.config.get("max_debate_rounds", 3)
        
        # 检查辩论轮次
        if current_round >= max_rounds:
            return "end_debate"
        
        # 检查是否达成共识
        if self._has_consensus(state):
            return "end_debate"
        
        # 检查分歧是否足够大
        if self._has_significant_disagreement(state):
            return "continue_debate"
        
        return "end_debate"
    
    def route_to_risk_assessment(self, state: AgentState) -> List[str]:
        """路由到风险评估节点"""
        
        trading_decision = state.get("trader_decision", {})
        risk_level = trading_decision.get("risk_level", "medium")
        
        # 根据风险级别决定评估路径
        if risk_level == "high":
            return ["aggressive_risk", "conservative_risk", "neutral_risk"]
        elif risk_level == "low":
            return ["conservative_risk", "neutral_risk"]
        else:
            return ["neutral_risk"]
    
    def should_approve_trade(self, state: AgentState) -> str:
        """判断是否批准交易"""
        
        risk_assessment = state.get("risk_assessment", {})
        risk_score = risk_assessment.get("overall_risk_score", 0.5)
        
        # 风险阈值检查
        if risk_score > self.config.get("risk_threshold", 0.8):
            return "reject_trade"
        
        # 一致性检查
        if self._risk_assessments_consistent(state):
            return "approve_trade"
        
        return "request_review"

3. 状态传播 (Propagator)

class Propagator:
    """状态传播管理器"""
    
    def propagate(self, symbol: str, date: str) -> Tuple[AgentState, Dict]:
        """执行完整的传播流程"""
        
        # 初始化状态
        initial_state = self._initialize_state(symbol, date)
        
        # 执行图传播
        final_state = self.graph.invoke(initial_state)
        
        # 提取决策结果
        decision = self._extract_decision(final_state)
        
        return final_state, decision
    
    def _initialize_state(self, symbol: str, date: str) -> AgentState:
        """初始化智能体状态"""
        return AgentState(
            ticker=symbol,
            date=date,
            analyst_reports={},
            research_reports={},
            trader_decision={},
            risk_assessment={},
            portfolio_decision={},
            messages=[],
            metadata={}
        )
    
    def _extract_decision(self, state: AgentState) -> Dict:
        """从最终状态提取决策信息"""
        return {
            "action": state.portfolio_decision.get("action", "hold"),
            "quantity": state.portfolio_decision.get("quantity", 0),
            "confidence": state.portfolio_decision.get("confidence", 0.5),
            "reasoning": state.portfolio_decision.get("reasoning", ""),
            "risk_score": state.risk_assessment.get("overall_risk_score", 0.5)
        }

节点类型详解

1. 分析节点 (Analysis Nodes)

def fundamentals_analyst_node(state: AgentState) -> AgentState:
    """基本面分析师节点"""
    
    # 获取数据
    data = get_fundamental_data(state.ticker, state.date)
    
    # 执行分析
    analysis = fundamentals_analyst.analyze(data)
    
    # 更新状态
    state.analyst_reports["fundamentals"] = analysis
    
    return state

2. 决策节点 (Decision Nodes)

def trader_node(state: AgentState) -> AgentState:
    """交易员决策节点"""
    
    # 收集所有分析报告
    all_reports = {
        **state.analyst_reports,
        **state.research_reports
    }
    
    # 制定交易决策
    decision = trader.make_decision(all_reports)
    
    # 更新状态
    state.trader_decision = decision
    
    return state

3. 并行节点 (Parallel Nodes)

def parallel_analysis_node(state: AgentState) -> AgentState:
    """并行分析节点"""
    
    # 并行执行多个分析师
    with ThreadPoolExecutor() as executor:
        futures = {
            executor.submit(fundamentals_analyst.analyze, state): "fundamentals",
            executor.submit(technical_analyst.analyze, state): "technical",
            executor.submit(news_analyst.analyze, state): "news",
            executor.submit(social_analyst.analyze, state): "social"
        }
        
        # 收集结果
        for future in as_completed(futures):
            analyst_type = futures[future]
            result = future.result()
            state.analyst_reports[analyst_type] = result
    
    return state

边和路由设计

1. 顺序边 (Sequential Edges)

# 简单的顺序连接
workflow.add_edge("initialize", "parallel_analysis")
workflow.add_edge("parallel_analysis", "research_manager")
workflow.add_edge("research_manager", "trader")

2. 条件边 (Conditional Edges)

# 基于条件的路由
workflow.add_conditional_edges(
    "debate_round",
    conditional_logic.should_continue_debate,
    {
        "continue_debate": "bull_researcher",
        "end_debate": "research_consensus"
    }
)

3. 并行边 (Parallel Edges)

# 并行执行多个节点
workflow.add_conditional_edges(
    "trading_decision",
    conditional_logic.route_to_risk_assessment,
    {
        "aggressive_risk": "aggressive_risk_node",
        "conservative_risk": "conservative_risk_node",
        "neutral_risk": "neutral_risk_node"
    }
)

状态管理

1. 状态结构

@dataclass
class AgentState:
    """智能体状态数据结构"""
    
    # 基本信息
    ticker: str
    date: str
    
    # 分析结果
    analyst_reports: Dict[str, Any]
    research_reports: Dict[str, Any]
    
    # 决策信息
    trader_decision: Dict[str, Any]
    risk_assessment: Dict[str, Any]
    portfolio_decision: Dict[str, Any]
    
    # 通信记录
    messages: List[BaseMessage]
    
    # 元数据
    metadata: Dict[str, Any]
    
    # 控制信息
    debate_round: int = 0
    risk_round: int = 0

2. 状态更新

class StateManager:
    """状态管理器"""
    
    def update_state(self, state: AgentState, updates: Dict) -> AgentState:
        """安全地更新状态"""
        
        # 深拷贝状态
        new_state = copy.deepcopy(state)
        
        # 应用更新
        for key, value in updates.items():
            if hasattr(new_state, key):
                setattr(new_state, key, value)
        
        # 验证状态一致性
        self._validate_state(new_state)
        
        return new_state
    
    def _validate_state(self, state: AgentState):
        """验证状态一致性"""
        
        # 检查必需字段
        required_fields = ["ticker", "date"]
        for field in required_fields:
            if not getattr(state, field):
                raise ValueError(f"Required field {field} is missing")
        
        # 检查数据类型
        if not isinstance(state.analyst_reports, dict):
            raise TypeError("analyst_reports must be a dictionary")

错误处理和恢复

1. 节点级错误处理

def safe_node_execution(node_func):
    """节点执行的安全包装器"""
    
    def wrapper(state: AgentState) -> AgentState:
        try:
            return node_func(state)
        except Exception as e:
            # 记录错误
            logger.error(f"Node {node_func.__name__} failed: {e}")
            
            # 添加错误信息到状态
            state.metadata["errors"] = state.metadata.get("errors", [])
            state.metadata["errors"].append({
                "node": node_func.__name__,
                "error": str(e),
                "timestamp": datetime.now().isoformat()
            })
            
            return state
    
    return wrapper

2. 图级错误恢复

class GraphRecovery:
    """图执行恢复机制"""
    
    def execute_with_recovery(self, graph, initial_state):
        """带恢复机制的图执行"""
        
        try:
            return graph.invoke(initial_state)
        except Exception as e:
            # 尝试从检查点恢复
            if checkpoint := self._find_last_checkpoint(initial_state):
                logger.info("Recovering from checkpoint")
                return self._recover_from_checkpoint(graph, checkpoint)
            
            # 降级执行
            logger.warning("Falling back to degraded execution")
            return self._degraded_execution(initial_state)

这种图结构设计确保了智能体工作流的清晰性、可维护性和容错性,同时提供了灵活的扩展机制。