Files
lowkey_backtest/strategies/multi_pair/strategy.py
Simon Moisy df37366603 feat: Multi-Pair Divergence Selection Strategy
- Extend regime detection to top 10 cryptocurrencies (45 pairs)
- Dynamic pair selection based on divergence score (|z_score| * probability)
- Universal ML model trained on all pairs
- Correlation-based filtering to avoid redundant positions
- Funding rate integration from OKX for all 10 assets
- ATR-based dynamic stop-loss and take-profit
- Walk-forward training with 70/30 split

Performance: +35.69% return (vs +28.66% baseline), 63.6% win rate
2026-01-15 20:47:23 +08:00

526 lines
19 KiB
Python

"""
Multi-Pair Divergence Selection Strategy.
Main strategy class that orchestrates pair scanning, feature calculation,
model training, and signal generation for backtesting.
"""
from dataclasses import dataclass
from typing import Optional
import pandas as pd
import numpy as np
from strategies.base import BaseStrategy
from engine.market import MarketType
from engine.logging_config import get_logger
from .config import MultiPairConfig
from .pair_scanner import PairScanner, TradingPair
from .correlation import CorrelationFilter
from .feature_engine import MultiPairFeatureEngine
from .divergence_scorer import DivergenceScorer, DivergenceSignal
logger = get_logger(__name__)
@dataclass
class PositionState:
"""Tracks current position state."""
pair: TradingPair | None = None
direction: str | None = None # 'long' or 'short'
entry_price: float = 0.0
entry_idx: int = -1
stop_loss: float = 0.0
take_profit: float = 0.0
atr: float = 0.0 # ATR at entry for reference
last_exit_idx: int = -100 # For cooldown tracking
class MultiPairDivergenceStrategy(BaseStrategy):
"""
Multi-Pair Divergence Selection Strategy.
Scans multiple cryptocurrency pairs for spread divergence,
selects the most divergent pair using ML-enhanced scoring,
and trades mean-reversion opportunities.
Key Features:
- Universal ML model across all pairs
- Correlation-based pair filtering
- Dynamic SL/TP based on volatility
- Walk-forward training
"""
def __init__(
self,
config: MultiPairConfig | None = None,
model_path: str = "data/multi_pair_model.pkl"
):
super().__init__()
self.config = config or MultiPairConfig()
# Initialize components
self.pair_scanner = PairScanner(self.config)
self.correlation_filter = CorrelationFilter(self.config)
self.feature_engine = MultiPairFeatureEngine(self.config)
self.divergence_scorer = DivergenceScorer(self.config, model_path)
# Strategy configuration
self.default_market_type = MarketType.PERPETUAL
self.default_leverage = 1
# Runtime state
self.pairs: list[TradingPair] = []
self.asset_data: dict[str, pd.DataFrame] = {}
self.pair_features: dict[str, pd.DataFrame] = {}
self.position = PositionState()
self.train_end_idx: int = 0
def run(self, close: pd.Series, **kwargs) -> tuple:
"""
Execute the multi-pair divergence strategy.
This method is called by the backtester with the primary asset's
close prices. For multi-pair, we load all assets internally.
Args:
close: Primary close prices (used for index alignment)
**kwargs: Additional data (high, low, volume)
Returns:
Tuple of (long_entries, long_exits, short_entries, short_exits, size)
"""
logger.info("Starting Multi-Pair Divergence Strategy")
# 1. Load all asset data
start_date = close.index.min().strftime("%Y-%m-%d")
end_date = close.index.max().strftime("%Y-%m-%d")
self.asset_data = self.feature_engine.load_all_assets(
start_date=start_date,
end_date=end_date
)
# 1b. Load funding rate data for all assets
self.feature_engine.load_funding_data(
start_date=start_date,
end_date=end_date,
use_cache=True
)
if len(self.asset_data) < 2:
logger.error("Insufficient assets loaded, need at least 2")
return self._empty_signals(close)
# 2. Generate pairs
self.pairs = self.pair_scanner.generate_pairs(check_exchange=False)
# Filter to pairs with available data
available_assets = set(self.asset_data.keys())
self.pairs = [
p for p in self.pairs
if p.base_asset in available_assets
and p.quote_asset in available_assets
]
logger.info("Trading %d pairs from %d assets", len(self.pairs), len(self.asset_data))
# 3. Calculate features for all pairs
self.pair_features = self.feature_engine.calculate_all_pair_features(
self.pairs, self.asset_data
)
if not self.pair_features:
logger.error("No pair features calculated")
return self._empty_signals(close)
# 4. Align to common index
common_index = self._get_common_index()
if len(common_index) < 200:
logger.error("Insufficient common data across pairs")
return self._empty_signals(close)
# 5. Walk-forward split
n_samples = len(common_index)
train_size = int(n_samples * self.config.train_ratio)
self.train_end_idx = train_size
train_end_date = common_index[train_size - 1]
test_start_date = common_index[train_size]
logger.info(
"Walk-Forward Split: Train=%d bars (until %s), Test=%d bars (from %s)",
train_size, train_end_date.strftime('%Y-%m-%d'),
n_samples - train_size, test_start_date.strftime('%Y-%m-%d')
)
# 6. Train model on training period
if self.divergence_scorer.model is None:
train_features = {
pid: feat[feat.index <= train_end_date]
for pid, feat in self.pair_features.items()
}
combined = self.feature_engine.get_combined_features(train_features)
self.divergence_scorer.train_model(combined, train_features)
# 7. Generate signals for test period
return self._generate_signals(common_index, train_size, close)
def _generate_signals(
self,
index: pd.DatetimeIndex,
train_size: int,
reference_close: pd.Series
) -> tuple:
"""
Generate entry/exit signals for the test period.
Iterates through each bar in the test period, scoring pairs
and generating signals based on divergence scores.
"""
# Initialize signal arrays aligned to reference close
long_entries = pd.Series(False, index=reference_close.index)
long_exits = pd.Series(False, index=reference_close.index)
short_entries = pd.Series(False, index=reference_close.index)
short_exits = pd.Series(False, index=reference_close.index)
size = pd.Series(1.0, index=reference_close.index)
# Track position state
self.position = PositionState()
# Price data for correlation calculation
price_data = {
symbol: df['close'] for symbol, df in self.asset_data.items()
}
# Iterate through test period
test_indices = index[train_size:]
trade_count = 0
for i, timestamp in enumerate(test_indices):
current_idx = train_size + i
# Check exit conditions first
if self.position.pair is not None:
# Enforce minimum hold period
bars_held = current_idx - self.position.entry_idx
if bars_held < self.config.min_hold_bars:
# Only allow SL/TP exits during min hold period
should_exit, exit_reason = self._check_sl_tp_only(timestamp)
else:
should_exit, exit_reason = self._check_exit(timestamp)
if should_exit:
# Map exit signal to reference index
if timestamp in reference_close.index:
if self.position.direction == 'long':
long_exits.loc[timestamp] = True
else:
short_exits.loc[timestamp] = True
logger.debug(
"Exit %s %s at %s: %s (held %d bars)",
self.position.direction,
self.position.pair.name,
timestamp.strftime('%Y-%m-%d %H:%M'),
exit_reason,
bars_held
)
self.position = PositionState(last_exit_idx=current_idx)
# Score pairs (with correlation filter if position exists)
held_asset = None
if self.position.pair is not None:
held_asset = self.position.pair.base_asset
# Filter pairs by correlation
candidate_pairs = self.correlation_filter.filter_pairs(
self.pairs,
held_asset,
price_data,
current_idx
)
# Get candidate features
candidate_features = {
pid: feat for pid, feat in self.pair_features.items()
if any(p.pair_id == pid for p in candidate_pairs)
}
# Score pairs
signals = self.divergence_scorer.score_pairs(
candidate_features,
candidate_pairs,
timestamp
)
# Get best signal
best = self.divergence_scorer.select_best_pair(signals)
if best is None:
continue
# Check if we should switch positions or enter new
should_enter = False
# Check cooldown
bars_since_exit = current_idx - self.position.last_exit_idx
in_cooldown = bars_since_exit < self.config.cooldown_bars
if self.position.pair is None and not in_cooldown:
# No position and not in cooldown, can enter
should_enter = True
elif self.position.pair is not None:
# Check if we should switch (requires min hold + significant improvement)
bars_held = current_idx - self.position.entry_idx
current_score = self._get_current_score(timestamp)
if (bars_held >= self.config.min_hold_bars and
best.divergence_score > current_score * self.config.switch_threshold):
# New opportunity is significantly better
if timestamp in reference_close.index:
if self.position.direction == 'long':
long_exits.loc[timestamp] = True
else:
short_exits.loc[timestamp] = True
self.position = PositionState(last_exit_idx=current_idx)
should_enter = True
if should_enter:
# Calculate ATR-based dynamic SL/TP
sl_price, tp_price = self._calculate_sl_tp(
best.base_price,
best.direction,
best.atr,
best.atr_pct
)
# Set position
self.position = PositionState(
pair=best.pair,
direction=best.direction,
entry_price=best.base_price,
entry_idx=current_idx,
stop_loss=sl_price,
take_profit=tp_price,
atr=best.atr
)
# Calculate position size based on divergence
pos_size = self._calculate_size(best.divergence_score)
# Generate entry signal
if timestamp in reference_close.index:
if best.direction == 'long':
long_entries.loc[timestamp] = True
else:
short_entries.loc[timestamp] = True
size.loc[timestamp] = pos_size
trade_count += 1
logger.debug(
"Entry %s %s at %s: z=%.2f, prob=%.2f, score=%.3f",
best.direction,
best.pair.name,
timestamp.strftime('%Y-%m-%d %H:%M'),
best.z_score,
best.probability,
best.divergence_score
)
logger.info("Generated %d trades in test period", trade_count)
return long_entries, long_exits, short_entries, short_exits, size
def _check_exit(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
"""
Check if current position should be exited.
Exit conditions:
1. Z-Score reverted to mean (|Z| < threshold)
2. Stop-loss hit
3. Take-profit hit
Returns:
Tuple of (should_exit, reason)
"""
if self.position.pair is None:
return False, ""
pair_id = self.position.pair.pair_id
if pair_id not in self.pair_features:
return True, "pair_data_missing"
features = self.pair_features[pair_id]
valid = features[features.index <= timestamp]
if len(valid) == 0:
return True, "no_data"
latest = valid.iloc[-1]
z_score = latest['z_score']
current_price = latest['base_close']
# Check mean reversion (primary exit)
if abs(z_score) < self.config.z_exit_threshold:
return True, f"mean_reversion (z={z_score:.2f})"
# Check SL/TP
return self._check_sl_tp(current_price)
def _check_sl_tp_only(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
"""
Check only stop-loss and take-profit conditions.
Used during minimum hold period.
"""
if self.position.pair is None:
return False, ""
pair_id = self.position.pair.pair_id
if pair_id not in self.pair_features:
return True, "pair_data_missing"
features = self.pair_features[pair_id]
valid = features[features.index <= timestamp]
if len(valid) == 0:
return True, "no_data"
latest = valid.iloc[-1]
current_price = latest['base_close']
return self._check_sl_tp(current_price)
def _check_sl_tp(self, current_price: float) -> tuple[bool, str]:
"""Check stop-loss and take-profit levels."""
if self.position.direction == 'long':
if current_price <= self.position.stop_loss:
return True, f"stop_loss ({current_price:.2f} <= {self.position.stop_loss:.2f})"
if current_price >= self.position.take_profit:
return True, f"take_profit ({current_price:.2f} >= {self.position.take_profit:.2f})"
else: # short
if current_price >= self.position.stop_loss:
return True, f"stop_loss ({current_price:.2f} >= {self.position.stop_loss:.2f})"
if current_price <= self.position.take_profit:
return True, f"take_profit ({current_price:.2f} <= {self.position.take_profit:.2f})"
return False, ""
def _get_current_score(self, timestamp: pd.Timestamp) -> float:
"""Get current position's divergence score for comparison."""
if self.position.pair is None:
return 0.0
pair_id = self.position.pair.pair_id
if pair_id not in self.pair_features:
return 0.0
features = self.pair_features[pair_id]
valid = features[features.index <= timestamp]
if len(valid) == 0:
return 0.0
latest = valid.iloc[-1]
z_score = abs(latest['z_score'])
# Re-score with model
if self.divergence_scorer.model is not None:
feature_row = latest[self.divergence_scorer.feature_cols].fillna(0)
feature_row = feature_row.replace([np.inf, -np.inf], 0)
X = pd.DataFrame(
[feature_row.values],
columns=self.divergence_scorer.feature_cols
)
prob = self.divergence_scorer.model.predict_proba(X)[0, 1]
return z_score * prob
return z_score * 0.5
def _calculate_sl_tp(
self,
entry_price: float,
direction: str,
atr: float,
atr_pct: float
) -> tuple[float, float]:
"""
Calculate ATR-based dynamic stop-loss and take-profit prices.
Uses ATR (Average True Range) to set stops that adapt to
each asset's volatility. More volatile assets get wider stops.
Args:
entry_price: Entry price
direction: 'long' or 'short'
atr: ATR in price units
atr_pct: ATR as percentage of price
Returns:
Tuple of (stop_loss_price, take_profit_price)
"""
# Calculate SL/TP as ATR multiples
if atr > 0 and atr_pct > 0:
# ATR-based calculation
sl_distance = atr * self.config.sl_atr_multiplier
tp_distance = atr * self.config.tp_atr_multiplier
# Convert to percentage for bounds checking
sl_pct = sl_distance / entry_price
tp_pct = tp_distance / entry_price
else:
# Fallback to fixed percentages if ATR unavailable
sl_pct = self.config.base_sl_pct
tp_pct = self.config.base_tp_pct
# Apply bounds to prevent extreme stops
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
# Calculate actual prices
if direction == 'long':
stop_loss = entry_price * (1 - sl_pct)
take_profit = entry_price * (1 + tp_pct)
else: # short
stop_loss = entry_price * (1 + sl_pct)
take_profit = entry_price * (1 - tp_pct)
return stop_loss, take_profit
def _calculate_size(self, divergence_score: float) -> float:
"""
Calculate position size based on divergence score.
Higher divergence = larger position (up to 2x).
"""
# Base score threshold (Z=1.0, prob=0.5 -> score=0.5)
base_threshold = 0.5
# Scale factor
if divergence_score <= base_threshold:
return 1.0
# Linear scaling: 1.0 at threshold, up to 2.0 at 2x threshold
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
return min(scale, 2.0)
def _get_common_index(self) -> pd.DatetimeIndex:
"""Get the intersection of all pair feature indices."""
if not self.pair_features:
return pd.DatetimeIndex([])
common = None
for features in self.pair_features.values():
if common is None:
common = features.index
else:
common = common.intersection(features.index)
return common.sort_values()
def _empty_signals(self, close: pd.Series) -> tuple:
"""Return empty signal arrays."""
empty = self.create_empty_signals(close)
size = pd.Series(1.0, index=close.index)
return empty, empty, empty, empty, size