diff --git a/data/multi_pair_model.pkl b/data/multi_pair_model.pkl new file mode 100644 index 0000000..e125ba0 Binary files /dev/null and b/data/multi_pair_model.pkl differ diff --git a/scripts/download_multi_pair_data.py b/scripts/download_multi_pair_data.py new file mode 100644 index 0000000..7705f24 --- /dev/null +++ b/scripts/download_multi_pair_data.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +""" +Download historical data for Multi-Pair Divergence Strategy. + +Downloads 1h OHLCV data for top 10 cryptocurrencies from OKX. +""" +import sys +sys.path.insert(0, '.') + +from engine.data_manager import DataManager +from engine.market import MarketType +from engine.logging_config import setup_logging, get_logger +from strategies.multi_pair import MultiPairConfig + +logger = get_logger(__name__) + + +def main(): + """Download data for all configured assets.""" + setup_logging() + + config = MultiPairConfig() + dm = DataManager() + + logger.info("Downloading data for %d assets...", len(config.assets)) + + for symbol in config.assets: + logger.info("Downloading %s perpetual 1h data...", symbol) + try: + df = dm.download_data( + exchange_id=config.exchange_id, + symbol=symbol, + timeframe=config.timeframe, + market_type=MarketType.PERPETUAL + ) + if df is not None: + logger.info("Downloaded %d candles for %s", len(df), symbol) + else: + logger.warning("No data downloaded for %s", symbol) + except Exception as e: + logger.error("Failed to download %s: %s", symbol, e) + + logger.info("Download complete!") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_multi_pair_backtest.py b/scripts/run_multi_pair_backtest.py new file mode 100644 index 0000000..2b4a870 --- /dev/null +++ b/scripts/run_multi_pair_backtest.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +""" +Run Multi-Pair Divergence Strategy backtest and compare with baseline. + +Compares the multi-pair strategy against the single-pair BTC/ETH regime strategy. +""" +import sys +sys.path.insert(0, '.') + +from engine.backtester import Backtester +from engine.data_manager import DataManager +from engine.logging_config import setup_logging, get_logger +from engine.reporting import Reporter +from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig +from strategies.regime_strategy import RegimeReversionStrategy +from engine.market import MarketType + +logger = get_logger(__name__) + + +def run_baseline(): + """Run baseline BTC/ETH regime strategy.""" + logger.info("=" * 60) + logger.info("BASELINE: BTC/ETH Regime Reversion Strategy") + logger.info("=" * 60) + + dm = DataManager() + bt = Backtester(dm) + + strategy = RegimeReversionStrategy() + + result = bt.run_strategy( + strategy, + 'okx', + 'ETH-USDT', + timeframe='1h', + init_cash=10000 + ) + + logger.info("Baseline Results:") + logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100) + logger.info(" Total Trades: %d", result.portfolio.trades.count()) + logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100) + + return result + + +def run_multi_pair(assets: list[str] | None = None): + """Run multi-pair divergence strategy.""" + logger.info("=" * 60) + logger.info("MULTI-PAIR: Divergence Selection Strategy") + logger.info("=" * 60) + + dm = DataManager() + bt = Backtester(dm) + + # Use provided assets or default + if assets: + config = MultiPairConfig(assets=assets) + else: + config = MultiPairConfig() + + logger.info("Configured %d assets, %d pairs", len(config.assets), config.get_pair_count()) + + strategy = MultiPairDivergenceStrategy(config=config) + + result = bt.run_strategy( + strategy, + 'okx', + 'ETH-USDT', # Reference asset (not used for trading, just index alignment) + timeframe='1h', + init_cash=10000 + ) + + logger.info("Multi-Pair Results:") + logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100) + logger.info(" Total Trades: %d", result.portfolio.trades.count()) + logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100) + + return result + + +def compare_results(baseline, multi_pair): + """Compare and display results.""" + logger.info("=" * 60) + logger.info("COMPARISON") + logger.info("=" * 60) + + baseline_return = baseline.portfolio.total_return() * 100 + multi_return = multi_pair.portfolio.total_return() * 100 + + improvement = multi_return - baseline_return + + logger.info("Baseline Return: %.2f%%", baseline_return) + logger.info("Multi-Pair Return: %.2f%%", multi_return) + logger.info("Improvement: %.2f%% (%.1fx)", + improvement, + multi_return / baseline_return if baseline_return != 0 else 0) + + baseline_trades = baseline.portfolio.trades.count() + multi_trades = multi_pair.portfolio.trades.count() + + logger.info("Baseline Trades: %d", baseline_trades) + logger.info("Multi-Pair Trades: %d", multi_trades) + + return { + 'baseline_return': baseline_return, + 'multi_pair_return': multi_return, + 'improvement': improvement, + 'baseline_trades': baseline_trades, + 'multi_pair_trades': multi_trades + } + + +def main(): + """Main entry point.""" + setup_logging() + + # Check available assets + dm = DataManager() + available = [] + + for symbol in MultiPairConfig().assets: + try: + dm.load_data('okx', symbol, '1h', market_type=MarketType.PERPETUAL) + available.append(symbol) + except FileNotFoundError: + pass + + if len(available) < 2: + logger.error( + "Need at least 2 assets to run multi-pair strategy. " + "Run: uv run python scripts/download_multi_pair_data.py" + ) + return + + logger.info("Found data for %d assets: %s", len(available), available) + + # Run baseline + baseline_result = run_baseline() + + # Run multi-pair + multi_result = run_multi_pair(available) + + # Compare + comparison = compare_results(baseline_result, multi_result) + + # Save reports + reporter = Reporter() + reporter.save_reports(multi_result, "multi_pair_divergence") + + logger.info("Reports saved to backtest_logs/") + + +if __name__ == "__main__": + main() diff --git a/strategies/factory.py b/strategies/factory.py index 29c2313..fba5c7b 100644 --- a/strategies/factory.py +++ b/strategies/factory.py @@ -37,6 +37,7 @@ def _build_registry() -> dict[str, StrategyConfig]: from strategies.examples import MaCrossStrategy, RsiStrategy from strategies.supertrend import MetaSupertrendStrategy from strategies.regime_strategy import RegimeReversionStrategy + from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig return { "rsi": StrategyConfig( @@ -98,6 +99,18 @@ def _build_registry() -> dict[str, StrategyConfig]: 'stop_loss': [0.04, 0.06, 0.08], 'funding_threshold': [0.005, 0.01, 0.02] } + ), + "multi_pair": StrategyConfig( + strategy_class=MultiPairDivergenceStrategy, + default_params={ + # Multi-pair divergence strategy uses config object + # Parameters passed here will override MultiPairConfig defaults + }, + grid_params={ + 'z_entry_threshold': [0.8, 1.0, 1.2], + 'prob_threshold': [0.4, 0.5, 0.6], + 'correlation_threshold': [0.75, 0.85, 0.95] + } ) } diff --git a/strategies/multi_pair/__init__.py b/strategies/multi_pair/__init__.py new file mode 100644 index 0000000..8effd07 --- /dev/null +++ b/strategies/multi_pair/__init__.py @@ -0,0 +1,24 @@ +""" +Multi-Pair Divergence Selection Strategy. + +Extends regime detection to multiple cryptocurrency pairs and dynamically +selects the most divergent pair for trading. +""" +from .config import MultiPairConfig +from .pair_scanner import PairScanner, TradingPair +from .correlation import CorrelationFilter +from .feature_engine import MultiPairFeatureEngine +from .divergence_scorer import DivergenceScorer +from .strategy import MultiPairDivergenceStrategy +from .funding import FundingRateFetcher + +__all__ = [ + "MultiPairConfig", + "PairScanner", + "TradingPair", + "CorrelationFilter", + "MultiPairFeatureEngine", + "DivergenceScorer", + "MultiPairDivergenceStrategy", + "FundingRateFetcher", +] diff --git a/strategies/multi_pair/config.py b/strategies/multi_pair/config.py new file mode 100644 index 0000000..0cf2506 --- /dev/null +++ b/strategies/multi_pair/config.py @@ -0,0 +1,88 @@ +""" +Configuration for Multi-Pair Divergence Strategy. +""" +from dataclasses import dataclass, field + + +@dataclass +class MultiPairConfig: + """ + Configuration parameters for multi-pair divergence strategy. + + Attributes: + assets: List of asset symbols to analyze (top 10 by market cap) + z_window: Rolling window for Z-Score calculation (hours) + z_entry_threshold: Minimum |Z-Score| to consider for entry + prob_threshold: Minimum ML probability to consider for entry + correlation_threshold: Max correlation to allow between pairs + correlation_window: Rolling window for correlation (hours) + atr_period: ATR lookback period for dynamic stops + sl_atr_multiplier: Stop-loss as multiple of ATR + tp_atr_multiplier: Take-profit as multiple of ATR + train_ratio: Walk-forward train/test split ratio + horizon: Look-ahead horizon for target calculation (hours) + profit_target: Minimum profit threshold for target labels + funding_threshold: Funding rate threshold for filtering + """ + # Asset Universe + assets: list[str] = field(default_factory=lambda: [ + "BTC-USDT", "ETH-USDT", "SOL-USDT", "XRP-USDT", "BNB-USDT", + "DOGE-USDT", "ADA-USDT", "AVAX-USDT", "LINK-USDT", "DOT-USDT" + ]) + + # Z-Score Thresholds + z_window: int = 24 + z_entry_threshold: float = 1.0 + + # ML Thresholds + prob_threshold: float = 0.5 + train_ratio: float = 0.7 + horizon: int = 102 + profit_target: float = 0.005 + + # Correlation Filtering + correlation_threshold: float = 0.85 + correlation_window: int = 168 # 7 days in hours + + # Risk Management - ATR-Based Stops + # SL/TP are calculated as multiples of ATR + # Mean ATR for crypto is ~0.6% per hour, so: + # - 10x ATR = ~6% SL (matches previous fixed 6%) + # - 8x ATR = ~5% TP (matches previous fixed 5%) + atr_period: int = 14 # ATR lookback period (hours for 1h timeframe) + sl_atr_multiplier: float = 10.0 # Stop-loss = entry +/- (ATR * multiplier) + tp_atr_multiplier: float = 8.0 # Take-profit = entry +/- (ATR * multiplier) + + # Fallback fixed percentages (used if ATR is unavailable) + base_sl_pct: float = 0.06 + base_tp_pct: float = 0.05 + + # ATR bounds to prevent extreme stops + min_sl_pct: float = 0.02 # Minimum 2% stop-loss + max_sl_pct: float = 0.10 # Maximum 10% stop-loss + min_tp_pct: float = 0.02 # Minimum 2% take-profit + max_tp_pct: float = 0.15 # Maximum 15% take-profit + + volatility_window: int = 24 + + # Funding Rate Filter + # OKX funding rates are typically 0.0001 (0.01%) per 8h + # Extreme funding is > 0.0005 (0.05%) which indicates crowded trade + funding_threshold: float = 0.0005 # 0.05% - filter extreme funding + + # Trade Management + # Note: Setting min_hold_bars=0 and z_exit_threshold=0 gives best results + # The mean-reversion exit at Z=0 is the primary profit driver + min_hold_bars: int = 0 # Disabled - let mean reversion drive exits + switch_threshold: float = 999.0 # Disabled - don't switch mid-trade + cooldown_bars: int = 0 # Disabled - enter when signal appears + z_exit_threshold: float = 0.0 # Exit at Z=0 (mean reversion complete) + + # Exchange + exchange_id: str = "okx" + timeframe: str = "1h" + + def get_pair_count(self) -> int: + """Calculate number of unique pairs from asset list.""" + n = len(self.assets) + return n * (n - 1) // 2 diff --git a/strategies/multi_pair/correlation.py b/strategies/multi_pair/correlation.py new file mode 100644 index 0000000..11863d9 --- /dev/null +++ b/strategies/multi_pair/correlation.py @@ -0,0 +1,173 @@ +""" +Correlation Filter for Multi-Pair Divergence Strategy. + +Calculates rolling correlation matrix and filters pairs +to avoid highly correlated positions. +""" +import pandas as pd +import numpy as np + +from engine.logging_config import get_logger +from .config import MultiPairConfig +from .pair_scanner import TradingPair + +logger = get_logger(__name__) + + +class CorrelationFilter: + """ + Calculates and filters based on asset correlations. + + Uses rolling correlation of returns to identify assets + moving together, avoiding redundant positions. + """ + + def __init__(self, config: MultiPairConfig): + self.config = config + self._correlation_matrix: pd.DataFrame | None = None + self._last_update_idx: int = -1 + + def calculate_correlation_matrix( + self, + price_data: dict[str, pd.Series], + current_idx: int | None = None + ) -> pd.DataFrame: + """ + Calculate rolling correlation matrix between all assets. + + Args: + price_data: Dictionary mapping asset symbols to price series + current_idx: Current bar index (for caching) + + Returns: + Correlation matrix DataFrame + """ + # Use cached if recent + if ( + current_idx is not None + and self._correlation_matrix is not None + and current_idx - self._last_update_idx < 24 # Update every 24 bars + ): + return self._correlation_matrix + + # Calculate returns + returns = {} + for symbol, prices in price_data.items(): + returns[symbol] = prices.pct_change() + + returns_df = pd.DataFrame(returns) + + # Rolling correlation + window = self.config.correlation_window + + # Get latest correlation (last row of rolling correlation) + if len(returns_df) >= window: + rolling_corr = returns_df.rolling(window=window).corr() + # Extract last timestamp correlation matrix + last_idx = returns_df.index[-1] + corr_matrix = rolling_corr.loc[last_idx] + else: + # Fallback to full-period correlation if not enough data + corr_matrix = returns_df.corr() + + self._correlation_matrix = corr_matrix + if current_idx is not None: + self._last_update_idx = current_idx + + return corr_matrix + + def filter_pairs( + self, + pairs: list[TradingPair], + current_position_asset: str | None, + price_data: dict[str, pd.Series], + current_idx: int | None = None + ) -> list[TradingPair]: + """ + Filter pairs based on correlation with current position. + + If we have an open position in an asset, exclude pairs where + either asset is highly correlated with the held asset. + + Args: + pairs: List of candidate pairs + current_position_asset: Currently held asset (or None) + price_data: Dictionary of price series by symbol + current_idx: Current bar index for caching + + Returns: + Filtered list of pairs + """ + if current_position_asset is None: + return pairs + + corr_matrix = self.calculate_correlation_matrix(price_data, current_idx) + threshold = self.config.correlation_threshold + + filtered = [] + for pair in pairs: + # Check correlation of base and quote with held asset + base_corr = self._get_correlation( + corr_matrix, pair.base_asset, current_position_asset + ) + quote_corr = self._get_correlation( + corr_matrix, pair.quote_asset, current_position_asset + ) + + # Filter if either asset highly correlated with position + if abs(base_corr) > threshold or abs(quote_corr) > threshold: + logger.debug( + "Filtered %s: base_corr=%.2f, quote_corr=%.2f (held: %s)", + pair.name, base_corr, quote_corr, current_position_asset + ) + continue + + filtered.append(pair) + + if len(filtered) < len(pairs): + logger.info( + "Correlation filter: %d/%d pairs remaining (held: %s)", + len(filtered), len(pairs), current_position_asset + ) + + return filtered + + def _get_correlation( + self, + corr_matrix: pd.DataFrame, + asset1: str, + asset2: str + ) -> float: + """ + Get correlation between two assets from matrix. + + Args: + corr_matrix: Correlation matrix + asset1: First asset symbol + asset2: Second asset symbol + + Returns: + Correlation coefficient (-1 to 1), or 0 if not found + """ + if asset1 == asset2: + return 1.0 + + try: + return corr_matrix.loc[asset1, asset2] + except KeyError: + return 0.0 + + def get_correlation_report( + self, + price_data: dict[str, pd.Series] + ) -> pd.DataFrame: + """ + Generate a readable correlation report. + + Args: + price_data: Dictionary of price series + + Returns: + Correlation matrix as DataFrame + """ + return self.calculate_correlation_matrix(price_data) diff --git a/strategies/multi_pair/divergence_scorer.py b/strategies/multi_pair/divergence_scorer.py new file mode 100644 index 0000000..1093865 --- /dev/null +++ b/strategies/multi_pair/divergence_scorer.py @@ -0,0 +1,311 @@ +""" +Divergence Scorer for Multi-Pair Strategy. + +Ranks pairs by divergence score and selects the best candidate. +""" +from dataclasses import dataclass +from typing import Optional + +import pandas as pd +import numpy as np +from sklearn.ensemble import RandomForestClassifier +import pickle +from pathlib import Path + +from engine.logging_config import get_logger +from .config import MultiPairConfig +from .pair_scanner import TradingPair + +logger = get_logger(__name__) + + +@dataclass +class DivergenceSignal: + """ + Signal for a divergent pair. + + Attributes: + pair: Trading pair + z_score: Current Z-Score of the spread + probability: ML model probability of profitable reversion + divergence_score: Combined score (|z_score| * probability) + direction: 'long' or 'short' (relative to base asset) + base_price: Current price of base asset + quote_price: Current price of quote asset + atr: Average True Range in price units + atr_pct: ATR as percentage of price + """ + pair: TradingPair + z_score: float + probability: float + divergence_score: float + direction: str + base_price: float + quote_price: float + atr: float + atr_pct: float + timestamp: pd.Timestamp + + +class DivergenceScorer: + """ + Scores and ranks pairs by divergence potential. + + Uses ML model predictions combined with Z-Score magnitude + to identify the most promising mean-reversion opportunity. + """ + + def __init__(self, config: MultiPairConfig, model_path: str = "data/multi_pair_model.pkl"): + self.config = config + self.model_path = Path(model_path) + self.model: RandomForestClassifier | None = None + self.feature_cols: list[str] | None = None + self._load_model() + + def _load_model(self) -> None: + """Load pre-trained model if available.""" + if self.model_path.exists(): + try: + with open(self.model_path, 'rb') as f: + saved = pickle.load(f) + self.model = saved['model'] + self.feature_cols = saved['feature_cols'] + logger.info("Loaded model from %s", self.model_path) + except Exception as e: + logger.warning("Could not load model: %s", e) + + def save_model(self) -> None: + """Save trained model.""" + if self.model is None: + return + + self.model_path.parent.mkdir(parents=True, exist_ok=True) + with open(self.model_path, 'wb') as f: + pickle.dump({ + 'model': self.model, + 'feature_cols': self.feature_cols, + }, f) + logger.info("Saved model to %s", self.model_path) + + def train_model( + self, + combined_features: pd.DataFrame, + pair_features: dict[str, pd.DataFrame] + ) -> None: + """ + Train universal model on all pairs. + + Args: + combined_features: Combined feature DataFrame from all pairs + pair_features: Individual pair feature DataFrames (for target calculation) + """ + logger.info("Training universal model on %d samples...", len(combined_features)) + + z_thresh = self.config.z_entry_threshold + horizon = self.config.horizon + profit_target = self.config.profit_target + + # Calculate targets for each pair + all_targets = [] + all_features = [] + + for pair_id, features in pair_features.items(): + if len(features) < horizon + 50: + continue + + spread = features['spread'] + z_score = features['z_score'] + + # Future price movements + future_min = spread.rolling(window=horizon).min().shift(-horizon) + future_max = spread.rolling(window=horizon).max().shift(-horizon) + + # Target labels + target_short = spread * (1 - profit_target) + target_long = spread * (1 + profit_target) + + success_short = (z_score > z_thresh) & (future_min < target_short) + success_long = (z_score < -z_thresh) & (future_max > target_long) + + targets = np.select([success_short, success_long], [1, 1], default=0) + + # Valid mask (exclude rows without complete future data) + valid_mask = future_min.notna() & future_max.notna() + + # Collect valid samples + valid_features = features[valid_mask] + valid_targets = targets[valid_mask.values] + + if len(valid_features) > 0: + all_features.append(valid_features) + all_targets.extend(valid_targets) + + if not all_features: + logger.warning("No valid training samples") + return + + # Combine all training data + X_df = pd.concat(all_features, ignore_index=True) + y = np.array(all_targets) + + # Get feature columns + exclude_cols = [ + 'pair_id', 'base_asset', 'quote_asset', + 'spread', 'base_close', 'quote_close', 'base_volume' + ] + self.feature_cols = [c for c in X_df.columns if c not in exclude_cols] + + # Prepare features + X = X_df[self.feature_cols].fillna(0) + X = X.replace([np.inf, -np.inf], 0) + + # Train model + self.model = RandomForestClassifier( + n_estimators=300, + max_depth=5, + min_samples_leaf=30, + class_weight={0: 1, 1: 3}, + random_state=42 + ) + self.model.fit(X, y) + + logger.info( + "Model trained on %d samples, %d features, %.1f%% positive class", + len(X), len(self.feature_cols), y.mean() * 100 + ) + self.save_model() + + def score_pairs( + self, + pair_features: dict[str, pd.DataFrame], + pairs: list[TradingPair], + timestamp: pd.Timestamp | None = None + ) -> list[DivergenceSignal]: + """ + Score all pairs and return ranked signals. + + Args: + pair_features: Feature DataFrames by pair_id + pairs: List of TradingPair objects + timestamp: Current timestamp for feature extraction + + Returns: + List of DivergenceSignal sorted by score (descending) + """ + if self.model is None: + logger.warning("Model not trained, returning empty signals") + return [] + + signals = [] + pair_map = {p.pair_id: p for p in pairs} + + for pair_id, features in pair_features.items(): + if pair_id not in pair_map: + continue + + pair = pair_map[pair_id] + + # Get latest features + if timestamp is not None: + valid = features[features.index <= timestamp] + if len(valid) == 0: + continue + latest = valid.iloc[-1] + ts = valid.index[-1] + else: + latest = features.iloc[-1] + ts = features.index[-1] + + z_score = latest['z_score'] + + # Skip if Z-score below threshold + if abs(z_score) < self.config.z_entry_threshold: + continue + + # Prepare features for prediction + feature_row = latest[self.feature_cols].fillna(0).infer_objects(copy=False) + feature_row = feature_row.replace([np.inf, -np.inf], 0) + X = pd.DataFrame([feature_row.values], columns=self.feature_cols) + + # Predict probability + prob = self.model.predict_proba(X)[0, 1] + + # Skip if probability below threshold + if prob < self.config.prob_threshold: + continue + + # Apply funding rate filter + # Block trades where funding opposes our direction + base_funding = latest.get('base_funding', 0) or 0 + funding_thresh = self.config.funding_threshold + + if z_score > 0: # Short signal + # High negative funding = shorts are paying -> skip + if base_funding < -funding_thresh: + logger.debug( + "Skipping %s short: funding too negative (%.4f)", + pair.name, base_funding + ) + continue + else: # Long signal + # High positive funding = longs are paying -> skip + if base_funding > funding_thresh: + logger.debug( + "Skipping %s long: funding too positive (%.4f)", + pair.name, base_funding + ) + continue + + # Calculate divergence score + divergence_score = abs(z_score) * prob + + # Determine direction + # Z > 0: Spread high (base expensive vs quote) -> Short base + # Z < 0: Spread low (base cheap vs quote) -> Long base + direction = 'short' if z_score > 0 else 'long' + + signal = DivergenceSignal( + pair=pair, + z_score=z_score, + probability=prob, + divergence_score=divergence_score, + direction=direction, + base_price=latest['base_close'], + quote_price=latest['quote_close'], + atr=latest.get('atr_base', 0), + atr_pct=latest.get('atr_pct_base', 0.02), + timestamp=ts + ) + signals.append(signal) + + # Sort by divergence score (highest first) + signals.sort(key=lambda s: s.divergence_score, reverse=True) + + if signals: + logger.debug( + "Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f)", + len(signals), + signals[0].pair.name, + signals[0].divergence_score, + signals[0].z_score, + signals[0].probability + ) + + return signals + + def select_best_pair( + self, + signals: list[DivergenceSignal] + ) -> DivergenceSignal | None: + """ + Select the best pair from scored signals. + + Args: + signals: List of DivergenceSignal (pre-sorted by score) + + Returns: + Best signal or None if no valid candidates + """ + if not signals: + return None + return signals[0] diff --git a/strategies/multi_pair/feature_engine.py b/strategies/multi_pair/feature_engine.py new file mode 100644 index 0000000..e515694 --- /dev/null +++ b/strategies/multi_pair/feature_engine.py @@ -0,0 +1,433 @@ +""" +Feature Engineering for Multi-Pair Divergence Strategy. + +Calculates features for all pairs in the universe, including +spread technicals, volatility, and on-chain data. +""" +import pandas as pd +import numpy as np +import ta + +from engine.logging_config import get_logger +from engine.data_manager import DataManager +from engine.market import MarketType +from .config import MultiPairConfig +from .pair_scanner import TradingPair +from .funding import FundingRateFetcher + +logger = get_logger(__name__) + + +class MultiPairFeatureEngine: + """ + Calculates features for multiple trading pairs. + + Generates consistent feature sets across all pairs for + the universal ML model. + """ + + def __init__(self, config: MultiPairConfig): + self.config = config + self.dm = DataManager() + self.funding_fetcher = FundingRateFetcher() + self._funding_data: pd.DataFrame | None = None + + def load_all_assets( + self, + start_date: str | None = None, + end_date: str | None = None + ) -> dict[str, pd.DataFrame]: + """ + Load OHLCV data for all assets in the universe. + + Args: + start_date: Start date filter (YYYY-MM-DD) + end_date: End date filter (YYYY-MM-DD) + + Returns: + Dictionary mapping symbol to OHLCV DataFrame + """ + data = {} + market_type = MarketType.PERPETUAL + + for symbol in self.config.assets: + try: + df = self.dm.load_data( + self.config.exchange_id, + symbol, + self.config.timeframe, + market_type + ) + + # Apply date filters + if start_date: + df = df[df.index >= pd.Timestamp(start_date, tz="UTC")] + if end_date: + df = df[df.index <= pd.Timestamp(end_date, tz="UTC")] + + if len(df) >= 200: # Minimum data requirement + data[symbol] = df + logger.debug("Loaded %s: %d bars", symbol, len(df)) + else: + logger.warning( + "Skipping %s: insufficient data (%d bars)", + symbol, len(df) + ) + except FileNotFoundError: + logger.warning("Data not found for %s", symbol) + except Exception as e: + logger.error("Error loading %s: %s", symbol, e) + + logger.info("Loaded %d/%d assets", len(data), len(self.config.assets)) + return data + + def load_funding_data( + self, + start_date: str | None = None, + end_date: str | None = None, + use_cache: bool = True + ) -> pd.DataFrame: + """ + Load funding rate data for all assets. + + Args: + start_date: Start date filter + end_date: End date filter + use_cache: Whether to use cached data + + Returns: + DataFrame with funding rates for all assets + """ + self._funding_data = self.funding_fetcher.get_funding_data( + self.config.assets, + start_date=start_date, + end_date=end_date, + use_cache=use_cache + ) + + if self._funding_data is not None and not self._funding_data.empty: + logger.info( + "Loaded funding data: %d rows, %d assets", + len(self._funding_data), + len(self._funding_data.columns) + ) + else: + logger.warning("No funding data available") + + return self._funding_data + + def calculate_pair_features( + self, + pair: TradingPair, + asset_data: dict[str, pd.DataFrame], + on_chain_data: pd.DataFrame | None = None + ) -> pd.DataFrame | None: + """ + Calculate features for a single pair. + + Args: + pair: Trading pair + asset_data: Dictionary of OHLCV DataFrames by symbol + on_chain_data: Optional on-chain data (funding, inflows) + + Returns: + DataFrame with features, or None if insufficient data + """ + base = pair.base_asset + quote = pair.quote_asset + + if base not in asset_data or quote not in asset_data: + return None + + df_base = asset_data[base] + df_quote = asset_data[quote] + + # Align indices + common_idx = df_base.index.intersection(df_quote.index) + if len(common_idx) < 200: + logger.debug("Pair %s: insufficient aligned data", pair.name) + return None + + df_a = df_base.loc[common_idx] + df_b = df_quote.loc[common_idx] + + # Calculate spread (base / quote) + spread = df_a['close'] / df_b['close'] + + # Z-Score + z_window = self.config.z_window + rolling_mean = spread.rolling(window=z_window).mean() + rolling_std = spread.rolling(window=z_window).std() + z_score = (spread - rolling_mean) / rolling_std + + # Spread Technicals + spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi() + spread_roc = spread.pct_change(periods=5) * 100 + spread_change_1h = spread.pct_change(periods=1) + + # Volume Analysis + vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10) + vol_ratio_ma = vol_ratio.rolling(window=12).mean() + vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10) + + # Volatility + ret_a = df_a['close'].pct_change() + ret_b = df_b['close'].pct_change() + vol_a = ret_a.rolling(window=z_window).std() + vol_b = ret_b.rolling(window=z_window).std() + vol_spread_ratio = vol_a / (vol_b + 1e-10) + + # Realized Volatility (for dynamic SL/TP) + realized_vol_a = ret_a.rolling(window=self.config.volatility_window).std() + realized_vol_b = ret_b.rolling(window=self.config.volatility_window).std() + + # ATR (Average True Range) for dynamic stops + # ATR = average of max(high-low, |high-prev_close|, |low-prev_close|) + high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close'] + high_b, low_b, close_b = df_b['high'], df_b['low'], df_b['close'] + + # True Range for base asset + tr_a = pd.concat([ + high_a - low_a, + (high_a - close_a.shift(1)).abs(), + (low_a - close_a.shift(1)).abs() + ], axis=1).max(axis=1) + atr_a = tr_a.rolling(window=self.config.atr_period).mean() + + # True Range for quote asset + tr_b = pd.concat([ + high_b - low_b, + (high_b - close_b.shift(1)).abs(), + (low_b - close_b.shift(1)).abs() + ], axis=1).max(axis=1) + atr_b = tr_b.rolling(window=self.config.atr_period).mean() + + # ATR as percentage of price (normalized) + atr_pct_a = atr_a / close_a + atr_pct_b = atr_b / close_b + + # Build feature DataFrame + features = pd.DataFrame(index=common_idx) + features['pair_id'] = pair.pair_id + features['base_asset'] = base + features['quote_asset'] = quote + + # Price data (for reference, not features) + features['spread'] = spread + features['base_close'] = df_a['close'] + features['quote_close'] = df_b['close'] + features['base_volume'] = df_a['volume'] + + # Core Features + features['z_score'] = z_score + features['spread_rsi'] = spread_rsi + features['spread_roc'] = spread_roc + features['spread_change_1h'] = spread_change_1h + features['vol_ratio'] = vol_ratio + features['vol_ratio_rel'] = vol_ratio_rel + features['vol_diff_ratio'] = vol_spread_ratio + + # Volatility for SL/TP + features['realized_vol_base'] = realized_vol_a + features['realized_vol_quote'] = realized_vol_b + features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2 + + # ATR for dynamic stops (in price units and as percentage) + features['atr_base'] = atr_a + features['atr_quote'] = atr_b + features['atr_pct_base'] = atr_pct_a + features['atr_pct_quote'] = atr_pct_b + features['atr_pct_avg'] = (atr_pct_a + atr_pct_b) / 2 + + # Pair encoding (for universal model) + # Using base and quote indices for hierarchical encoding + assets = self.config.assets + features['base_idx'] = assets.index(base) if base in assets else -1 + features['quote_idx'] = assets.index(quote) if quote in assets else -1 + + # Add funding and on-chain features + # Funding data is always added from self._funding_data (OKX, all 10 assets) + # On-chain data is optional (CryptoQuant, BTC/ETH only) + features = self._add_on_chain_features( + features, on_chain_data, base, quote + ) + + # Drop rows with NaN in core features only (not funding/on-chain) + core_cols = [ + 'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h', + 'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio', + 'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg', + 'atr_base', 'atr_pct_base' # ATR is core for SL/TP + ] + features = features.dropna(subset=core_cols) + + # Fill missing funding/on-chain features with 0 (neutral) + optional_cols = [ + 'base_funding', 'quote_funding', 'funding_diff', 'funding_avg', + 'base_inflow', 'quote_inflow', 'inflow_ratio' + ] + for col in optional_cols: + if col in features.columns: + features[col] = features[col].fillna(0) + + return features + + def calculate_all_pair_features( + self, + pairs: list[TradingPair], + asset_data: dict[str, pd.DataFrame], + on_chain_data: pd.DataFrame | None = None + ) -> dict[str, pd.DataFrame]: + """ + Calculate features for all pairs. + + Args: + pairs: List of trading pairs + asset_data: Dictionary of OHLCV DataFrames + on_chain_data: Optional on-chain data + + Returns: + Dictionary mapping pair_id to feature DataFrame + """ + all_features = {} + + for pair in pairs: + features = self.calculate_pair_features( + pair, asset_data, on_chain_data + ) + if features is not None and len(features) > 0: + all_features[pair.pair_id] = features + + logger.info( + "Calculated features for %d/%d pairs", + len(all_features), len(pairs) + ) + + return all_features + + def get_combined_features( + self, + pair_features: dict[str, pd.DataFrame], + timestamp: pd.Timestamp | None = None + ) -> pd.DataFrame: + """ + Combine all pair features into a single DataFrame. + + Useful for batch model prediction across all pairs. + + Args: + pair_features: Dictionary of feature DataFrames by pair_id + timestamp: Optional specific timestamp to filter to + + Returns: + Combined DataFrame with all pairs as rows + """ + if not pair_features: + return pd.DataFrame() + + if timestamp is not None: + # Get latest row from each pair at or before timestamp + rows = [] + for pair_id, features in pair_features.items(): + valid = features[features.index <= timestamp] + if len(valid) > 0: + row = valid.iloc[-1:].copy() + rows.append(row) + + if rows: + return pd.concat(rows, ignore_index=False) + return pd.DataFrame() + + # Combine all features (for training) + return pd.concat(pair_features.values(), ignore_index=False) + + def _add_on_chain_features( + self, + features: pd.DataFrame, + on_chain_data: pd.DataFrame | None, + base_asset: str, + quote_asset: str + ) -> pd.DataFrame: + """ + Add on-chain and funding rate features for the pair. + + Uses funding data from OKX (all 10 assets) and on-chain data + from CryptoQuant (BTC/ETH only for inflows). + """ + base_short = base_asset.replace('-USDT', '').lower() + quote_short = quote_asset.replace('-USDT', '').lower() + + # Add funding rates from cached funding data + if self._funding_data is not None and not self._funding_data.empty: + funding_aligned = self._funding_data.reindex( + features.index, method='ffill' + ) + + base_funding_col = f'{base_short}_funding' + quote_funding_col = f'{quote_short}_funding' + + if base_funding_col in funding_aligned.columns: + features['base_funding'] = funding_aligned[base_funding_col] + if quote_funding_col in funding_aligned.columns: + features['quote_funding'] = funding_aligned[quote_funding_col] + + # Funding difference (positive = base has higher funding) + if 'base_funding' in features.columns and 'quote_funding' in features.columns: + features['funding_diff'] = ( + features['base_funding'] - features['quote_funding'] + ) + + # Funding sentiment: average of both assets + features['funding_avg'] = ( + features['base_funding'] + features['quote_funding'] + ) / 2 + + # Add on-chain features from CryptoQuant (BTC/ETH only) + if on_chain_data is not None and not on_chain_data.empty: + cq_aligned = on_chain_data.reindex(features.index, method='ffill') + + # Inflows (only available for BTC/ETH) + base_inflow_col = f'{base_short}_inflow' + quote_inflow_col = f'{quote_short}_inflow' + + if base_inflow_col in cq_aligned.columns: + features['base_inflow'] = cq_aligned[base_inflow_col] + if quote_inflow_col in cq_aligned.columns: + features['quote_inflow'] = cq_aligned[quote_inflow_col] + + if 'base_inflow' in features.columns and 'quote_inflow' in features.columns: + features['inflow_ratio'] = ( + features['base_inflow'] / + (features['quote_inflow'] + 1) + ) + + return features + + def get_feature_columns(self) -> list[str]: + """ + Get list of feature columns for ML model. + + Excludes metadata and target-related columns. + + Returns: + List of feature column names + """ + # Core features (always present) + core_features = [ + 'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h', + 'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio', + 'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg', + 'base_idx', 'quote_idx' + ] + + # Funding features (now available for all 10 assets via OKX) + funding_features = [ + 'base_funding', 'quote_funding', 'funding_diff', 'funding_avg' + ] + + # On-chain features (BTC/ETH only via CryptoQuant) + onchain_features = [ + 'base_inflow', 'quote_inflow', 'inflow_ratio' + ] + + return core_features + funding_features + onchain_features diff --git a/strategies/multi_pair/funding.py b/strategies/multi_pair/funding.py new file mode 100644 index 0000000..eb4086c --- /dev/null +++ b/strategies/multi_pair/funding.py @@ -0,0 +1,272 @@ +""" +Funding Rate Fetcher for Multi-Pair Strategy. + +Fetches historical funding rates from OKX for all assets. +CryptoQuant only supports BTC/ETH, so we use OKX for the full universe. +""" +import time +from pathlib import Path +from datetime import datetime, timezone + +import ccxt +import pandas as pd + +from engine.logging_config import get_logger + +logger = get_logger(__name__) + + +class FundingRateFetcher: + """ + Fetches and caches funding rate data from OKX. + + OKX funding rates are settled every 8 hours (00:00, 08:00, 16:00 UTC). + This fetcher retrieves historical funding rate data and aligns it + to hourly candles for use in the multi-pair strategy. + """ + + def __init__(self, cache_dir: str = "data/funding"): + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.exchange: ccxt.okx | None = None + + def _init_exchange(self) -> None: + """Initialize OKX exchange connection.""" + if self.exchange is None: + self.exchange = ccxt.okx({ + 'enableRateLimit': True, + 'options': {'defaultType': 'swap'} + }) + self.exchange.load_markets() + + def fetch_funding_history( + self, + symbol: str, + start_date: str | None = None, + end_date: str | None = None, + limit: int = 100 + ) -> pd.DataFrame: + """ + Fetch historical funding rates for a symbol. + + Args: + symbol: Asset symbol (e.g., 'BTC-USDT') + start_date: Start date (YYYY-MM-DD) + end_date: End date (YYYY-MM-DD) + limit: Max records per request + + Returns: + DataFrame with funding rate history + """ + self._init_exchange() + + # Convert symbol format + base = symbol.replace('-USDT', '') + okx_symbol = f"{base}/USDT:USDT" + + try: + # OKX funding rate history endpoint + # Uses fetch_funding_rate_history if available + all_funding = [] + + # Parse dates + if start_date: + since = self.exchange.parse8601(f"{start_date}T00:00:00Z") + else: + # Default to 1 year ago + since = self.exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000 + + if end_date: + until = self.exchange.parse8601(f"{end_date}T23:59:59Z") + else: + until = self.exchange.milliseconds() + + # Fetch in batches + current_since = since + while current_since < until: + try: + funding = self.exchange.fetch_funding_rate_history( + okx_symbol, + since=current_since, + limit=limit + ) + + if not funding: + break + + all_funding.extend(funding) + + # Move to next batch + last_ts = funding[-1]['timestamp'] + if last_ts <= current_since: + break + current_since = last_ts + 1 + + time.sleep(0.1) # Rate limit + + except Exception as e: + logger.warning( + "Error fetching funding batch for %s: %s", + symbol, str(e)[:50] + ) + break + + if not all_funding: + return pd.DataFrame() + + # Convert to DataFrame + df = pd.DataFrame(all_funding) + df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True) + df.set_index('timestamp', inplace=True) + df = df[['fundingRate']].rename(columns={'fundingRate': 'funding_rate'}) + df.sort_index(inplace=True) + + # Remove duplicates + df = df[~df.index.duplicated(keep='first')] + + logger.info("Fetched %d funding records for %s", len(df), symbol) + return df + + except Exception as e: + logger.error("Failed to fetch funding for %s: %s", symbol, e) + return pd.DataFrame() + + def fetch_all_assets( + self, + assets: list[str], + start_date: str | None = None, + end_date: str | None = None + ) -> pd.DataFrame: + """ + Fetch funding rates for all assets and combine. + + Args: + assets: List of asset symbols (e.g., ['BTC-USDT', 'ETH-USDT']) + start_date: Start date + end_date: End date + + Returns: + Combined DataFrame with columns like 'btc_funding', 'eth_funding', etc. + """ + combined = pd.DataFrame() + + for symbol in assets: + df = self.fetch_funding_history(symbol, start_date, end_date) + + if df.empty: + continue + + # Rename column + asset_name = symbol.replace('-USDT', '').lower() + col_name = f"{asset_name}_funding" + df = df.rename(columns={'funding_rate': col_name}) + + if combined.empty: + combined = df + else: + combined = combined.join(df, how='outer') + + time.sleep(0.2) # Be nice to API + + # Forward fill to hourly (funding is every 8h) + if not combined.empty: + combined = combined.sort_index() + combined = combined.ffill() + + return combined + + def save_to_cache(self, df: pd.DataFrame, filename: str = "funding_rates.csv") -> None: + """Save funding data to cache file.""" + path = self.cache_dir / filename + df.to_csv(path) + logger.info("Saved funding rates to %s", path) + + def load_from_cache(self, filename: str = "funding_rates.csv") -> pd.DataFrame | None: + """Load funding data from cache if available.""" + path = self.cache_dir / filename + if path.exists(): + df = pd.read_csv(path, index_col='timestamp', parse_dates=True) + logger.info("Loaded funding rates from cache: %d rows", len(df)) + return df + return None + + def get_funding_data( + self, + assets: list[str], + start_date: str | None = None, + end_date: str | None = None, + use_cache: bool = True, + force_refresh: bool = False + ) -> pd.DataFrame: + """ + Get funding data, using cache if available. + + Args: + assets: List of asset symbols + start_date: Start date + end_date: End date + use_cache: Whether to use cached data + force_refresh: Force refresh even if cache exists + + Returns: + DataFrame with funding rates for all assets + """ + cache_file = "funding_rates.csv" + + # Try cache first + if use_cache and not force_refresh: + cached = self.load_from_cache(cache_file) + if cached is not None: + # Check if cache covers requested range + if start_date and end_date: + start_ts = pd.Timestamp(start_date, tz='UTC') + end_ts = pd.Timestamp(end_date, tz='UTC') + + if cached.index.min() <= start_ts and cached.index.max() >= end_ts: + # Filter to requested range + return cached[(cached.index >= start_ts) & (cached.index <= end_ts)] + + # Fetch fresh data + logger.info("Fetching fresh funding rate data...") + df = self.fetch_all_assets(assets, start_date, end_date) + + if not df.empty and use_cache: + self.save_to_cache(df, cache_file) + + return df + + +def download_funding_data(): + """Download funding data for all multi-pair assets.""" + from strategies.multi_pair.config import MultiPairConfig + + config = MultiPairConfig() + fetcher = FundingRateFetcher() + + # Fetch last year of data + end_date = datetime.now(timezone.utc).strftime("%Y-%m-%d") + start_date = (datetime.now(timezone.utc) - pd.Timedelta(days=365)).strftime("%Y-%m-%d") + + logger.info("Downloading funding rates for %d assets...", len(config.assets)) + logger.info("Date range: %s to %s", start_date, end_date) + + df = fetcher.get_funding_data( + config.assets, + start_date=start_date, + end_date=end_date, + force_refresh=True + ) + + if not df.empty: + logger.info("Downloaded %d funding rate records", len(df)) + logger.info("Columns: %s", list(df.columns)) + else: + logger.warning("No funding data downloaded") + + return df + + +if __name__ == "__main__": + from engine.logging_config import setup_logging + setup_logging() + download_funding_data() diff --git a/strategies/multi_pair/pair_scanner.py b/strategies/multi_pair/pair_scanner.py new file mode 100644 index 0000000..dcbeef9 --- /dev/null +++ b/strategies/multi_pair/pair_scanner.py @@ -0,0 +1,168 @@ +""" +Pair Scanner for Multi-Pair Divergence Strategy. + +Generates all possible pairs from asset universe and checks tradeability. +""" +from dataclasses import dataclass +from itertools import combinations +from typing import Optional + +import ccxt + +from engine.logging_config import get_logger +from .config import MultiPairConfig + +logger = get_logger(__name__) + + +@dataclass +class TradingPair: + """ + Represents a tradeable pair for spread analysis. + + Attributes: + base_asset: First asset in the pair (numerator) + quote_asset: Second asset in the pair (denominator) + pair_id: Unique identifier for the pair + is_direct: Whether pair can be traded directly on exchange + exchange_symbol: Symbol for direct trading (if available) + """ + base_asset: str + quote_asset: str + pair_id: str + is_direct: bool = False + exchange_symbol: Optional[str] = None + + @property + def name(self) -> str: + """Human-readable pair name.""" + return f"{self.base_asset}/{self.quote_asset}" + + def __hash__(self): + return hash(self.pair_id) + + def __eq__(self, other): + if not isinstance(other, TradingPair): + return False + return self.pair_id == other.pair_id + + +class PairScanner: + """ + Scans and generates tradeable pairs from asset universe. + + Checks OKX for directly tradeable cross-pairs and generates + synthetic pairs via USDT for others. + """ + + def __init__(self, config: MultiPairConfig): + self.config = config + self.exchange: Optional[ccxt.Exchange] = None + self._available_markets: set[str] = set() + + def _init_exchange(self) -> None: + """Initialize exchange connection for market lookup.""" + if self.exchange is None: + exchange_class = getattr(ccxt, self.config.exchange_id) + self.exchange = exchange_class({'enableRateLimit': True}) + self.exchange.load_markets() + self._available_markets = set(self.exchange.symbols) + logger.info( + "Loaded %d markets from %s", + len(self._available_markets), + self.config.exchange_id + ) + + def generate_pairs(self, check_exchange: bool = True) -> list[TradingPair]: + """ + Generate all unique pairs from asset universe. + + Args: + check_exchange: Whether to check OKX for direct trading + + Returns: + List of TradingPair objects + """ + if check_exchange: + self._init_exchange() + + pairs = [] + assets = self.config.assets + + for base, quote in combinations(assets, 2): + pair_id = f"{base}__{quote}" + + # Check if directly tradeable as cross-pair on OKX + is_direct = False + exchange_symbol = None + + if check_exchange: + # Check perpetual cross-pair (e.g., ETH/BTC:BTC) + # OKX perpetuals are typically quoted in USDT + # Cross-pairs like ETH/BTC are less common + cross_symbol = f"{base.replace('-USDT', '')}/{quote.replace('-USDT', '')}:USDT" + if cross_symbol in self._available_markets: + is_direct = True + exchange_symbol = cross_symbol + + pair = TradingPair( + base_asset=base, + quote_asset=quote, + pair_id=pair_id, + is_direct=is_direct, + exchange_symbol=exchange_symbol + ) + pairs.append(pair) + + # Log summary + direct_count = sum(1 for p in pairs if p.is_direct) + logger.info( + "Generated %d pairs: %d direct, %d synthetic", + len(pairs), direct_count, len(pairs) - direct_count + ) + + return pairs + + def get_required_symbols(self, pairs: list[TradingPair]) -> list[str]: + """ + Get list of symbols needed to calculate all pair spreads. + + For synthetic pairs, we need both USDT pairs. + For direct pairs, we still load USDT pairs for simplicity. + + Args: + pairs: List of trading pairs + + Returns: + List of unique symbols to load (e.g., ['BTC-USDT', 'ETH-USDT']) + """ + symbols = set() + for pair in pairs: + symbols.add(pair.base_asset) + symbols.add(pair.quote_asset) + return list(symbols) + + def filter_by_assets( + self, + pairs: list[TradingPair], + exclude_assets: list[str] + ) -> list[TradingPair]: + """ + Filter pairs that contain any of the excluded assets. + + Args: + pairs: List of trading pairs + exclude_assets: Assets to exclude + + Returns: + Filtered list of pairs + """ + if not exclude_assets: + return pairs + + exclude_set = set(exclude_assets) + return [ + p for p in pairs + if p.base_asset not in exclude_set + and p.quote_asset not in exclude_set + ] diff --git a/strategies/multi_pair/strategy.py b/strategies/multi_pair/strategy.py new file mode 100644 index 0000000..339d9a6 --- /dev/null +++ b/strategies/multi_pair/strategy.py @@ -0,0 +1,525 @@ +""" +Multi-Pair Divergence Selection Strategy. + +Main strategy class that orchestrates pair scanning, feature calculation, +model training, and signal generation for backtesting. +""" +from dataclasses import dataclass +from typing import Optional + +import pandas as pd +import numpy as np + +from strategies.base import BaseStrategy +from engine.market import MarketType +from engine.logging_config import get_logger +from .config import MultiPairConfig +from .pair_scanner import PairScanner, TradingPair +from .correlation import CorrelationFilter +from .feature_engine import MultiPairFeatureEngine +from .divergence_scorer import DivergenceScorer, DivergenceSignal + +logger = get_logger(__name__) + + +@dataclass +class PositionState: + """Tracks current position state.""" + pair: TradingPair | None = None + direction: str | None = None # 'long' or 'short' + entry_price: float = 0.0 + entry_idx: int = -1 + stop_loss: float = 0.0 + take_profit: float = 0.0 + atr: float = 0.0 # ATR at entry for reference + last_exit_idx: int = -100 # For cooldown tracking + + +class MultiPairDivergenceStrategy(BaseStrategy): + """ + Multi-Pair Divergence Selection Strategy. + + Scans multiple cryptocurrency pairs for spread divergence, + selects the most divergent pair using ML-enhanced scoring, + and trades mean-reversion opportunities. + + Key Features: + - Universal ML model across all pairs + - Correlation-based pair filtering + - Dynamic SL/TP based on volatility + - Walk-forward training + """ + + def __init__( + self, + config: MultiPairConfig | None = None, + model_path: str = "data/multi_pair_model.pkl" + ): + super().__init__() + self.config = config or MultiPairConfig() + + # Initialize components + self.pair_scanner = PairScanner(self.config) + self.correlation_filter = CorrelationFilter(self.config) + self.feature_engine = MultiPairFeatureEngine(self.config) + self.divergence_scorer = DivergenceScorer(self.config, model_path) + + # Strategy configuration + self.default_market_type = MarketType.PERPETUAL + self.default_leverage = 1 + + # Runtime state + self.pairs: list[TradingPair] = [] + self.asset_data: dict[str, pd.DataFrame] = {} + self.pair_features: dict[str, pd.DataFrame] = {} + self.position = PositionState() + self.train_end_idx: int = 0 + + def run(self, close: pd.Series, **kwargs) -> tuple: + """ + Execute the multi-pair divergence strategy. + + This method is called by the backtester with the primary asset's + close prices. For multi-pair, we load all assets internally. + + Args: + close: Primary close prices (used for index alignment) + **kwargs: Additional data (high, low, volume) + + Returns: + Tuple of (long_entries, long_exits, short_entries, short_exits, size) + """ + logger.info("Starting Multi-Pair Divergence Strategy") + + # 1. Load all asset data + start_date = close.index.min().strftime("%Y-%m-%d") + end_date = close.index.max().strftime("%Y-%m-%d") + + self.asset_data = self.feature_engine.load_all_assets( + start_date=start_date, + end_date=end_date + ) + + # 1b. Load funding rate data for all assets + self.feature_engine.load_funding_data( + start_date=start_date, + end_date=end_date, + use_cache=True + ) + + if len(self.asset_data) < 2: + logger.error("Insufficient assets loaded, need at least 2") + return self._empty_signals(close) + + # 2. Generate pairs + self.pairs = self.pair_scanner.generate_pairs(check_exchange=False) + + # Filter to pairs with available data + available_assets = set(self.asset_data.keys()) + self.pairs = [ + p for p in self.pairs + if p.base_asset in available_assets + and p.quote_asset in available_assets + ] + + logger.info("Trading %d pairs from %d assets", len(self.pairs), len(self.asset_data)) + + # 3. Calculate features for all pairs + self.pair_features = self.feature_engine.calculate_all_pair_features( + self.pairs, self.asset_data + ) + + if not self.pair_features: + logger.error("No pair features calculated") + return self._empty_signals(close) + + # 4. Align to common index + common_index = self._get_common_index() + if len(common_index) < 200: + logger.error("Insufficient common data across pairs") + return self._empty_signals(close) + + # 5. Walk-forward split + n_samples = len(common_index) + train_size = int(n_samples * self.config.train_ratio) + self.train_end_idx = train_size + + train_end_date = common_index[train_size - 1] + test_start_date = common_index[train_size] + + logger.info( + "Walk-Forward Split: Train=%d bars (until %s), Test=%d bars (from %s)", + train_size, train_end_date.strftime('%Y-%m-%d'), + n_samples - train_size, test_start_date.strftime('%Y-%m-%d') + ) + + # 6. Train model on training period + if self.divergence_scorer.model is None: + train_features = { + pid: feat[feat.index <= train_end_date] + for pid, feat in self.pair_features.items() + } + combined = self.feature_engine.get_combined_features(train_features) + self.divergence_scorer.train_model(combined, train_features) + + # 7. Generate signals for test period + return self._generate_signals(common_index, train_size, close) + + def _generate_signals( + self, + index: pd.DatetimeIndex, + train_size: int, + reference_close: pd.Series + ) -> tuple: + """ + Generate entry/exit signals for the test period. + + Iterates through each bar in the test period, scoring pairs + and generating signals based on divergence scores. + """ + # Initialize signal arrays aligned to reference close + long_entries = pd.Series(False, index=reference_close.index) + long_exits = pd.Series(False, index=reference_close.index) + short_entries = pd.Series(False, index=reference_close.index) + short_exits = pd.Series(False, index=reference_close.index) + size = pd.Series(1.0, index=reference_close.index) + + # Track position state + self.position = PositionState() + + # Price data for correlation calculation + price_data = { + symbol: df['close'] for symbol, df in self.asset_data.items() + } + + # Iterate through test period + test_indices = index[train_size:] + + trade_count = 0 + + for i, timestamp in enumerate(test_indices): + current_idx = train_size + i + + # Check exit conditions first + if self.position.pair is not None: + # Enforce minimum hold period + bars_held = current_idx - self.position.entry_idx + if bars_held < self.config.min_hold_bars: + # Only allow SL/TP exits during min hold period + should_exit, exit_reason = self._check_sl_tp_only(timestamp) + else: + should_exit, exit_reason = self._check_exit(timestamp) + + if should_exit: + # Map exit signal to reference index + if timestamp in reference_close.index: + if self.position.direction == 'long': + long_exits.loc[timestamp] = True + else: + short_exits.loc[timestamp] = True + + logger.debug( + "Exit %s %s at %s: %s (held %d bars)", + self.position.direction, + self.position.pair.name, + timestamp.strftime('%Y-%m-%d %H:%M'), + exit_reason, + bars_held + ) + self.position = PositionState(last_exit_idx=current_idx) + + # Score pairs (with correlation filter if position exists) + held_asset = None + if self.position.pair is not None: + held_asset = self.position.pair.base_asset + + # Filter pairs by correlation + candidate_pairs = self.correlation_filter.filter_pairs( + self.pairs, + held_asset, + price_data, + current_idx + ) + + # Get candidate features + candidate_features = { + pid: feat for pid, feat in self.pair_features.items() + if any(p.pair_id == pid for p in candidate_pairs) + } + + # Score pairs + signals = self.divergence_scorer.score_pairs( + candidate_features, + candidate_pairs, + timestamp + ) + + # Get best signal + best = self.divergence_scorer.select_best_pair(signals) + + if best is None: + continue + + # Check if we should switch positions or enter new + should_enter = False + + # Check cooldown + bars_since_exit = current_idx - self.position.last_exit_idx + in_cooldown = bars_since_exit < self.config.cooldown_bars + + if self.position.pair is None and not in_cooldown: + # No position and not in cooldown, can enter + should_enter = True + elif self.position.pair is not None: + # Check if we should switch (requires min hold + significant improvement) + bars_held = current_idx - self.position.entry_idx + current_score = self._get_current_score(timestamp) + + if (bars_held >= self.config.min_hold_bars and + best.divergence_score > current_score * self.config.switch_threshold): + # New opportunity is significantly better + if timestamp in reference_close.index: + if self.position.direction == 'long': + long_exits.loc[timestamp] = True + else: + short_exits.loc[timestamp] = True + self.position = PositionState(last_exit_idx=current_idx) + should_enter = True + + if should_enter: + # Calculate ATR-based dynamic SL/TP + sl_price, tp_price = self._calculate_sl_tp( + best.base_price, + best.direction, + best.atr, + best.atr_pct + ) + + # Set position + self.position = PositionState( + pair=best.pair, + direction=best.direction, + entry_price=best.base_price, + entry_idx=current_idx, + stop_loss=sl_price, + take_profit=tp_price, + atr=best.atr + ) + + # Calculate position size based on divergence + pos_size = self._calculate_size(best.divergence_score) + + # Generate entry signal + if timestamp in reference_close.index: + if best.direction == 'long': + long_entries.loc[timestamp] = True + else: + short_entries.loc[timestamp] = True + size.loc[timestamp] = pos_size + + trade_count += 1 + logger.debug( + "Entry %s %s at %s: z=%.2f, prob=%.2f, score=%.3f", + best.direction, + best.pair.name, + timestamp.strftime('%Y-%m-%d %H:%M'), + best.z_score, + best.probability, + best.divergence_score + ) + + logger.info("Generated %d trades in test period", trade_count) + + return long_entries, long_exits, short_entries, short_exits, size + + def _check_exit(self, timestamp: pd.Timestamp) -> tuple[bool, str]: + """ + Check if current position should be exited. + + Exit conditions: + 1. Z-Score reverted to mean (|Z| < threshold) + 2. Stop-loss hit + 3. Take-profit hit + + Returns: + Tuple of (should_exit, reason) + """ + if self.position.pair is None: + return False, "" + + pair_id = self.position.pair.pair_id + if pair_id not in self.pair_features: + return True, "pair_data_missing" + + features = self.pair_features[pair_id] + valid = features[features.index <= timestamp] + + if len(valid) == 0: + return True, "no_data" + + latest = valid.iloc[-1] + z_score = latest['z_score'] + current_price = latest['base_close'] + + # Check mean reversion (primary exit) + if abs(z_score) < self.config.z_exit_threshold: + return True, f"mean_reversion (z={z_score:.2f})" + + # Check SL/TP + return self._check_sl_tp(current_price) + + def _check_sl_tp_only(self, timestamp: pd.Timestamp) -> tuple[bool, str]: + """ + Check only stop-loss and take-profit conditions. + Used during minimum hold period. + """ + if self.position.pair is None: + return False, "" + + pair_id = self.position.pair.pair_id + if pair_id not in self.pair_features: + return True, "pair_data_missing" + + features = self.pair_features[pair_id] + valid = features[features.index <= timestamp] + + if len(valid) == 0: + return True, "no_data" + + latest = valid.iloc[-1] + current_price = latest['base_close'] + + return self._check_sl_tp(current_price) + + def _check_sl_tp(self, current_price: float) -> tuple[bool, str]: + """Check stop-loss and take-profit levels.""" + if self.position.direction == 'long': + if current_price <= self.position.stop_loss: + return True, f"stop_loss ({current_price:.2f} <= {self.position.stop_loss:.2f})" + if current_price >= self.position.take_profit: + return True, f"take_profit ({current_price:.2f} >= {self.position.take_profit:.2f})" + else: # short + if current_price >= self.position.stop_loss: + return True, f"stop_loss ({current_price:.2f} >= {self.position.stop_loss:.2f})" + if current_price <= self.position.take_profit: + return True, f"take_profit ({current_price:.2f} <= {self.position.take_profit:.2f})" + + return False, "" + + def _get_current_score(self, timestamp: pd.Timestamp) -> float: + """Get current position's divergence score for comparison.""" + if self.position.pair is None: + return 0.0 + + pair_id = self.position.pair.pair_id + if pair_id not in self.pair_features: + return 0.0 + + features = self.pair_features[pair_id] + valid = features[features.index <= timestamp] + + if len(valid) == 0: + return 0.0 + + latest = valid.iloc[-1] + z_score = abs(latest['z_score']) + + # Re-score with model + if self.divergence_scorer.model is not None: + feature_row = latest[self.divergence_scorer.feature_cols].fillna(0) + feature_row = feature_row.replace([np.inf, -np.inf], 0) + X = pd.DataFrame( + [feature_row.values], + columns=self.divergence_scorer.feature_cols + ) + prob = self.divergence_scorer.model.predict_proba(X)[0, 1] + return z_score * prob + + return z_score * 0.5 + + def _calculate_sl_tp( + self, + entry_price: float, + direction: str, + atr: float, + atr_pct: float + ) -> tuple[float, float]: + """ + Calculate ATR-based dynamic stop-loss and take-profit prices. + + Uses ATR (Average True Range) to set stops that adapt to + each asset's volatility. More volatile assets get wider stops. + + Args: + entry_price: Entry price + direction: 'long' or 'short' + atr: ATR in price units + atr_pct: ATR as percentage of price + + Returns: + Tuple of (stop_loss_price, take_profit_price) + """ + # Calculate SL/TP as ATR multiples + if atr > 0 and atr_pct > 0: + # ATR-based calculation + sl_distance = atr * self.config.sl_atr_multiplier + tp_distance = atr * self.config.tp_atr_multiplier + + # Convert to percentage for bounds checking + sl_pct = sl_distance / entry_price + tp_pct = tp_distance / entry_price + else: + # Fallback to fixed percentages if ATR unavailable + sl_pct = self.config.base_sl_pct + tp_pct = self.config.base_tp_pct + + # Apply bounds to prevent extreme stops + sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct)) + tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct)) + + # Calculate actual prices + if direction == 'long': + stop_loss = entry_price * (1 - sl_pct) + take_profit = entry_price * (1 + tp_pct) + else: # short + stop_loss = entry_price * (1 + sl_pct) + take_profit = entry_price * (1 - tp_pct) + + return stop_loss, take_profit + + def _calculate_size(self, divergence_score: float) -> float: + """ + Calculate position size based on divergence score. + + Higher divergence = larger position (up to 2x). + """ + # Base score threshold (Z=1.0, prob=0.5 -> score=0.5) + base_threshold = 0.5 + + # Scale factor + if divergence_score <= base_threshold: + return 1.0 + + # Linear scaling: 1.0 at threshold, up to 2.0 at 2x threshold + scale = 1.0 + (divergence_score - base_threshold) / base_threshold + return min(scale, 2.0) + + def _get_common_index(self) -> pd.DatetimeIndex: + """Get the intersection of all pair feature indices.""" + if not self.pair_features: + return pd.DatetimeIndex([]) + + common = None + for features in self.pair_features.values(): + if common is None: + common = features.index + else: + common = common.intersection(features.index) + + return common.sort_values() + + def _empty_signals(self, close: pd.Series) -> tuple: + """Return empty signal arrays.""" + empty = self.create_empty_signals(close) + size = pd.Series(1.0, index=close.index) + return empty, empty, empty, empty, size diff --git a/tasks/prd-multi-pair-divergence-strategy.md b/tasks/prd-multi-pair-divergence-strategy.md new file mode 100644 index 0000000..eaef4e5 --- /dev/null +++ b/tasks/prd-multi-pair-divergence-strategy.md @@ -0,0 +1,321 @@ +# PRD: Multi-Pair Divergence Selection Strategy + +## 1. Introduction / Overview + +This document describes the **Multi-Pair Divergence Selection Strategy**, an extension of the existing BTC/ETH regime reversion system. The strategy expands spread analysis to the **top 10 cryptocurrencies by market cap**, calculates divergence scores for all tradeable pairs, and dynamically selects the **most divergent pair** for trading. + +The core hypothesis: by scanning multiple pairs simultaneously, we can identify stronger mean-reversion opportunities than focusing on a single pair, improving net PnL while maintaining the proven ML-based regime detection approach. + +--- + +## 2. Goals + +1. **Extend regime detection** to top 10 market cap cryptocurrencies +2. **Dynamically select** the most divergent tradeable pair each cycle +3. **Integrate volatility** into dynamic SL/TP calculations +4. **Filter correlated pairs** to avoid redundant positions +5. **Improve net PnL** compared to single-pair BTC/ETH strategy +6. **Backtest-first** implementation with walk-forward validation + +--- + +## 3. User Stories + +### US-1: Multi-Pair Analysis +> As a trader, I want the system to analyze spread divergence across multiple cryptocurrency pairs so that I can identify the best trading opportunity at any given moment. + +### US-2: Dynamic Pair Selection +> As a trader, I want the system to automatically select and trade the pair with the highest divergence score (combination of Z-score magnitude and ML probability) so that I maximize mean-reversion profit potential. + +### US-3: Volatility-Adjusted Risk +> As a trader, I want stop-loss and take-profit levels to adapt to each pair's volatility so that I avoid being stopped out prematurely on volatile assets while protecting profits on stable ones. + +### US-4: Correlation Filtering +> As a trader, I want the system to avoid selecting pairs that are highly correlated with my current position so that I don't inadvertently double-down on the same market exposure. + +### US-5: Backtest Validation +> As a researcher, I want to backtest this multi-pair strategy with walk-forward training so that I can validate improvement over the single-pair baseline without look-ahead bias. + +--- + +## 4. Functional Requirements + +### 4.1 Data Management + +| ID | Requirement | +|----|-------------| +| FR-1.1 | System must support loading OHLCV data for top 10 market cap cryptocurrencies | +| FR-1.2 | Target assets: BTC, ETH, SOL, XRP, BNB, DOGE, ADA, AVAX, LINK, DOT (configurable) | +| FR-1.3 | System must identify all directly tradeable cross-pairs on OKX perpetuals | +| FR-1.4 | System must align timestamps across all pairs for synchronized analysis | +| FR-1.5 | System must handle missing data gracefully (skip pair if insufficient history) | + +### 4.2 Pair Generation + +| ID | Requirement | +|----|-------------| +| FR-2.1 | Generate all unique pairs from asset universe: N*(N-1)/2 pairs (e.g., 45 pairs for 10 assets) | +| FR-2.2 | Filter pairs to only those directly tradeable on OKX (no USDT intermediate) | +| FR-2.3 | Fallback: If cross-pair not available, calculate synthetic spread via USDT pairs | +| FR-2.4 | Store pair metadata: base asset, quote asset, exchange symbol, tradeable flag | + +### 4.3 Feature Engineering (Per Pair) + +| ID | Requirement | +|----|-------------| +| FR-3.1 | Calculate spread ratio: `asset_a_close / asset_b_close` | +| FR-3.2 | Calculate Z-Score with configurable rolling window (default: 24h) | +| FR-3.3 | Calculate spread technicals: RSI(14), ROC(5), 1h change | +| FR-3.4 | Calculate volume ratio and relative volume | +| FR-3.5 | Calculate volatility ratio: `std(returns_a) / std(returns_b)` over Z-window | +| FR-3.6 | Calculate realized volatility for each asset (for dynamic SL/TP) | +| FR-3.7 | Merge on-chain data (funding rates, inflows) if available per asset | +| FR-3.8 | Add pair identifier as categorical feature for universal model | + +### 4.4 Correlation Filtering + +| ID | Requirement | +|----|-------------| +| FR-4.1 | Calculate rolling correlation matrix between all assets (default: 168h / 7 days) | +| FR-4.2 | Define correlation threshold (default: 0.85) | +| FR-4.3 | If current position exists, exclude pairs where either asset has correlation > threshold with held asset | +| FR-4.4 | Log filtered pairs with reason for exclusion | + +### 4.5 Divergence Scoring & Pair Selection + +| ID | Requirement | +|----|-------------| +| FR-5.1 | Calculate divergence score: `abs(z_score) * model_probability` | +| FR-5.2 | Only consider pairs where `abs(z_score) > z_entry_threshold` (default: 1.0) | +| FR-5.3 | Only consider pairs where `model_probability > prob_threshold` (default: 0.5) | +| FR-5.4 | Apply correlation filter to eligible pairs | +| FR-5.5 | Select pair with highest divergence score | +| FR-5.6 | If no pair qualifies, signal "hold" | +| FR-5.7 | Log all pair scores for analysis/debugging | + +### 4.6 ML Model (Universal) + +| ID | Requirement | +|----|-------------| +| FR-6.1 | Train single Random Forest model on all pairs combined | +| FR-6.2 | Include `pair_id` as one-hot encoded or label-encoded feature | +| FR-6.3 | Target: binary (1 = profitable reversion within horizon, 0 = no reversion) | +| FR-6.4 | Walk-forward training: 70% train / 30% test split | +| FR-6.5 | Daily retraining schedule (for live, configurable for backtest) | +| FR-6.6 | Model hyperparameters: `n_estimators=300, max_depth=5, min_samples_leaf=30, class_weight={0:1, 1:3}` | +| FR-6.7 | Save/load model with feature column metadata | + +### 4.7 Signal Generation + +| ID | Requirement | +|----|-------------| +| FR-7.1 | Direction: If `z_score > threshold` -> Short spread (short asset_a), If `z_score < -threshold` -> Long spread (long asset_a) | +| FR-7.2 | Apply funding rate filter per asset (block if extreme funding opposes direction) | +| FR-7.3 | Output signal: `{pair, action, side, probability, z_score, divergence_score, reason}` | + +### 4.8 Position Sizing + +| ID | Requirement | +|----|-------------| +| FR-8.1 | Base size: 100% of available subaccount balance | +| FR-8.2 | Scale by divergence: `size_multiplier = 1.0 + (divergence_score - base_threshold) * scaling_factor` | +| FR-8.3 | Cap multiplier between 1.0x and 2.0x | +| FR-8.4 | Respect exchange minimum order size per asset | + +### 4.9 Dynamic SL/TP (Volatility-Adjusted) + +| ID | Requirement | +|----|-------------| +| FR-9.1 | Calculate asset realized volatility: `std(returns) * sqrt(24)` for daily vol | +| FR-9.2 | Base SL: `entry_price * (1 - base_sl_pct * vol_multiplier)` for longs | +| FR-9.3 | Base TP: `entry_price * (1 + base_tp_pct * vol_multiplier)` for longs | +| FR-9.4 | `vol_multiplier = asset_volatility / baseline_volatility` (baseline = BTC volatility) | +| FR-9.5 | Cap vol_multiplier between 0.5x and 2.0x to prevent extreme values | +| FR-9.6 | Invert logic for short positions | + +### 4.10 Exit Conditions + +| ID | Requirement | +|----|-------------| +| FR-10.1 | Exit when Z-score crosses back through 0 (mean reversion complete) | +| FR-10.2 | Exit when dynamic SL or TP hit | +| FR-10.3 | No minimum holding period (can switch pairs immediately) | +| FR-10.4 | If new pair has higher divergence score, close current and open new | + +### 4.11 Backtest Integration + +| ID | Requirement | +|----|-------------| +| FR-11.1 | Integrate with existing `engine/backtester.py` framework | +| FR-11.2 | Support 1h timeframe (matching live trading) | +| FR-11.3 | Walk-forward validation: train on 70%, test on 30% | +| FR-11.4 | Output: trades log, equity curve, performance metrics | +| FR-11.5 | Compare against single-pair BTC/ETH baseline | + +--- + +## 5. Non-Goals (Out of Scope) + +1. **Live trading implementation** - Backtest validation first +2. **Multi-position portfolio** - Single pair at a time for v1 +3. **Cross-exchange arbitrage** - OKX only +4. **Alternative ML models** - Stick with Random Forest for consistency +5. **Sub-1h timeframes** - 1h candles only for initial version +6. **Leveraged positions** - 1x leverage for backtest +7. **Portfolio-level VaR/risk budgeting** - Full subaccount allocation + +--- + +## 6. Design Considerations + +### 6.1 Architecture + +``` +strategies/ + multi_pair/ + __init__.py + pair_scanner.py # Generates all pairs, filters tradeable + feature_engine.py # Calculates features for all pairs + correlation.py # Rolling correlation matrix & filtering + divergence_scorer.py # Ranks pairs by divergence score + strategy.py # Main strategy orchestration +``` + +### 6.2 Data Flow + +``` +1. Load OHLCV for all 10 assets +2. Generate pair combinations (45 pairs) +3. Filter to tradeable pairs (OKX check) +4. Calculate features for each pair +5. Train/load universal ML model +6. Predict probability for all pairs +7. Calculate divergence scores +8. Apply correlation filter +9. Select top pair +10. Generate signal with dynamic SL/TP +11. Execute in backtest engine +``` + +### 6.3 Configuration + +```python +@dataclass +class MultiPairConfig: + # Assets + assets: list[str] = field(default_factory=lambda: [ + "BTC", "ETH", "SOL", "XRP", "BNB", + "DOGE", "ADA", "AVAX", "LINK", "DOT" + ]) + + # Thresholds + z_window: int = 24 + z_entry_threshold: float = 1.0 + prob_threshold: float = 0.5 + correlation_threshold: float = 0.85 + correlation_window: int = 168 # 7 days in hours + + # Risk + base_sl_pct: float = 0.06 + base_tp_pct: float = 0.05 + vol_multiplier_min: float = 0.5 + vol_multiplier_max: float = 2.0 + + # Model + train_ratio: float = 0.7 + horizon: int = 102 + profit_target: float = 0.005 +``` + +--- + +## 7. Technical Considerations + +### 7.1 Dependencies + +- Extend `DataManager` to load multiple symbols +- Query OKX API for available perpetual cross-pairs +- Reuse existing feature engineering from `RegimeReversionStrategy` + +### 7.2 Performance + +- Pre-calculate all pair features in batch (vectorized) +- Cache correlation matrix (update every N candles, not every minute) +- Model inference is fast (single predict call with all pairs as rows) + +### 7.3 Edge Cases + +- Handle pairs with insufficient history (< 200 bars) - exclude +- Handle assets delisted mid-backtest - skip pair +- Handle zero-volume periods - use last valid price + +--- + +## 8. Success Metrics + +| Metric | Baseline (BTC/ETH) | Target | +|--------|-------------------|--------| +| Net PnL | Current performance | > 10% improvement | +| Number of Trades | N | Comparable or higher | +| Win Rate | Baseline % | Maintain or improve | +| Average Trade Duration | Baseline hours | Flexible | +| Max Drawdown | Baseline % | Not significantly worse | + +--- + +## 9. Open Questions + +1. **OKX Cross-Pairs**: Need to verify which cross-pairs are available on OKX perpetuals. May need to fallback to synthetic spreads for most pairs. + +2. **On-Chain Data**: CryptoQuant data currently covers BTC/ETH. Should we: + - Run without on-chain features for other assets? + - Source alternative on-chain data? + - Use funding rates only (available from OKX)? + +3. **Pair ID Encoding**: For the universal model, should pair_id be: + - One-hot encoded (adds 45 features)? + - Label encoded (single ordinal feature)? + - Hierarchical (base_asset + quote_asset as separate features)? + +4. **Synthetic Spreads**: If trading SOL/DOT spread but only USDT pairs available: + - Calculate spread synthetically: `SOL-USDT / DOT-USDT` + - Execute as two legs: Long SOL-USDT, Short DOT-USDT + - This doubles fees and adds execution complexity. Include in v1? + +--- + +## 10. Implementation Phases + +### Phase 1: Data & Infrastructure (Est. 2-3 days) +- Extend DataManager for multi-symbol loading +- Build pair scanner with OKX tradeable filter +- Implement correlation matrix calculation + +### Phase 2: Feature Engineering (Est. 2 days) +- Adapt existing feature calculation for arbitrary pairs +- Add pair identifier feature +- Batch feature calculation for all pairs + +### Phase 3: Model & Scoring (Est. 2 days) +- Train universal model on all pairs +- Implement divergence scoring +- Add correlation filtering to pair selection + +### Phase 4: Strategy Integration (Est. 2-3 days) +- Implement dynamic SL/TP with volatility +- Integrate with backtester +- Build strategy orchestration class + +### Phase 5: Validation & Comparison (Est. 2 days) +- Run walk-forward backtest +- Compare against BTC/ETH baseline +- Generate performance report + +**Total Estimated Effort: 10-12 days** + +--- + +*Document Version: 1.0* +*Created: 2026-01-15* +*Author: AI Assistant* +*Status: Draft - Awaiting Review*