""" Multi-Pair Divergence Selection Strategy. Main strategy class that orchestrates pair scanning, feature calculation, model training, and signal generation for backtesting. """ from dataclasses import dataclass from typing import Optional import pandas as pd import numpy as np from strategies.base import BaseStrategy from engine.market import MarketType from engine.logging_config import get_logger from .config import MultiPairConfig from .pair_scanner import PairScanner, TradingPair from .correlation import CorrelationFilter from .feature_engine import MultiPairFeatureEngine from .divergence_scorer import DivergenceScorer, DivergenceSignal logger = get_logger(__name__) @dataclass class PositionState: """Tracks current position state.""" pair: TradingPair | None = None direction: str | None = None # 'long' or 'short' entry_price: float = 0.0 entry_idx: int = -1 stop_loss: float = 0.0 take_profit: float = 0.0 atr: float = 0.0 # ATR at entry for reference last_exit_idx: int = -100 # For cooldown tracking class MultiPairDivergenceStrategy(BaseStrategy): """ Multi-Pair Divergence Selection Strategy. Scans multiple cryptocurrency pairs for spread divergence, selects the most divergent pair using ML-enhanced scoring, and trades mean-reversion opportunities. Key Features: - Universal ML model across all pairs - Correlation-based pair filtering - Dynamic SL/TP based on volatility - Walk-forward training """ def __init__( self, config: MultiPairConfig | None = None, model_path: str = "data/multi_pair_model.pkl" ): super().__init__() self.config = config or MultiPairConfig() # Initialize components self.pair_scanner = PairScanner(self.config) self.correlation_filter = CorrelationFilter(self.config) self.feature_engine = MultiPairFeatureEngine(self.config) self.divergence_scorer = DivergenceScorer(self.config, model_path) # Strategy configuration self.default_market_type = MarketType.PERPETUAL self.default_leverage = 1 # Runtime state self.pairs: list[TradingPair] = [] self.asset_data: dict[str, pd.DataFrame] = {} self.pair_features: dict[str, pd.DataFrame] = {} self.position = PositionState() self.train_end_idx: int = 0 def run(self, close: pd.Series, **kwargs) -> tuple: """ Execute the multi-pair divergence strategy. This method is called by the backtester with the primary asset's close prices. For multi-pair, we load all assets internally. Args: close: Primary close prices (used for index alignment) **kwargs: Additional data (high, low, volume) Returns: Tuple of (long_entries, long_exits, short_entries, short_exits, size) """ logger.info("Starting Multi-Pair Divergence Strategy") # 1. Load all asset data start_date = close.index.min().strftime("%Y-%m-%d") end_date = close.index.max().strftime("%Y-%m-%d") self.asset_data = self.feature_engine.load_all_assets( start_date=start_date, end_date=end_date ) # 1b. Load funding rate data for all assets self.feature_engine.load_funding_data( start_date=start_date, end_date=end_date, use_cache=True ) if len(self.asset_data) < 2: logger.error("Insufficient assets loaded, need at least 2") return self._empty_signals(close) # 2. Generate pairs self.pairs = self.pair_scanner.generate_pairs(check_exchange=False) # Filter to pairs with available data available_assets = set(self.asset_data.keys()) self.pairs = [ p for p in self.pairs if p.base_asset in available_assets and p.quote_asset in available_assets ] logger.info("Trading %d pairs from %d assets", len(self.pairs), len(self.asset_data)) # 3. Calculate features for all pairs self.pair_features = self.feature_engine.calculate_all_pair_features( self.pairs, self.asset_data ) if not self.pair_features: logger.error("No pair features calculated") return self._empty_signals(close) # 4. Align to common index common_index = self._get_common_index() if len(common_index) < 200: logger.error("Insufficient common data across pairs") return self._empty_signals(close) # 5. Walk-forward split n_samples = len(common_index) train_size = int(n_samples * self.config.train_ratio) self.train_end_idx = train_size train_end_date = common_index[train_size - 1] test_start_date = common_index[train_size] logger.info( "Walk-Forward Split: Train=%d bars (until %s), Test=%d bars (from %s)", train_size, train_end_date.strftime('%Y-%m-%d'), n_samples - train_size, test_start_date.strftime('%Y-%m-%d') ) # 6. Train model on training period if self.divergence_scorer.model is None: train_features = { pid: feat[feat.index <= train_end_date] for pid, feat in self.pair_features.items() } combined = self.feature_engine.get_combined_features(train_features) self.divergence_scorer.train_model(combined, train_features) # 7. Generate signals for test period return self._generate_signals(common_index, train_size, close) def _generate_signals( self, index: pd.DatetimeIndex, train_size: int, reference_close: pd.Series ) -> tuple: """ Generate entry/exit signals for the test period. Iterates through each bar in the test period, scoring pairs and generating signals based on divergence scores. """ # Initialize signal arrays aligned to reference close long_entries = pd.Series(False, index=reference_close.index) long_exits = pd.Series(False, index=reference_close.index) short_entries = pd.Series(False, index=reference_close.index) short_exits = pd.Series(False, index=reference_close.index) size = pd.Series(1.0, index=reference_close.index) # Track position state self.position = PositionState() # Price data for correlation calculation price_data = { symbol: df['close'] for symbol, df in self.asset_data.items() } # Iterate through test period test_indices = index[train_size:] trade_count = 0 for i, timestamp in enumerate(test_indices): current_idx = train_size + i # Check exit conditions first if self.position.pair is not None: # Enforce minimum hold period bars_held = current_idx - self.position.entry_idx if bars_held < self.config.min_hold_bars: # Only allow SL/TP exits during min hold period should_exit, exit_reason = self._check_sl_tp_only(timestamp) else: should_exit, exit_reason = self._check_exit(timestamp) if should_exit: # Map exit signal to reference index if timestamp in reference_close.index: if self.position.direction == 'long': long_exits.loc[timestamp] = True else: short_exits.loc[timestamp] = True logger.debug( "Exit %s %s at %s: %s (held %d bars)", self.position.direction, self.position.pair.name, timestamp.strftime('%Y-%m-%d %H:%M'), exit_reason, bars_held ) self.position = PositionState(last_exit_idx=current_idx) # Score pairs (with correlation filter if position exists) held_asset = None if self.position.pair is not None: held_asset = self.position.pair.base_asset # Filter pairs by correlation candidate_pairs = self.correlation_filter.filter_pairs( self.pairs, held_asset, price_data, current_idx ) # Get candidate features candidate_features = { pid: feat for pid, feat in self.pair_features.items() if any(p.pair_id == pid for p in candidate_pairs) } # Score pairs signals = self.divergence_scorer.score_pairs( candidate_features, candidate_pairs, timestamp ) # Get best signal best = self.divergence_scorer.select_best_pair(signals) if best is None: continue # Check if we should switch positions or enter new should_enter = False # Check cooldown bars_since_exit = current_idx - self.position.last_exit_idx in_cooldown = bars_since_exit < self.config.cooldown_bars if self.position.pair is None and not in_cooldown: # No position and not in cooldown, can enter should_enter = True elif self.position.pair is not None: # Check if we should switch (requires min hold + significant improvement) bars_held = current_idx - self.position.entry_idx current_score = self._get_current_score(timestamp) if (bars_held >= self.config.min_hold_bars and best.divergence_score > current_score * self.config.switch_threshold): # New opportunity is significantly better if timestamp in reference_close.index: if self.position.direction == 'long': long_exits.loc[timestamp] = True else: short_exits.loc[timestamp] = True self.position = PositionState(last_exit_idx=current_idx) should_enter = True if should_enter: # Calculate ATR-based dynamic SL/TP sl_price, tp_price = self._calculate_sl_tp( best.base_price, best.direction, best.atr, best.atr_pct ) # Set position self.position = PositionState( pair=best.pair, direction=best.direction, entry_price=best.base_price, entry_idx=current_idx, stop_loss=sl_price, take_profit=tp_price, atr=best.atr ) # Calculate position size based on divergence pos_size = self._calculate_size(best.divergence_score) # Generate entry signal if timestamp in reference_close.index: if best.direction == 'long': long_entries.loc[timestamp] = True else: short_entries.loc[timestamp] = True size.loc[timestamp] = pos_size trade_count += 1 logger.debug( "Entry %s %s at %s: z=%.2f, prob=%.2f, score=%.3f", best.direction, best.pair.name, timestamp.strftime('%Y-%m-%d %H:%M'), best.z_score, best.probability, best.divergence_score ) logger.info("Generated %d trades in test period", trade_count) return long_entries, long_exits, short_entries, short_exits, size def _check_exit(self, timestamp: pd.Timestamp) -> tuple[bool, str]: """ Check if current position should be exited. Exit conditions: 1. Z-Score reverted to mean (|Z| < threshold) 2. Stop-loss hit 3. Take-profit hit Returns: Tuple of (should_exit, reason) """ if self.position.pair is None: return False, "" pair_id = self.position.pair.pair_id if pair_id not in self.pair_features: return True, "pair_data_missing" features = self.pair_features[pair_id] valid = features[features.index <= timestamp] if len(valid) == 0: return True, "no_data" latest = valid.iloc[-1] z_score = latest['z_score'] current_price = latest['base_close'] # Check mean reversion (primary exit) if abs(z_score) < self.config.z_exit_threshold: return True, f"mean_reversion (z={z_score:.2f})" # Check SL/TP return self._check_sl_tp(current_price) def _check_sl_tp_only(self, timestamp: pd.Timestamp) -> tuple[bool, str]: """ Check only stop-loss and take-profit conditions. Used during minimum hold period. """ if self.position.pair is None: return False, "" pair_id = self.position.pair.pair_id if pair_id not in self.pair_features: return True, "pair_data_missing" features = self.pair_features[pair_id] valid = features[features.index <= timestamp] if len(valid) == 0: return True, "no_data" latest = valid.iloc[-1] current_price = latest['base_close'] return self._check_sl_tp(current_price) def _check_sl_tp(self, current_price: float) -> tuple[bool, str]: """Check stop-loss and take-profit levels.""" if self.position.direction == 'long': if current_price <= self.position.stop_loss: return True, f"stop_loss ({current_price:.2f} <= {self.position.stop_loss:.2f})" if current_price >= self.position.take_profit: return True, f"take_profit ({current_price:.2f} >= {self.position.take_profit:.2f})" else: # short if current_price >= self.position.stop_loss: return True, f"stop_loss ({current_price:.2f} >= {self.position.stop_loss:.2f})" if current_price <= self.position.take_profit: return True, f"take_profit ({current_price:.2f} <= {self.position.take_profit:.2f})" return False, "" def _get_current_score(self, timestamp: pd.Timestamp) -> float: """Get current position's divergence score for comparison.""" if self.position.pair is None: return 0.0 pair_id = self.position.pair.pair_id if pair_id not in self.pair_features: return 0.0 features = self.pair_features[pair_id] valid = features[features.index <= timestamp] if len(valid) == 0: return 0.0 latest = valid.iloc[-1] z_score = abs(latest['z_score']) # Re-score with model if self.divergence_scorer.model is not None: feature_row = latest[self.divergence_scorer.feature_cols].fillna(0) feature_row = feature_row.replace([np.inf, -np.inf], 0) X = pd.DataFrame( [feature_row.values], columns=self.divergence_scorer.feature_cols ) prob = self.divergence_scorer.model.predict_proba(X)[0, 1] return z_score * prob return z_score * 0.5 def _calculate_sl_tp( self, entry_price: float, direction: str, atr: float, atr_pct: float ) -> tuple[float, float]: """ Calculate ATR-based dynamic stop-loss and take-profit prices. Uses ATR (Average True Range) to set stops that adapt to each asset's volatility. More volatile assets get wider stops. Args: entry_price: Entry price direction: 'long' or 'short' atr: ATR in price units atr_pct: ATR as percentage of price Returns: Tuple of (stop_loss_price, take_profit_price) """ # Calculate SL/TP as ATR multiples if atr > 0 and atr_pct > 0: # ATR-based calculation sl_distance = atr * self.config.sl_atr_multiplier tp_distance = atr * self.config.tp_atr_multiplier # Convert to percentage for bounds checking sl_pct = sl_distance / entry_price tp_pct = tp_distance / entry_price else: # Fallback to fixed percentages if ATR unavailable sl_pct = self.config.base_sl_pct tp_pct = self.config.base_tp_pct # Apply bounds to prevent extreme stops sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct)) tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct)) # Calculate actual prices if direction == 'long': stop_loss = entry_price * (1 - sl_pct) take_profit = entry_price * (1 + tp_pct) else: # short stop_loss = entry_price * (1 + sl_pct) take_profit = entry_price * (1 - tp_pct) return stop_loss, take_profit def _calculate_size(self, divergence_score: float) -> float: """ Calculate position size based on divergence score. Higher divergence = larger position (up to 2x). """ # Base score threshold (Z=1.0, prob=0.5 -> score=0.5) base_threshold = 0.5 # Scale factor if divergence_score <= base_threshold: return 1.0 # Linear scaling: 1.0 at threshold, up to 2.0 at 2x threshold scale = 1.0 + (divergence_score - base_threshold) / base_threshold return min(scale, 2.0) def _get_common_index(self) -> pd.DatetimeIndex: """Get the intersection of all pair feature indices.""" if not self.pair_features: return pd.DatetimeIndex([]) common = None for features in self.pair_features.values(): if common is None: common = features.index else: common = common.intersection(features.index) return common.sort_values() def _empty_signals(self, close: pd.Series) -> tuple: """Return empty signal arrays.""" empty = self.create_empty_signals(close) size = pd.Series(1.0, index=close.index) return empty, empty, empty, empty, size