feat: Multi-Pair Divergence Selection Strategy
- Extend regime detection to top 10 cryptocurrencies (45 pairs) - Dynamic pair selection based on divergence score (|z_score| * probability) - Universal ML model trained on all pairs - Correlation-based filtering to avoid redundant positions - Funding rate integration from OKX for all 10 assets - ATR-based dynamic stop-loss and take-profit - Walk-forward training with 70/30 split Performance: +35.69% return (vs +28.66% baseline), 63.6% win rate
This commit is contained in:
525
strategies/multi_pair/strategy.py
Normal file
525
strategies/multi_pair/strategy.py
Normal file
@@ -0,0 +1,525 @@
|
||||
"""
|
||||
Multi-Pair Divergence Selection Strategy.
|
||||
|
||||
Main strategy class that orchestrates pair scanning, feature calculation,
|
||||
model training, and signal generation for backtesting.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from strategies.base import BaseStrategy
|
||||
from engine.market import MarketType
|
||||
from engine.logging_config import get_logger
|
||||
from .config import MultiPairConfig
|
||||
from .pair_scanner import PairScanner, TradingPair
|
||||
from .correlation import CorrelationFilter
|
||||
from .feature_engine import MultiPairFeatureEngine
|
||||
from .divergence_scorer import DivergenceScorer, DivergenceSignal
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PositionState:
|
||||
"""Tracks current position state."""
|
||||
pair: TradingPair | None = None
|
||||
direction: str | None = None # 'long' or 'short'
|
||||
entry_price: float = 0.0
|
||||
entry_idx: int = -1
|
||||
stop_loss: float = 0.0
|
||||
take_profit: float = 0.0
|
||||
atr: float = 0.0 # ATR at entry for reference
|
||||
last_exit_idx: int = -100 # For cooldown tracking
|
||||
|
||||
|
||||
class MultiPairDivergenceStrategy(BaseStrategy):
|
||||
"""
|
||||
Multi-Pair Divergence Selection Strategy.
|
||||
|
||||
Scans multiple cryptocurrency pairs for spread divergence,
|
||||
selects the most divergent pair using ML-enhanced scoring,
|
||||
and trades mean-reversion opportunities.
|
||||
|
||||
Key Features:
|
||||
- Universal ML model across all pairs
|
||||
- Correlation-based pair filtering
|
||||
- Dynamic SL/TP based on volatility
|
||||
- Walk-forward training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MultiPairConfig | None = None,
|
||||
model_path: str = "data/multi_pair_model.pkl"
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config or MultiPairConfig()
|
||||
|
||||
# Initialize components
|
||||
self.pair_scanner = PairScanner(self.config)
|
||||
self.correlation_filter = CorrelationFilter(self.config)
|
||||
self.feature_engine = MultiPairFeatureEngine(self.config)
|
||||
self.divergence_scorer = DivergenceScorer(self.config, model_path)
|
||||
|
||||
# Strategy configuration
|
||||
self.default_market_type = MarketType.PERPETUAL
|
||||
self.default_leverage = 1
|
||||
|
||||
# Runtime state
|
||||
self.pairs: list[TradingPair] = []
|
||||
self.asset_data: dict[str, pd.DataFrame] = {}
|
||||
self.pair_features: dict[str, pd.DataFrame] = {}
|
||||
self.position = PositionState()
|
||||
self.train_end_idx: int = 0
|
||||
|
||||
def run(self, close: pd.Series, **kwargs) -> tuple:
|
||||
"""
|
||||
Execute the multi-pair divergence strategy.
|
||||
|
||||
This method is called by the backtester with the primary asset's
|
||||
close prices. For multi-pair, we load all assets internally.
|
||||
|
||||
Args:
|
||||
close: Primary close prices (used for index alignment)
|
||||
**kwargs: Additional data (high, low, volume)
|
||||
|
||||
Returns:
|
||||
Tuple of (long_entries, long_exits, short_entries, short_exits, size)
|
||||
"""
|
||||
logger.info("Starting Multi-Pair Divergence Strategy")
|
||||
|
||||
# 1. Load all asset data
|
||||
start_date = close.index.min().strftime("%Y-%m-%d")
|
||||
end_date = close.index.max().strftime("%Y-%m-%d")
|
||||
|
||||
self.asset_data = self.feature_engine.load_all_assets(
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
# 1b. Load funding rate data for all assets
|
||||
self.feature_engine.load_funding_data(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
use_cache=True
|
||||
)
|
||||
|
||||
if len(self.asset_data) < 2:
|
||||
logger.error("Insufficient assets loaded, need at least 2")
|
||||
return self._empty_signals(close)
|
||||
|
||||
# 2. Generate pairs
|
||||
self.pairs = self.pair_scanner.generate_pairs(check_exchange=False)
|
||||
|
||||
# Filter to pairs with available data
|
||||
available_assets = set(self.asset_data.keys())
|
||||
self.pairs = [
|
||||
p for p in self.pairs
|
||||
if p.base_asset in available_assets
|
||||
and p.quote_asset in available_assets
|
||||
]
|
||||
|
||||
logger.info("Trading %d pairs from %d assets", len(self.pairs), len(self.asset_data))
|
||||
|
||||
# 3. Calculate features for all pairs
|
||||
self.pair_features = self.feature_engine.calculate_all_pair_features(
|
||||
self.pairs, self.asset_data
|
||||
)
|
||||
|
||||
if not self.pair_features:
|
||||
logger.error("No pair features calculated")
|
||||
return self._empty_signals(close)
|
||||
|
||||
# 4. Align to common index
|
||||
common_index = self._get_common_index()
|
||||
if len(common_index) < 200:
|
||||
logger.error("Insufficient common data across pairs")
|
||||
return self._empty_signals(close)
|
||||
|
||||
# 5. Walk-forward split
|
||||
n_samples = len(common_index)
|
||||
train_size = int(n_samples * self.config.train_ratio)
|
||||
self.train_end_idx = train_size
|
||||
|
||||
train_end_date = common_index[train_size - 1]
|
||||
test_start_date = common_index[train_size]
|
||||
|
||||
logger.info(
|
||||
"Walk-Forward Split: Train=%d bars (until %s), Test=%d bars (from %s)",
|
||||
train_size, train_end_date.strftime('%Y-%m-%d'),
|
||||
n_samples - train_size, test_start_date.strftime('%Y-%m-%d')
|
||||
)
|
||||
|
||||
# 6. Train model on training period
|
||||
if self.divergence_scorer.model is None:
|
||||
train_features = {
|
||||
pid: feat[feat.index <= train_end_date]
|
||||
for pid, feat in self.pair_features.items()
|
||||
}
|
||||
combined = self.feature_engine.get_combined_features(train_features)
|
||||
self.divergence_scorer.train_model(combined, train_features)
|
||||
|
||||
# 7. Generate signals for test period
|
||||
return self._generate_signals(common_index, train_size, close)
|
||||
|
||||
def _generate_signals(
|
||||
self,
|
||||
index: pd.DatetimeIndex,
|
||||
train_size: int,
|
||||
reference_close: pd.Series
|
||||
) -> tuple:
|
||||
"""
|
||||
Generate entry/exit signals for the test period.
|
||||
|
||||
Iterates through each bar in the test period, scoring pairs
|
||||
and generating signals based on divergence scores.
|
||||
"""
|
||||
# Initialize signal arrays aligned to reference close
|
||||
long_entries = pd.Series(False, index=reference_close.index)
|
||||
long_exits = pd.Series(False, index=reference_close.index)
|
||||
short_entries = pd.Series(False, index=reference_close.index)
|
||||
short_exits = pd.Series(False, index=reference_close.index)
|
||||
size = pd.Series(1.0, index=reference_close.index)
|
||||
|
||||
# Track position state
|
||||
self.position = PositionState()
|
||||
|
||||
# Price data for correlation calculation
|
||||
price_data = {
|
||||
symbol: df['close'] for symbol, df in self.asset_data.items()
|
||||
}
|
||||
|
||||
# Iterate through test period
|
||||
test_indices = index[train_size:]
|
||||
|
||||
trade_count = 0
|
||||
|
||||
for i, timestamp in enumerate(test_indices):
|
||||
current_idx = train_size + i
|
||||
|
||||
# Check exit conditions first
|
||||
if self.position.pair is not None:
|
||||
# Enforce minimum hold period
|
||||
bars_held = current_idx - self.position.entry_idx
|
||||
if bars_held < self.config.min_hold_bars:
|
||||
# Only allow SL/TP exits during min hold period
|
||||
should_exit, exit_reason = self._check_sl_tp_only(timestamp)
|
||||
else:
|
||||
should_exit, exit_reason = self._check_exit(timestamp)
|
||||
|
||||
if should_exit:
|
||||
# Map exit signal to reference index
|
||||
if timestamp in reference_close.index:
|
||||
if self.position.direction == 'long':
|
||||
long_exits.loc[timestamp] = True
|
||||
else:
|
||||
short_exits.loc[timestamp] = True
|
||||
|
||||
logger.debug(
|
||||
"Exit %s %s at %s: %s (held %d bars)",
|
||||
self.position.direction,
|
||||
self.position.pair.name,
|
||||
timestamp.strftime('%Y-%m-%d %H:%M'),
|
||||
exit_reason,
|
||||
bars_held
|
||||
)
|
||||
self.position = PositionState(last_exit_idx=current_idx)
|
||||
|
||||
# Score pairs (with correlation filter if position exists)
|
||||
held_asset = None
|
||||
if self.position.pair is not None:
|
||||
held_asset = self.position.pair.base_asset
|
||||
|
||||
# Filter pairs by correlation
|
||||
candidate_pairs = self.correlation_filter.filter_pairs(
|
||||
self.pairs,
|
||||
held_asset,
|
||||
price_data,
|
||||
current_idx
|
||||
)
|
||||
|
||||
# Get candidate features
|
||||
candidate_features = {
|
||||
pid: feat for pid, feat in self.pair_features.items()
|
||||
if any(p.pair_id == pid for p in candidate_pairs)
|
||||
}
|
||||
|
||||
# Score pairs
|
||||
signals = self.divergence_scorer.score_pairs(
|
||||
candidate_features,
|
||||
candidate_pairs,
|
||||
timestamp
|
||||
)
|
||||
|
||||
# Get best signal
|
||||
best = self.divergence_scorer.select_best_pair(signals)
|
||||
|
||||
if best is None:
|
||||
continue
|
||||
|
||||
# Check if we should switch positions or enter new
|
||||
should_enter = False
|
||||
|
||||
# Check cooldown
|
||||
bars_since_exit = current_idx - self.position.last_exit_idx
|
||||
in_cooldown = bars_since_exit < self.config.cooldown_bars
|
||||
|
||||
if self.position.pair is None and not in_cooldown:
|
||||
# No position and not in cooldown, can enter
|
||||
should_enter = True
|
||||
elif self.position.pair is not None:
|
||||
# Check if we should switch (requires min hold + significant improvement)
|
||||
bars_held = current_idx - self.position.entry_idx
|
||||
current_score = self._get_current_score(timestamp)
|
||||
|
||||
if (bars_held >= self.config.min_hold_bars and
|
||||
best.divergence_score > current_score * self.config.switch_threshold):
|
||||
# New opportunity is significantly better
|
||||
if timestamp in reference_close.index:
|
||||
if self.position.direction == 'long':
|
||||
long_exits.loc[timestamp] = True
|
||||
else:
|
||||
short_exits.loc[timestamp] = True
|
||||
self.position = PositionState(last_exit_idx=current_idx)
|
||||
should_enter = True
|
||||
|
||||
if should_enter:
|
||||
# Calculate ATR-based dynamic SL/TP
|
||||
sl_price, tp_price = self._calculate_sl_tp(
|
||||
best.base_price,
|
||||
best.direction,
|
||||
best.atr,
|
||||
best.atr_pct
|
||||
)
|
||||
|
||||
# Set position
|
||||
self.position = PositionState(
|
||||
pair=best.pair,
|
||||
direction=best.direction,
|
||||
entry_price=best.base_price,
|
||||
entry_idx=current_idx,
|
||||
stop_loss=sl_price,
|
||||
take_profit=tp_price,
|
||||
atr=best.atr
|
||||
)
|
||||
|
||||
# Calculate position size based on divergence
|
||||
pos_size = self._calculate_size(best.divergence_score)
|
||||
|
||||
# Generate entry signal
|
||||
if timestamp in reference_close.index:
|
||||
if best.direction == 'long':
|
||||
long_entries.loc[timestamp] = True
|
||||
else:
|
||||
short_entries.loc[timestamp] = True
|
||||
size.loc[timestamp] = pos_size
|
||||
|
||||
trade_count += 1
|
||||
logger.debug(
|
||||
"Entry %s %s at %s: z=%.2f, prob=%.2f, score=%.3f",
|
||||
best.direction,
|
||||
best.pair.name,
|
||||
timestamp.strftime('%Y-%m-%d %H:%M'),
|
||||
best.z_score,
|
||||
best.probability,
|
||||
best.divergence_score
|
||||
)
|
||||
|
||||
logger.info("Generated %d trades in test period", trade_count)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits, size
|
||||
|
||||
def _check_exit(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
|
||||
"""
|
||||
Check if current position should be exited.
|
||||
|
||||
Exit conditions:
|
||||
1. Z-Score reverted to mean (|Z| < threshold)
|
||||
2. Stop-loss hit
|
||||
3. Take-profit hit
|
||||
|
||||
Returns:
|
||||
Tuple of (should_exit, reason)
|
||||
"""
|
||||
if self.position.pair is None:
|
||||
return False, ""
|
||||
|
||||
pair_id = self.position.pair.pair_id
|
||||
if pair_id not in self.pair_features:
|
||||
return True, "pair_data_missing"
|
||||
|
||||
features = self.pair_features[pair_id]
|
||||
valid = features[features.index <= timestamp]
|
||||
|
||||
if len(valid) == 0:
|
||||
return True, "no_data"
|
||||
|
||||
latest = valid.iloc[-1]
|
||||
z_score = latest['z_score']
|
||||
current_price = latest['base_close']
|
||||
|
||||
# Check mean reversion (primary exit)
|
||||
if abs(z_score) < self.config.z_exit_threshold:
|
||||
return True, f"mean_reversion (z={z_score:.2f})"
|
||||
|
||||
# Check SL/TP
|
||||
return self._check_sl_tp(current_price)
|
||||
|
||||
def _check_sl_tp_only(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
|
||||
"""
|
||||
Check only stop-loss and take-profit conditions.
|
||||
Used during minimum hold period.
|
||||
"""
|
||||
if self.position.pair is None:
|
||||
return False, ""
|
||||
|
||||
pair_id = self.position.pair.pair_id
|
||||
if pair_id not in self.pair_features:
|
||||
return True, "pair_data_missing"
|
||||
|
||||
features = self.pair_features[pair_id]
|
||||
valid = features[features.index <= timestamp]
|
||||
|
||||
if len(valid) == 0:
|
||||
return True, "no_data"
|
||||
|
||||
latest = valid.iloc[-1]
|
||||
current_price = latest['base_close']
|
||||
|
||||
return self._check_sl_tp(current_price)
|
||||
|
||||
def _check_sl_tp(self, current_price: float) -> tuple[bool, str]:
|
||||
"""Check stop-loss and take-profit levels."""
|
||||
if self.position.direction == 'long':
|
||||
if current_price <= self.position.stop_loss:
|
||||
return True, f"stop_loss ({current_price:.2f} <= {self.position.stop_loss:.2f})"
|
||||
if current_price >= self.position.take_profit:
|
||||
return True, f"take_profit ({current_price:.2f} >= {self.position.take_profit:.2f})"
|
||||
else: # short
|
||||
if current_price >= self.position.stop_loss:
|
||||
return True, f"stop_loss ({current_price:.2f} >= {self.position.stop_loss:.2f})"
|
||||
if current_price <= self.position.take_profit:
|
||||
return True, f"take_profit ({current_price:.2f} <= {self.position.take_profit:.2f})"
|
||||
|
||||
return False, ""
|
||||
|
||||
def _get_current_score(self, timestamp: pd.Timestamp) -> float:
|
||||
"""Get current position's divergence score for comparison."""
|
||||
if self.position.pair is None:
|
||||
return 0.0
|
||||
|
||||
pair_id = self.position.pair.pair_id
|
||||
if pair_id not in self.pair_features:
|
||||
return 0.0
|
||||
|
||||
features = self.pair_features[pair_id]
|
||||
valid = features[features.index <= timestamp]
|
||||
|
||||
if len(valid) == 0:
|
||||
return 0.0
|
||||
|
||||
latest = valid.iloc[-1]
|
||||
z_score = abs(latest['z_score'])
|
||||
|
||||
# Re-score with model
|
||||
if self.divergence_scorer.model is not None:
|
||||
feature_row = latest[self.divergence_scorer.feature_cols].fillna(0)
|
||||
feature_row = feature_row.replace([np.inf, -np.inf], 0)
|
||||
X = pd.DataFrame(
|
||||
[feature_row.values],
|
||||
columns=self.divergence_scorer.feature_cols
|
||||
)
|
||||
prob = self.divergence_scorer.model.predict_proba(X)[0, 1]
|
||||
return z_score * prob
|
||||
|
||||
return z_score * 0.5
|
||||
|
||||
def _calculate_sl_tp(
|
||||
self,
|
||||
entry_price: float,
|
||||
direction: str,
|
||||
atr: float,
|
||||
atr_pct: float
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate ATR-based dynamic stop-loss and take-profit prices.
|
||||
|
||||
Uses ATR (Average True Range) to set stops that adapt to
|
||||
each asset's volatility. More volatile assets get wider stops.
|
||||
|
||||
Args:
|
||||
entry_price: Entry price
|
||||
direction: 'long' or 'short'
|
||||
atr: ATR in price units
|
||||
atr_pct: ATR as percentage of price
|
||||
|
||||
Returns:
|
||||
Tuple of (stop_loss_price, take_profit_price)
|
||||
"""
|
||||
# Calculate SL/TP as ATR multiples
|
||||
if atr > 0 and atr_pct > 0:
|
||||
# ATR-based calculation
|
||||
sl_distance = atr * self.config.sl_atr_multiplier
|
||||
tp_distance = atr * self.config.tp_atr_multiplier
|
||||
|
||||
# Convert to percentage for bounds checking
|
||||
sl_pct = sl_distance / entry_price
|
||||
tp_pct = tp_distance / entry_price
|
||||
else:
|
||||
# Fallback to fixed percentages if ATR unavailable
|
||||
sl_pct = self.config.base_sl_pct
|
||||
tp_pct = self.config.base_tp_pct
|
||||
|
||||
# Apply bounds to prevent extreme stops
|
||||
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
|
||||
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
|
||||
|
||||
# Calculate actual prices
|
||||
if direction == 'long':
|
||||
stop_loss = entry_price * (1 - sl_pct)
|
||||
take_profit = entry_price * (1 + tp_pct)
|
||||
else: # short
|
||||
stop_loss = entry_price * (1 + sl_pct)
|
||||
take_profit = entry_price * (1 - tp_pct)
|
||||
|
||||
return stop_loss, take_profit
|
||||
|
||||
def _calculate_size(self, divergence_score: float) -> float:
|
||||
"""
|
||||
Calculate position size based on divergence score.
|
||||
|
||||
Higher divergence = larger position (up to 2x).
|
||||
"""
|
||||
# Base score threshold (Z=1.0, prob=0.5 -> score=0.5)
|
||||
base_threshold = 0.5
|
||||
|
||||
# Scale factor
|
||||
if divergence_score <= base_threshold:
|
||||
return 1.0
|
||||
|
||||
# Linear scaling: 1.0 at threshold, up to 2.0 at 2x threshold
|
||||
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
|
||||
return min(scale, 2.0)
|
||||
|
||||
def _get_common_index(self) -> pd.DatetimeIndex:
|
||||
"""Get the intersection of all pair feature indices."""
|
||||
if not self.pair_features:
|
||||
return pd.DatetimeIndex([])
|
||||
|
||||
common = None
|
||||
for features in self.pair_features.values():
|
||||
if common is None:
|
||||
common = features.index
|
||||
else:
|
||||
common = common.intersection(features.index)
|
||||
|
||||
return common.sort_values()
|
||||
|
||||
def _empty_signals(self, close: pd.Series) -> tuple:
|
||||
"""Return empty signal arrays."""
|
||||
empty = self.create_empty_signals(close)
|
||||
size = pd.Series(1.0, index=close.index)
|
||||
return empty, empty, empty, empty, size
|
||||
Reference in New Issue
Block a user