TradingAgents-CN

数据缓存策略

概述

TradingAgents 框架采用多层次的缓存策略来优化数据访问性能,减少API调用成本,并提高系统响应速度。本文档详细介绍了缓存架构、策略、实现和最佳实践。

缓存架构

多层缓存设计

graph TB
    subgraph "应用层"
        AGENT1[分析师智能体]
        AGENT2[研究员智能体]
        AGENT3[交易员智能体]
    end
    
    subgraph "缓存层次"
        L1[L1: 内存缓存<br/>最快访问]
        L2[L2: 本地文件缓存<br/>持久化存储]
        L3[L3: Redis缓存<br/>分布式共享]
        L4[L4: 数据库缓存<br/>长期存储]
    end
    
    subgraph "数据源"
        API1[FinnHub API]
        API2[Yahoo Finance]
        API3[Reddit API]
        API4[Google News]
    end
    
    AGENT1 --> L1
    AGENT2 --> L1
    AGENT3 --> L1
    
    L1 --> L2
    L2 --> L3
    L3 --> L4
    
    L4 --> API1
    L4 --> API2
    L4 --> API3
    L4 --> API4

1. 缓存管理器

核心缓存类

class CacheManager:
    """缓存管理器 - 统一管理多层缓存"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.memory_cache = MemoryCache(config.get("memory_cache", {}))
        self.file_cache = FileCache(config.get("file_cache", {}))
        self.redis_cache = RedisCache(config.get("redis_cache", {})) if config.get("redis_enabled") else None
        self.db_cache = DatabaseCache(config.get("db_cache", {})) if config.get("db_enabled") else None
        
        # 缓存策略配置
        self.cache_strategies = self._load_cache_strategies()
        self.ttl_config = self._load_ttl_config()
        
    def get(self, key: str, data_type: str = "default") -> Optional[Any]:
        """获取缓存数据 - 按层次查找"""
        
        # L1: 内存缓存
        data = self.memory_cache.get(key)
        if data is not None:
            self._record_cache_hit("memory", key, data_type)
            return data
        
        # L2: 文件缓存
        data = self.file_cache.get(key)
        if data is not None:
            # 回填到内存缓存
            self.memory_cache.set(key, data, self._get_ttl(data_type))
            self._record_cache_hit("file", key, data_type)
            return data
        
        # L3: Redis缓存
        if self.redis_cache:
            data = self.redis_cache.get(key)
            if data is not None:
                # 回填到上层缓存
                self.file_cache.set(key, data, self._get_ttl(data_type))
                self.memory_cache.set(key, data, self._get_ttl(data_type))
                self._record_cache_hit("redis", key, data_type)
                return data
        
        # L4: 数据库缓存
        if self.db_cache:
            data = self.db_cache.get(key)
            if data is not None:
                # 回填到所有上层缓存
                if self.redis_cache:
                    self.redis_cache.set(key, data, self._get_ttl(data_type))
                self.file_cache.set(key, data, self._get_ttl(data_type))
                self.memory_cache.set(key, data, self._get_ttl(data_type))
                self._record_cache_hit("database", key, data_type)
                return data
        
        # 缓存未命中
        self._record_cache_miss(key, data_type)
        return None
    
    def set(self, key: str, data: Any, data_type: str = "default", ttl: Optional[int] = None) -> None:
        """设置缓存数据 - 写入所有层次"""
        
        if ttl is None:
            ttl = self._get_ttl(data_type)
        
        # 根据数据类型和大小决定缓存策略
        cache_strategy = self._determine_cache_strategy(data, data_type)
        
        # L1: 内存缓存 (总是缓存小数据)
        if cache_strategy["memory"]:
            self.memory_cache.set(key, data, ttl)
        
        # L2: 文件缓存 (缓存中等大小数据)
        if cache_strategy["file"]:
            self.file_cache.set(key, data, ttl)
        
        # L3: Redis缓存 (缓存共享数据)
        if cache_strategy["redis"] and self.redis_cache:
            self.redis_cache.set(key, data, ttl)
        
        # L4: 数据库缓存 (缓存重要数据)
        if cache_strategy["database"] and self.db_cache:
            self.db_cache.set(key, data, ttl)
    
    def _determine_cache_strategy(self, data: Any, data_type: str) -> Dict[str, bool]:
        """确定缓存策略"""
        
        data_size = self._estimate_data_size(data)
        data_importance = self._assess_data_importance(data_type)
        
        strategy = {
            "memory": data_size < 1024 * 1024,  # 小于1MB
            "file": data_size < 10 * 1024 * 1024,  # 小于10MB
            "redis": data_importance >= 0.7,  # 重要数据
            "database": data_importance >= 0.8 or data_type in ["fundamental_data", "company_profile"]
        }
        
        return strategy
    
    def _get_ttl(self, data_type: str) -> int:
        """获取数据类型的TTL"""
        return self.ttl_config.get(data_type, self.ttl_config["default"])

2. 内存缓存 (L1)

高速内存缓存

class MemoryCache:
    """内存缓存 - 最快的缓存层"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.cache = {}
        self.access_times = {}
        self.max_size = config.get("max_size", 1000)
        self.cleanup_threshold = config.get("cleanup_threshold", 0.8)
        
    def get(self, key: str) -> Optional[Any]:
        """获取缓存项"""
        if key in self.cache:
            item = self.cache[key]
            
            # 检查是否过期
            if self._is_expired(item):
                del self.cache[key]
                if key in self.access_times:
                    del self.access_times[key]
                return None
            
            # 更新访问时间
            self.access_times[key] = time.time()
            return item["data"]
        
        return None
    
    def set(self, key: str, data: Any, ttl: int) -> None:
        """设置缓存项"""
        
        # 检查是否需要清理
        if len(self.cache) >= self.max_size * self.cleanup_threshold:
            self._cleanup_cache()
        
        # 存储数据
        self.cache[key] = {
            "data": data,
            "timestamp": time.time(),
            "ttl": ttl
        }
        self.access_times[key] = time.time()
    
    def _cleanup_cache(self) -> None:
        """清理过期和最少使用的缓存项"""
        
        current_time = time.time()
        
        # 首先清理过期项
        expired_keys = []
        for key, item in self.cache.items():
            if self._is_expired(item):
                expired_keys.append(key)
        
        for key in expired_keys:
            del self.cache[key]
            if key in self.access_times:
                del self.access_times[key]
        
        # 如果还是太多,使用LRU策略清理
        if len(self.cache) >= self.max_size * self.cleanup_threshold:
            # 按访问时间排序,删除最少使用的
            sorted_keys = sorted(self.access_times.keys(), key=lambda k: self.access_times[k])
            keys_to_remove = sorted_keys[:len(sorted_keys) // 4]  # 删除25%
            
            for key in keys_to_remove:
                if key in self.cache:
                    del self.cache[key]
                if key in self.access_times:
                    del self.access_times[key]
    
    def _is_expired(self, item: Dict) -> bool:
        """检查缓存项是否过期"""
        return time.time() - item["timestamp"] > item["ttl"]

3. 文件缓存 (L2)

持久化文件缓存

class FileCache:
    """文件缓存 - 持久化存储"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.cache_dir = Path(config.get("cache_dir", "./cache"))
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.compression_enabled = config.get("compression", True)
        self.max_file_size = config.get("max_file_size", 50 * 1024 * 1024)  # 50MB
        
    def get(self, key: str) -> Optional[Any]:
        """从文件获取缓存数据"""
        
        cache_file = self._get_cache_file_path(key)
        
        if not cache_file.exists():
            return None
        
        try:
            # 检查文件修改时间
            if self._is_file_expired(cache_file, key):
                cache_file.unlink()  # 删除过期文件
                return None
            
            # 读取数据
            with open(cache_file, 'rb') as f:
                if self.compression_enabled:
                    compressed_data = f.read()
                    data = self._decompress_data(compressed_data)
                else:
                    data = pickle.load(f)
            
            return data
            
        except Exception as e:
            print(f"Error reading cache file {cache_file}: {e}")
            # 删除损坏的缓存文件
            if cache_file.exists():
                cache_file.unlink()
            return None
    
    def set(self, key: str, data: Any, ttl: int) -> None:
        """将数据写入文件缓存"""
        
        cache_file = self._get_cache_file_path(key)
        
        try:
            # 检查数据大小
            data_size = self._estimate_data_size(data)
            if data_size > self.max_file_size:
                print(f"Data too large for file cache: {data_size} bytes")
                return
            
            # 创建缓存元数据
            cache_data = {
                "data": data,
                "timestamp": time.time(),
                "ttl": ttl,
                "key": key
            }
            
            # 写入文件
            with open(cache_file, 'wb') as f:
                if self.compression_enabled:
                    compressed_data = self._compress_data(cache_data)
                    f.write(compressed_data)
                else:
                    pickle.dump(cache_data, f)
            
        except Exception as e:
            print(f"Error writing cache file {cache_file}: {e}")
    
    def _get_cache_file_path(self, key: str) -> Path:
        """获取缓存文件路径"""
        # 使用哈希避免文件名过长或包含特殊字符
        key_hash = hashlib.md5(key.encode()).hexdigest()
        return self.cache_dir / f"{key_hash}.cache"
    
    def _is_file_expired(self, cache_file: Path, key: str) -> bool:
        """检查缓存文件是否过期"""
        try:
            with open(cache_file, 'rb') as f:
                if self.compression_enabled:
                    compressed_data = f.read()
                    cache_data = self._decompress_data(compressed_data)
                else:
                    cache_data = pickle.load(f)
            
            return time.time() - cache_data["timestamp"] > cache_data["ttl"]
            
        except Exception:
            return True  # 如果无法读取,认为已过期
    
    def _compress_data(self, data: Any) -> bytes:
        """压缩数据"""
        import gzip
        pickled_data = pickle.dumps(data)
        return gzip.compress(pickled_data)
    
    def _decompress_data(self, compressed_data: bytes) -> Any:
        """解压数据"""
        import gzip
        pickled_data = gzip.decompress(compressed_data)
        return pickle.loads(pickled_data)
    
    def cleanup_expired_files(self) -> None:
        """清理过期的缓存文件"""
        for cache_file in self.cache_dir.glob("*.cache"):
            try:
                if self._is_file_expired(cache_file, ""):
                    cache_file.unlink()
            except Exception as e:
                print(f"Error checking cache file {cache_file}: {e}")

4. Redis缓存 (L3)

分布式共享缓存

class RedisCache:
    """Redis缓存 - 分布式共享缓存"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.redis_client = self._initialize_redis_client()
        self.key_prefix = config.get("key_prefix", "tradingagents:")
        self.serialization_format = config.get("serialization", "pickle")  # pickle, json, msgpack
        
    def _initialize_redis_client(self):
        """初始化Redis客户端"""
        try:
            import redis
            
            redis_config = {
                "host": self.config.get("host", "localhost"),
                "port": self.config.get("port", 6379),
                "db": self.config.get("db", 0),
                "password": self.config.get("password"),
                "socket_timeout": self.config.get("timeout", 5),
                "socket_connect_timeout": self.config.get("connect_timeout", 5),
                "retry_on_timeout": True,
                "health_check_interval": 30
            }
            
            # 移除None值
            redis_config = {k: v for k, v in redis_config.items() if v is not None}
            
            client = redis.Redis(**redis_config)
            
            # 测试连接
            client.ping()
            print("Redis connection established")
            
            return client
            
        except Exception as e:
            print(f"Failed to connect to Redis: {e}")
            return None
    
    def get(self, key: str) -> Optional[Any]:
        """从Redis获取数据"""
        if not self.redis_client:
            return None
        
        try:
            full_key = self.key_prefix + key
            data = self.redis_client.get(full_key)
            
            if data is None:
                return None
            
            # 反序列化数据
            return self._deserialize_data(data)
            
        except Exception as e:
            print(f"Error getting data from Redis: {e}")
            return None
    
    def set(self, key: str, data: Any, ttl: int) -> None:
        """向Redis设置数据"""
        if not self.redis_client:
            return
        
        try:
            full_key = self.key_prefix + key
            
            # 序列化数据
            serialized_data = self._serialize_data(data)
            
            # 设置数据和TTL
            self.redis_client.setex(full_key, ttl, serialized_data)
            
        except Exception as e:
            print(f"Error setting data to Redis: {e}")
    
    def _serialize_data(self, data: Any) -> bytes:
        """序列化数据"""
        if self.serialization_format == "pickle":
            return pickle.dumps(data)
        elif self.serialization_format == "json":
            import json
            return json.dumps(data, default=str).encode('utf-8')
        elif self.serialization_format == "msgpack":
            import msgpack
            return msgpack.packb(data, default=str)
        else:
            raise ValueError(f"Unsupported serialization format: {self.serialization_format}")
    
    def _deserialize_data(self, data: bytes) -> Any:
        """反序列化数据"""
        if self.serialization_format == "pickle":
            return pickle.loads(data)
        elif self.serialization_format == "json":
            import json
            return json.loads(data.decode('utf-8'))
        elif self.serialization_format == "msgpack":
            import msgpack
            return msgpack.unpackb(data, raw=False)
        else:
            raise ValueError(f"Unsupported serialization format: {self.serialization_format}")
    
    def delete(self, key: str) -> None:
        """删除Redis中的数据"""
        if not self.redis_client:
            return
        
        try:
            full_key = self.key_prefix + key
            self.redis_client.delete(full_key)
        except Exception as e:
            print(f"Error deleting data from Redis: {e}")
    
    def clear_expired(self) -> None:
        """清理过期的键(Redis自动处理TTL)"""
        # Redis会自动清理过期键,这里可以添加额外的清理逻辑
        pass

5. 缓存策略配置

TTL配置

# 不同数据类型的TTL配置
TTL_CONFIG = {
    "price_data": 60,           # 1分钟 - 价格数据变化快
    "fundamental_data": 3600,   # 1小时 - 基本面数据相对稳定
    "company_profile": 86400,   # 24小时 - 公司信息变化很少
    "news_data": 1800,          # 30分钟 - 新闻数据中等频率
    "social_data": 900,         # 15分钟 - 社交媒体数据变化较快
    "technical_indicators": 300, # 5分钟 - 技术指标需要较新数据
    "market_data": 600,         # 10分钟 - 市场数据中等频率
    "historical_data": 7200,    # 2小时 - 历史数据相对稳定
    "default": 1800             # 30分钟 - 默认TTL
}

# 缓存重要性评分
DATA_IMPORTANCE = {
    "price_data": 0.9,          # 高重要性
    "fundamental_data": 0.8,    # 高重要性
    "company_profile": 0.7,     # 中高重要性
    "news_data": 0.6,           # 中等重要性
    "social_data": 0.5,         # 中等重要性
    "technical_indicators": 0.7, # 中高重要性
    "market_data": 0.6,         # 中等重要性
    "historical_data": 0.8,     # 高重要性
}

6. 缓存监控和优化

缓存性能监控

class CacheMonitor:
    """缓存性能监控"""
    
    def __init__(self):
        self.metrics = {
            "hits": defaultdict(int),
            "misses": defaultdict(int),
            "hit_rates": defaultdict(float),
            "response_times": defaultdict(list),
            "cache_sizes": defaultdict(int)
        }
        
    def record_hit(self, cache_level: str, key: str, data_type: str, response_time: float = None):
        """记录缓存命中"""
        self.metrics["hits"][cache_level] += 1
        if response_time:
            self.metrics["response_times"][cache_level].append(response_time)
        
        self._update_hit_rate(cache_level)
    
    def record_miss(self, key: str, data_type: str):
        """记录缓存未命中"""
        self.metrics["misses"]["total"] += 1
        self._update_hit_rate("total")
    
    def _update_hit_rate(self, cache_level: str):
        """更新命中率"""
        hits = self.metrics["hits"][cache_level]
        misses = self.metrics["misses"].get(cache_level, 0)
        total = hits + misses
        
        if total > 0:
            self.metrics["hit_rates"][cache_level] = hits / total
    
    def get_performance_report(self) -> Dict:
        """获取性能报告"""
        return {
            "hit_rates": dict(self.metrics["hit_rates"]),
            "total_hits": sum(self.metrics["hits"].values()),
            "total_misses": sum(self.metrics["misses"].values()),
            "avg_response_times": {
                level: sum(times) / len(times) if times else 0
                for level, times in self.metrics["response_times"].items()
            },
            "cache_efficiency": self._calculate_cache_efficiency()
        }
    
    def _calculate_cache_efficiency(self) -> float:
        """计算缓存效率"""
        total_hits = sum(self.metrics["hits"].values())
        total_requests = total_hits + sum(self.metrics["misses"].values())
        
        return total_hits / total_requests if total_requests > 0 else 0.0

7. 缓存最佳实践

使用建议

class CacheBestPractices:
    """缓存最佳实践指南"""
    
    @staticmethod
    def generate_cache_key(symbol: str, data_type: str, date: str = None, **kwargs) -> str:
        """生成标准化的缓存键"""
        
        key_parts = [symbol.upper(), data_type]
        
        if date:
            key_parts.append(date)
        
        # 添加其他参数
        for k, v in sorted(kwargs.items()):
            key_parts.append(f"{k}:{v}")
        
        return ":".join(key_parts)
    
    @staticmethod
    def should_cache_data(data: Any, data_type: str) -> bool:
        """判断是否应该缓存数据"""
        
        # 不缓存空数据
        if not data:
            return False
        
        # 不缓存错误数据
        if isinstance(data, dict) and "error" in data:
            return False
        
        # 不缓存过大的数据
        data_size = CacheBestPractices._estimate_size(data)
        if data_size > 100 * 1024 * 1024:  # 100MB
            return False
        
        return True
    
    @staticmethod
    def _estimate_size(obj: Any) -> int:
        """估算对象大小"""
        try:
            return len(pickle.dumps(obj))
        except:
            return 0

通过这套完整的缓存策略,TradingAgents 能够显著提高数据访问性能,减少API调用成本,并提供更好的用户体验。