- Extend regime detection to top 10 cryptocurrencies (45 pairs) - Dynamic pair selection based on divergence score (|z_score| * probability) - Universal ML model trained on all pairs - Correlation-based filtering to avoid redundant positions - Funding rate integration from OKX for all 10 assets - ATR-based dynamic stop-loss and take-profit - Walk-forward training with 70/30 split Performance: +35.69% return (vs +28.66% baseline), 63.6% win rate
526 lines
19 KiB
Python
526 lines
19 KiB
Python
"""
|
|
Multi-Pair Divergence Selection Strategy.
|
|
|
|
Main strategy class that orchestrates pair scanning, feature calculation,
|
|
model training, and signal generation for backtesting.
|
|
"""
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
from strategies.base import BaseStrategy
|
|
from engine.market import MarketType
|
|
from engine.logging_config import get_logger
|
|
from .config import MultiPairConfig
|
|
from .pair_scanner import PairScanner, TradingPair
|
|
from .correlation import CorrelationFilter
|
|
from .feature_engine import MultiPairFeatureEngine
|
|
from .divergence_scorer import DivergenceScorer, DivergenceSignal
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class PositionState:
|
|
"""Tracks current position state."""
|
|
pair: TradingPair | None = None
|
|
direction: str | None = None # 'long' or 'short'
|
|
entry_price: float = 0.0
|
|
entry_idx: int = -1
|
|
stop_loss: float = 0.0
|
|
take_profit: float = 0.0
|
|
atr: float = 0.0 # ATR at entry for reference
|
|
last_exit_idx: int = -100 # For cooldown tracking
|
|
|
|
|
|
class MultiPairDivergenceStrategy(BaseStrategy):
|
|
"""
|
|
Multi-Pair Divergence Selection Strategy.
|
|
|
|
Scans multiple cryptocurrency pairs for spread divergence,
|
|
selects the most divergent pair using ML-enhanced scoring,
|
|
and trades mean-reversion opportunities.
|
|
|
|
Key Features:
|
|
- Universal ML model across all pairs
|
|
- Correlation-based pair filtering
|
|
- Dynamic SL/TP based on volatility
|
|
- Walk-forward training
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: MultiPairConfig | None = None,
|
|
model_path: str = "data/multi_pair_model.pkl"
|
|
):
|
|
super().__init__()
|
|
self.config = config or MultiPairConfig()
|
|
|
|
# Initialize components
|
|
self.pair_scanner = PairScanner(self.config)
|
|
self.correlation_filter = CorrelationFilter(self.config)
|
|
self.feature_engine = MultiPairFeatureEngine(self.config)
|
|
self.divergence_scorer = DivergenceScorer(self.config, model_path)
|
|
|
|
# Strategy configuration
|
|
self.default_market_type = MarketType.PERPETUAL
|
|
self.default_leverage = 1
|
|
|
|
# Runtime state
|
|
self.pairs: list[TradingPair] = []
|
|
self.asset_data: dict[str, pd.DataFrame] = {}
|
|
self.pair_features: dict[str, pd.DataFrame] = {}
|
|
self.position = PositionState()
|
|
self.train_end_idx: int = 0
|
|
|
|
def run(self, close: pd.Series, **kwargs) -> tuple:
|
|
"""
|
|
Execute the multi-pair divergence strategy.
|
|
|
|
This method is called by the backtester with the primary asset's
|
|
close prices. For multi-pair, we load all assets internally.
|
|
|
|
Args:
|
|
close: Primary close prices (used for index alignment)
|
|
**kwargs: Additional data (high, low, volume)
|
|
|
|
Returns:
|
|
Tuple of (long_entries, long_exits, short_entries, short_exits, size)
|
|
"""
|
|
logger.info("Starting Multi-Pair Divergence Strategy")
|
|
|
|
# 1. Load all asset data
|
|
start_date = close.index.min().strftime("%Y-%m-%d")
|
|
end_date = close.index.max().strftime("%Y-%m-%d")
|
|
|
|
self.asset_data = self.feature_engine.load_all_assets(
|
|
start_date=start_date,
|
|
end_date=end_date
|
|
)
|
|
|
|
# 1b. Load funding rate data for all assets
|
|
self.feature_engine.load_funding_data(
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
use_cache=True
|
|
)
|
|
|
|
if len(self.asset_data) < 2:
|
|
logger.error("Insufficient assets loaded, need at least 2")
|
|
return self._empty_signals(close)
|
|
|
|
# 2. Generate pairs
|
|
self.pairs = self.pair_scanner.generate_pairs(check_exchange=False)
|
|
|
|
# Filter to pairs with available data
|
|
available_assets = set(self.asset_data.keys())
|
|
self.pairs = [
|
|
p for p in self.pairs
|
|
if p.base_asset in available_assets
|
|
and p.quote_asset in available_assets
|
|
]
|
|
|
|
logger.info("Trading %d pairs from %d assets", len(self.pairs), len(self.asset_data))
|
|
|
|
# 3. Calculate features for all pairs
|
|
self.pair_features = self.feature_engine.calculate_all_pair_features(
|
|
self.pairs, self.asset_data
|
|
)
|
|
|
|
if not self.pair_features:
|
|
logger.error("No pair features calculated")
|
|
return self._empty_signals(close)
|
|
|
|
# 4. Align to common index
|
|
common_index = self._get_common_index()
|
|
if len(common_index) < 200:
|
|
logger.error("Insufficient common data across pairs")
|
|
return self._empty_signals(close)
|
|
|
|
# 5. Walk-forward split
|
|
n_samples = len(common_index)
|
|
train_size = int(n_samples * self.config.train_ratio)
|
|
self.train_end_idx = train_size
|
|
|
|
train_end_date = common_index[train_size - 1]
|
|
test_start_date = common_index[train_size]
|
|
|
|
logger.info(
|
|
"Walk-Forward Split: Train=%d bars (until %s), Test=%d bars (from %s)",
|
|
train_size, train_end_date.strftime('%Y-%m-%d'),
|
|
n_samples - train_size, test_start_date.strftime('%Y-%m-%d')
|
|
)
|
|
|
|
# 6. Train model on training period
|
|
if self.divergence_scorer.model is None:
|
|
train_features = {
|
|
pid: feat[feat.index <= train_end_date]
|
|
for pid, feat in self.pair_features.items()
|
|
}
|
|
combined = self.feature_engine.get_combined_features(train_features)
|
|
self.divergence_scorer.train_model(combined, train_features)
|
|
|
|
# 7. Generate signals for test period
|
|
return self._generate_signals(common_index, train_size, close)
|
|
|
|
def _generate_signals(
|
|
self,
|
|
index: pd.DatetimeIndex,
|
|
train_size: int,
|
|
reference_close: pd.Series
|
|
) -> tuple:
|
|
"""
|
|
Generate entry/exit signals for the test period.
|
|
|
|
Iterates through each bar in the test period, scoring pairs
|
|
and generating signals based on divergence scores.
|
|
"""
|
|
# Initialize signal arrays aligned to reference close
|
|
long_entries = pd.Series(False, index=reference_close.index)
|
|
long_exits = pd.Series(False, index=reference_close.index)
|
|
short_entries = pd.Series(False, index=reference_close.index)
|
|
short_exits = pd.Series(False, index=reference_close.index)
|
|
size = pd.Series(1.0, index=reference_close.index)
|
|
|
|
# Track position state
|
|
self.position = PositionState()
|
|
|
|
# Price data for correlation calculation
|
|
price_data = {
|
|
symbol: df['close'] for symbol, df in self.asset_data.items()
|
|
}
|
|
|
|
# Iterate through test period
|
|
test_indices = index[train_size:]
|
|
|
|
trade_count = 0
|
|
|
|
for i, timestamp in enumerate(test_indices):
|
|
current_idx = train_size + i
|
|
|
|
# Check exit conditions first
|
|
if self.position.pair is not None:
|
|
# Enforce minimum hold period
|
|
bars_held = current_idx - self.position.entry_idx
|
|
if bars_held < self.config.min_hold_bars:
|
|
# Only allow SL/TP exits during min hold period
|
|
should_exit, exit_reason = self._check_sl_tp_only(timestamp)
|
|
else:
|
|
should_exit, exit_reason = self._check_exit(timestamp)
|
|
|
|
if should_exit:
|
|
# Map exit signal to reference index
|
|
if timestamp in reference_close.index:
|
|
if self.position.direction == 'long':
|
|
long_exits.loc[timestamp] = True
|
|
else:
|
|
short_exits.loc[timestamp] = True
|
|
|
|
logger.debug(
|
|
"Exit %s %s at %s: %s (held %d bars)",
|
|
self.position.direction,
|
|
self.position.pair.name,
|
|
timestamp.strftime('%Y-%m-%d %H:%M'),
|
|
exit_reason,
|
|
bars_held
|
|
)
|
|
self.position = PositionState(last_exit_idx=current_idx)
|
|
|
|
# Score pairs (with correlation filter if position exists)
|
|
held_asset = None
|
|
if self.position.pair is not None:
|
|
held_asset = self.position.pair.base_asset
|
|
|
|
# Filter pairs by correlation
|
|
candidate_pairs = self.correlation_filter.filter_pairs(
|
|
self.pairs,
|
|
held_asset,
|
|
price_data,
|
|
current_idx
|
|
)
|
|
|
|
# Get candidate features
|
|
candidate_features = {
|
|
pid: feat for pid, feat in self.pair_features.items()
|
|
if any(p.pair_id == pid for p in candidate_pairs)
|
|
}
|
|
|
|
# Score pairs
|
|
signals = self.divergence_scorer.score_pairs(
|
|
candidate_features,
|
|
candidate_pairs,
|
|
timestamp
|
|
)
|
|
|
|
# Get best signal
|
|
best = self.divergence_scorer.select_best_pair(signals)
|
|
|
|
if best is None:
|
|
continue
|
|
|
|
# Check if we should switch positions or enter new
|
|
should_enter = False
|
|
|
|
# Check cooldown
|
|
bars_since_exit = current_idx - self.position.last_exit_idx
|
|
in_cooldown = bars_since_exit < self.config.cooldown_bars
|
|
|
|
if self.position.pair is None and not in_cooldown:
|
|
# No position and not in cooldown, can enter
|
|
should_enter = True
|
|
elif self.position.pair is not None:
|
|
# Check if we should switch (requires min hold + significant improvement)
|
|
bars_held = current_idx - self.position.entry_idx
|
|
current_score = self._get_current_score(timestamp)
|
|
|
|
if (bars_held >= self.config.min_hold_bars and
|
|
best.divergence_score > current_score * self.config.switch_threshold):
|
|
# New opportunity is significantly better
|
|
if timestamp in reference_close.index:
|
|
if self.position.direction == 'long':
|
|
long_exits.loc[timestamp] = True
|
|
else:
|
|
short_exits.loc[timestamp] = True
|
|
self.position = PositionState(last_exit_idx=current_idx)
|
|
should_enter = True
|
|
|
|
if should_enter:
|
|
# Calculate ATR-based dynamic SL/TP
|
|
sl_price, tp_price = self._calculate_sl_tp(
|
|
best.base_price,
|
|
best.direction,
|
|
best.atr,
|
|
best.atr_pct
|
|
)
|
|
|
|
# Set position
|
|
self.position = PositionState(
|
|
pair=best.pair,
|
|
direction=best.direction,
|
|
entry_price=best.base_price,
|
|
entry_idx=current_idx,
|
|
stop_loss=sl_price,
|
|
take_profit=tp_price,
|
|
atr=best.atr
|
|
)
|
|
|
|
# Calculate position size based on divergence
|
|
pos_size = self._calculate_size(best.divergence_score)
|
|
|
|
# Generate entry signal
|
|
if timestamp in reference_close.index:
|
|
if best.direction == 'long':
|
|
long_entries.loc[timestamp] = True
|
|
else:
|
|
short_entries.loc[timestamp] = True
|
|
size.loc[timestamp] = pos_size
|
|
|
|
trade_count += 1
|
|
logger.debug(
|
|
"Entry %s %s at %s: z=%.2f, prob=%.2f, score=%.3f",
|
|
best.direction,
|
|
best.pair.name,
|
|
timestamp.strftime('%Y-%m-%d %H:%M'),
|
|
best.z_score,
|
|
best.probability,
|
|
best.divergence_score
|
|
)
|
|
|
|
logger.info("Generated %d trades in test period", trade_count)
|
|
|
|
return long_entries, long_exits, short_entries, short_exits, size
|
|
|
|
def _check_exit(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
|
|
"""
|
|
Check if current position should be exited.
|
|
|
|
Exit conditions:
|
|
1. Z-Score reverted to mean (|Z| < threshold)
|
|
2. Stop-loss hit
|
|
3. Take-profit hit
|
|
|
|
Returns:
|
|
Tuple of (should_exit, reason)
|
|
"""
|
|
if self.position.pair is None:
|
|
return False, ""
|
|
|
|
pair_id = self.position.pair.pair_id
|
|
if pair_id not in self.pair_features:
|
|
return True, "pair_data_missing"
|
|
|
|
features = self.pair_features[pair_id]
|
|
valid = features[features.index <= timestamp]
|
|
|
|
if len(valid) == 0:
|
|
return True, "no_data"
|
|
|
|
latest = valid.iloc[-1]
|
|
z_score = latest['z_score']
|
|
current_price = latest['base_close']
|
|
|
|
# Check mean reversion (primary exit)
|
|
if abs(z_score) < self.config.z_exit_threshold:
|
|
return True, f"mean_reversion (z={z_score:.2f})"
|
|
|
|
# Check SL/TP
|
|
return self._check_sl_tp(current_price)
|
|
|
|
def _check_sl_tp_only(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
|
|
"""
|
|
Check only stop-loss and take-profit conditions.
|
|
Used during minimum hold period.
|
|
"""
|
|
if self.position.pair is None:
|
|
return False, ""
|
|
|
|
pair_id = self.position.pair.pair_id
|
|
if pair_id not in self.pair_features:
|
|
return True, "pair_data_missing"
|
|
|
|
features = self.pair_features[pair_id]
|
|
valid = features[features.index <= timestamp]
|
|
|
|
if len(valid) == 0:
|
|
return True, "no_data"
|
|
|
|
latest = valid.iloc[-1]
|
|
current_price = latest['base_close']
|
|
|
|
return self._check_sl_tp(current_price)
|
|
|
|
def _check_sl_tp(self, current_price: float) -> tuple[bool, str]:
|
|
"""Check stop-loss and take-profit levels."""
|
|
if self.position.direction == 'long':
|
|
if current_price <= self.position.stop_loss:
|
|
return True, f"stop_loss ({current_price:.2f} <= {self.position.stop_loss:.2f})"
|
|
if current_price >= self.position.take_profit:
|
|
return True, f"take_profit ({current_price:.2f} >= {self.position.take_profit:.2f})"
|
|
else: # short
|
|
if current_price >= self.position.stop_loss:
|
|
return True, f"stop_loss ({current_price:.2f} >= {self.position.stop_loss:.2f})"
|
|
if current_price <= self.position.take_profit:
|
|
return True, f"take_profit ({current_price:.2f} <= {self.position.take_profit:.2f})"
|
|
|
|
return False, ""
|
|
|
|
def _get_current_score(self, timestamp: pd.Timestamp) -> float:
|
|
"""Get current position's divergence score for comparison."""
|
|
if self.position.pair is None:
|
|
return 0.0
|
|
|
|
pair_id = self.position.pair.pair_id
|
|
if pair_id not in self.pair_features:
|
|
return 0.0
|
|
|
|
features = self.pair_features[pair_id]
|
|
valid = features[features.index <= timestamp]
|
|
|
|
if len(valid) == 0:
|
|
return 0.0
|
|
|
|
latest = valid.iloc[-1]
|
|
z_score = abs(latest['z_score'])
|
|
|
|
# Re-score with model
|
|
if self.divergence_scorer.model is not None:
|
|
feature_row = latest[self.divergence_scorer.feature_cols].fillna(0)
|
|
feature_row = feature_row.replace([np.inf, -np.inf], 0)
|
|
X = pd.DataFrame(
|
|
[feature_row.values],
|
|
columns=self.divergence_scorer.feature_cols
|
|
)
|
|
prob = self.divergence_scorer.model.predict_proba(X)[0, 1]
|
|
return z_score * prob
|
|
|
|
return z_score * 0.5
|
|
|
|
def _calculate_sl_tp(
|
|
self,
|
|
entry_price: float,
|
|
direction: str,
|
|
atr: float,
|
|
atr_pct: float
|
|
) -> tuple[float, float]:
|
|
"""
|
|
Calculate ATR-based dynamic stop-loss and take-profit prices.
|
|
|
|
Uses ATR (Average True Range) to set stops that adapt to
|
|
each asset's volatility. More volatile assets get wider stops.
|
|
|
|
Args:
|
|
entry_price: Entry price
|
|
direction: 'long' or 'short'
|
|
atr: ATR in price units
|
|
atr_pct: ATR as percentage of price
|
|
|
|
Returns:
|
|
Tuple of (stop_loss_price, take_profit_price)
|
|
"""
|
|
# Calculate SL/TP as ATR multiples
|
|
if atr > 0 and atr_pct > 0:
|
|
# ATR-based calculation
|
|
sl_distance = atr * self.config.sl_atr_multiplier
|
|
tp_distance = atr * self.config.tp_atr_multiplier
|
|
|
|
# Convert to percentage for bounds checking
|
|
sl_pct = sl_distance / entry_price
|
|
tp_pct = tp_distance / entry_price
|
|
else:
|
|
# Fallback to fixed percentages if ATR unavailable
|
|
sl_pct = self.config.base_sl_pct
|
|
tp_pct = self.config.base_tp_pct
|
|
|
|
# Apply bounds to prevent extreme stops
|
|
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
|
|
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
|
|
|
|
# Calculate actual prices
|
|
if direction == 'long':
|
|
stop_loss = entry_price * (1 - sl_pct)
|
|
take_profit = entry_price * (1 + tp_pct)
|
|
else: # short
|
|
stop_loss = entry_price * (1 + sl_pct)
|
|
take_profit = entry_price * (1 - tp_pct)
|
|
|
|
return stop_loss, take_profit
|
|
|
|
def _calculate_size(self, divergence_score: float) -> float:
|
|
"""
|
|
Calculate position size based on divergence score.
|
|
|
|
Higher divergence = larger position (up to 2x).
|
|
"""
|
|
# Base score threshold (Z=1.0, prob=0.5 -> score=0.5)
|
|
base_threshold = 0.5
|
|
|
|
# Scale factor
|
|
if divergence_score <= base_threshold:
|
|
return 1.0
|
|
|
|
# Linear scaling: 1.0 at threshold, up to 2.0 at 2x threshold
|
|
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
|
|
return min(scale, 2.0)
|
|
|
|
def _get_common_index(self) -> pd.DatetimeIndex:
|
|
"""Get the intersection of all pair feature indices."""
|
|
if not self.pair_features:
|
|
return pd.DatetimeIndex([])
|
|
|
|
common = None
|
|
for features in self.pair_features.values():
|
|
if common is None:
|
|
common = features.index
|
|
else:
|
|
common = common.intersection(features.index)
|
|
|
|
return common.sort_values()
|
|
|
|
def _empty_signals(self, close: pd.Series) -> tuple:
|
|
"""Return empty signal arrays."""
|
|
empty = self.create_empty_signals(close)
|
|
size = pd.Series(1.0, index=close.index)
|
|
return empty, empty, empty, empty, size
|