""" Multi-Pair Data Feed for Live Trading. Fetches real-time OHLCV and funding data for all assets in the universe. """ import logging from itertools import combinations import pandas as pd import numpy as np import ta from live_trading.okx_client import OKXClient from .config import MultiPairLiveConfig, PathConfig logger = logging.getLogger(__name__) class TradingPair: """ Represents a tradeable pair for spread analysis. Attributes: base_asset: First asset symbol (e.g., ETH/USDT:USDT) quote_asset: Second asset symbol (e.g., BTC/USDT:USDT) pair_id: Unique identifier """ def __init__(self, base_asset: str, quote_asset: str): self.base_asset = base_asset self.quote_asset = quote_asset self.pair_id = f"{base_asset}__{quote_asset}" @property def name(self) -> str: """Human-readable pair name.""" base = self.base_asset.split("/")[0] quote = self.quote_asset.split("/")[0] return f"{base}/{quote}" def __hash__(self): return hash(self.pair_id) def __eq__(self, other): if not isinstance(other, TradingPair): return False return self.pair_id == other.pair_id class MultiPairDataFeed: """ Real-time data feed for multi-pair strategy. Fetches OHLCV data for all assets and calculates spread features for all pair combinations. """ def __init__( self, okx_client: OKXClient, config: MultiPairLiveConfig, path_config: PathConfig ): self.client = okx_client self.config = config self.paths = path_config # Cache for asset data self._asset_data: dict[str, pd.DataFrame] = {} self._funding_rates: dict[str, float] = {} self._pairs: list[TradingPair] = [] # Generate pairs self._generate_pairs() def _generate_pairs(self) -> None: """Generate all unique pairs from asset universe.""" self._pairs = [] for base, quote in combinations(self.config.assets, 2): pair = TradingPair(base_asset=base, quote_asset=quote) self._pairs.append(pair) logger.info("Generated %d pairs from %d assets", len(self._pairs), len(self.config.assets)) @property def pairs(self) -> list[TradingPair]: """Get list of trading pairs.""" return self._pairs def fetch_all_ohlcv(self) -> dict[str, pd.DataFrame]: """ Fetch OHLCV data for all assets. Returns: Dictionary mapping symbol to OHLCV DataFrame """ self._asset_data = {} for symbol in self.config.assets: try: ohlcv = self.client.fetch_ohlcv( symbol, self.config.timeframe, self.config.candles_to_fetch ) df = self._ohlcv_to_dataframe(ohlcv) if len(df) >= 200: self._asset_data[symbol] = df logger.debug("Fetched %s: %d candles", symbol, len(df)) else: logger.warning("Skipping %s: insufficient data (%d)", symbol, len(df)) except Exception as e: logger.error("Error fetching %s: %s", symbol, e) logger.info("Fetched data for %d/%d assets", len(self._asset_data), len(self.config.assets)) return self._asset_data def _ohlcv_to_dataframe(self, ohlcv: list) -> pd.DataFrame: """Convert OHLCV list to DataFrame.""" df = pd.DataFrame( ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'] ) df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True) df.set_index('timestamp', inplace=True) return df def fetch_all_funding_rates(self) -> dict[str, float]: """ Fetch current funding rates for all assets. Returns: Dictionary mapping symbol to funding rate """ self._funding_rates = {} for symbol in self.config.assets: try: rate = self.client.get_funding_rate(symbol) self._funding_rates[symbol] = rate except Exception as e: logger.warning("Could not get funding for %s: %s", symbol, e) self._funding_rates[symbol] = 0.0 return self._funding_rates def calculate_pair_features( self, pair: TradingPair ) -> pd.DataFrame | None: """ Calculate features for a single pair. Args: pair: Trading pair Returns: DataFrame with features, or None if insufficient data """ base = pair.base_asset quote = pair.quote_asset if base not in self._asset_data or quote not in self._asset_data: return None df_base = self._asset_data[base] df_quote = self._asset_data[quote] # Align indices common_idx = df_base.index.intersection(df_quote.index) if len(common_idx) < 200: return None df_a = df_base.loc[common_idx] df_b = df_quote.loc[common_idx] # Calculate spread (base / quote) spread = df_a['close'] / df_b['close'] # Z-Score z_window = self.config.z_window rolling_mean = spread.rolling(window=z_window).mean() rolling_std = spread.rolling(window=z_window).std() z_score = (spread - rolling_mean) / rolling_std # Spread Technicals spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi() spread_roc = spread.pct_change(periods=5) * 100 spread_change_1h = spread.pct_change(periods=1) # Volume Analysis vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10) vol_ratio_ma = vol_ratio.rolling(window=12).mean() vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10) # Volatility ret_a = df_a['close'].pct_change() ret_b = df_b['close'].pct_change() vol_a = ret_a.rolling(window=z_window).std() vol_b = ret_b.rolling(window=z_window).std() vol_spread_ratio = vol_a / (vol_b + 1e-10) # Realized Volatility realized_vol_a = ret_a.rolling(window=24).std() realized_vol_b = ret_b.rolling(window=24).std() # ATR (Average True Range) high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close'] tr_a = pd.concat([ high_a - low_a, (high_a - close_a.shift(1)).abs(), (low_a - close_a.shift(1)).abs() ], axis=1).max(axis=1) atr_a = tr_a.rolling(window=self.config.atr_period).mean() atr_pct_a = atr_a / close_a # Build feature DataFrame features = pd.DataFrame(index=common_idx) features['pair_id'] = pair.pair_id features['base_asset'] = base features['quote_asset'] = quote # Price data features['spread'] = spread features['base_close'] = df_a['close'] features['quote_close'] = df_b['close'] features['base_volume'] = df_a['volume'] # Core Features features['z_score'] = z_score features['spread_rsi'] = spread_rsi features['spread_roc'] = spread_roc features['spread_change_1h'] = spread_change_1h features['vol_ratio'] = vol_ratio features['vol_ratio_rel'] = vol_ratio_rel features['vol_diff_ratio'] = vol_spread_ratio # Volatility features['realized_vol_base'] = realized_vol_a features['realized_vol_quote'] = realized_vol_b features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2 # ATR features['atr_base'] = atr_a features['atr_pct_base'] = atr_pct_a # Pair encoding assets = self.config.assets features['base_idx'] = assets.index(base) if base in assets else -1 features['quote_idx'] = assets.index(quote) if quote in assets else -1 # Funding rates base_funding = self._funding_rates.get(base, 0.0) quote_funding = self._funding_rates.get(quote, 0.0) features['base_funding'] = base_funding features['quote_funding'] = quote_funding features['funding_diff'] = base_funding - quote_funding features['funding_avg'] = (base_funding + quote_funding) / 2 # Drop NaN rows in core features core_cols = [ 'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h', 'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio', 'realized_vol_base', 'atr_base', 'atr_pct_base' ] features = features.dropna(subset=core_cols) return features def calculate_all_pair_features(self) -> dict[str, pd.DataFrame]: """ Calculate features for all pairs. Returns: Dictionary mapping pair_id to feature DataFrame """ all_features = {} for pair in self._pairs: features = self.calculate_pair_features(pair) if features is not None and len(features) > 0: all_features[pair.pair_id] = features logger.info("Calculated features for %d/%d pairs", len(all_features), len(self._pairs)) return all_features def get_latest_data(self) -> dict[str, pd.DataFrame] | None: """ Fetch and process latest market data for all pairs. Returns: Dictionary of pair features or None on error """ try: # Fetch OHLCV for all assets self.fetch_all_ohlcv() if len(self._asset_data) < 2: logger.warning("Insufficient assets fetched") return None # Fetch funding rates self.fetch_all_funding_rates() # Calculate features for all pairs pair_features = self.calculate_all_pair_features() if not pair_features: logger.warning("No pair features calculated") return None logger.info("Processed %d pairs with valid features", len(pair_features)) return pair_features except Exception as e: logger.error("Error fetching market data: %s", e, exc_info=True) return None def get_pair_by_id(self, pair_id: str) -> TradingPair | None: """Get pair object by ID.""" for pair in self._pairs: if pair.pair_id == pair_id: return pair return None def get_current_price(self, symbol: str) -> float | None: """Get current price for a symbol.""" if symbol in self._asset_data: return self._asset_data[symbol]['close'].iloc[-1] return None