feat: Multi-Pair Divergence Selection Strategy

- Extend regime detection to top 10 cryptocurrencies (45 pairs)
- Dynamic pair selection based on divergence score (|z_score| * probability)
- Universal ML model trained on all pairs
- Correlation-based filtering to avoid redundant positions
- Funding rate integration from OKX for all 10 assets
- ATR-based dynamic stop-loss and take-profit
- Walk-forward training with 70/30 split

Performance: +35.69% return (vs +28.66% baseline), 63.6% win rate
This commit is contained in:
2026-01-15 20:47:23 +08:00
parent 7e4a6874a2
commit df37366603
13 changed files with 2531 additions and 0 deletions

BIN
data/multi_pair_model.pkl Normal file

Binary file not shown.

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python3
"""
Download historical data for Multi-Pair Divergence Strategy.
Downloads 1h OHLCV data for top 10 cryptocurrencies from OKX.
"""
import sys
sys.path.insert(0, '.')
from engine.data_manager import DataManager
from engine.market import MarketType
from engine.logging_config import setup_logging, get_logger
from strategies.multi_pair import MultiPairConfig
logger = get_logger(__name__)
def main():
"""Download data for all configured assets."""
setup_logging()
config = MultiPairConfig()
dm = DataManager()
logger.info("Downloading data for %d assets...", len(config.assets))
for symbol in config.assets:
logger.info("Downloading %s perpetual 1h data...", symbol)
try:
df = dm.download_data(
exchange_id=config.exchange_id,
symbol=symbol,
timeframe=config.timeframe,
market_type=MarketType.PERPETUAL
)
if df is not None:
logger.info("Downloaded %d candles for %s", len(df), symbol)
else:
logger.warning("No data downloaded for %s", symbol)
except Exception as e:
logger.error("Failed to download %s: %s", symbol, e)
logger.info("Download complete!")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,156 @@
#!/usr/bin/env python3
"""
Run Multi-Pair Divergence Strategy backtest and compare with baseline.
Compares the multi-pair strategy against the single-pair BTC/ETH regime strategy.
"""
import sys
sys.path.insert(0, '.')
from engine.backtester import Backtester
from engine.data_manager import DataManager
from engine.logging_config import setup_logging, get_logger
from engine.reporting import Reporter
from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig
from strategies.regime_strategy import RegimeReversionStrategy
from engine.market import MarketType
logger = get_logger(__name__)
def run_baseline():
"""Run baseline BTC/ETH regime strategy."""
logger.info("=" * 60)
logger.info("BASELINE: BTC/ETH Regime Reversion Strategy")
logger.info("=" * 60)
dm = DataManager()
bt = Backtester(dm)
strategy = RegimeReversionStrategy()
result = bt.run_strategy(
strategy,
'okx',
'ETH-USDT',
timeframe='1h',
init_cash=10000
)
logger.info("Baseline Results:")
logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100)
logger.info(" Total Trades: %d", result.portfolio.trades.count())
logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100)
return result
def run_multi_pair(assets: list[str] | None = None):
"""Run multi-pair divergence strategy."""
logger.info("=" * 60)
logger.info("MULTI-PAIR: Divergence Selection Strategy")
logger.info("=" * 60)
dm = DataManager()
bt = Backtester(dm)
# Use provided assets or default
if assets:
config = MultiPairConfig(assets=assets)
else:
config = MultiPairConfig()
logger.info("Configured %d assets, %d pairs", len(config.assets), config.get_pair_count())
strategy = MultiPairDivergenceStrategy(config=config)
result = bt.run_strategy(
strategy,
'okx',
'ETH-USDT', # Reference asset (not used for trading, just index alignment)
timeframe='1h',
init_cash=10000
)
logger.info("Multi-Pair Results:")
logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100)
logger.info(" Total Trades: %d", result.portfolio.trades.count())
logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100)
return result
def compare_results(baseline, multi_pair):
"""Compare and display results."""
logger.info("=" * 60)
logger.info("COMPARISON")
logger.info("=" * 60)
baseline_return = baseline.portfolio.total_return() * 100
multi_return = multi_pair.portfolio.total_return() * 100
improvement = multi_return - baseline_return
logger.info("Baseline Return: %.2f%%", baseline_return)
logger.info("Multi-Pair Return: %.2f%%", multi_return)
logger.info("Improvement: %.2f%% (%.1fx)",
improvement,
multi_return / baseline_return if baseline_return != 0 else 0)
baseline_trades = baseline.portfolio.trades.count()
multi_trades = multi_pair.portfolio.trades.count()
logger.info("Baseline Trades: %d", baseline_trades)
logger.info("Multi-Pair Trades: %d", multi_trades)
return {
'baseline_return': baseline_return,
'multi_pair_return': multi_return,
'improvement': improvement,
'baseline_trades': baseline_trades,
'multi_pair_trades': multi_trades
}
def main():
"""Main entry point."""
setup_logging()
# Check available assets
dm = DataManager()
available = []
for symbol in MultiPairConfig().assets:
try:
dm.load_data('okx', symbol, '1h', market_type=MarketType.PERPETUAL)
available.append(symbol)
except FileNotFoundError:
pass
if len(available) < 2:
logger.error(
"Need at least 2 assets to run multi-pair strategy. "
"Run: uv run python scripts/download_multi_pair_data.py"
)
return
logger.info("Found data for %d assets: %s", len(available), available)
# Run baseline
baseline_result = run_baseline()
# Run multi-pair
multi_result = run_multi_pair(available)
# Compare
comparison = compare_results(baseline_result, multi_result)
# Save reports
reporter = Reporter()
reporter.save_reports(multi_result, "multi_pair_divergence")
logger.info("Reports saved to backtest_logs/")
if __name__ == "__main__":
main()

View File

@@ -37,6 +37,7 @@ def _build_registry() -> dict[str, StrategyConfig]:
from strategies.examples import MaCrossStrategy, RsiStrategy
from strategies.supertrend import MetaSupertrendStrategy
from strategies.regime_strategy import RegimeReversionStrategy
from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig
return {
"rsi": StrategyConfig(
@@ -98,6 +99,18 @@ def _build_registry() -> dict[str, StrategyConfig]:
'stop_loss': [0.04, 0.06, 0.08],
'funding_threshold': [0.005, 0.01, 0.02]
}
),
"multi_pair": StrategyConfig(
strategy_class=MultiPairDivergenceStrategy,
default_params={
# Multi-pair divergence strategy uses config object
# Parameters passed here will override MultiPairConfig defaults
},
grid_params={
'z_entry_threshold': [0.8, 1.0, 1.2],
'prob_threshold': [0.4, 0.5, 0.6],
'correlation_threshold': [0.75, 0.85, 0.95]
}
)
}

View File

@@ -0,0 +1,24 @@
"""
Multi-Pair Divergence Selection Strategy.
Extends regime detection to multiple cryptocurrency pairs and dynamically
selects the most divergent pair for trading.
"""
from .config import MultiPairConfig
from .pair_scanner import PairScanner, TradingPair
from .correlation import CorrelationFilter
from .feature_engine import MultiPairFeatureEngine
from .divergence_scorer import DivergenceScorer
from .strategy import MultiPairDivergenceStrategy
from .funding import FundingRateFetcher
__all__ = [
"MultiPairConfig",
"PairScanner",
"TradingPair",
"CorrelationFilter",
"MultiPairFeatureEngine",
"DivergenceScorer",
"MultiPairDivergenceStrategy",
"FundingRateFetcher",
]

View File

@@ -0,0 +1,88 @@
"""
Configuration for Multi-Pair Divergence Strategy.
"""
from dataclasses import dataclass, field
@dataclass
class MultiPairConfig:
"""
Configuration parameters for multi-pair divergence strategy.
Attributes:
assets: List of asset symbols to analyze (top 10 by market cap)
z_window: Rolling window for Z-Score calculation (hours)
z_entry_threshold: Minimum |Z-Score| to consider for entry
prob_threshold: Minimum ML probability to consider for entry
correlation_threshold: Max correlation to allow between pairs
correlation_window: Rolling window for correlation (hours)
atr_period: ATR lookback period for dynamic stops
sl_atr_multiplier: Stop-loss as multiple of ATR
tp_atr_multiplier: Take-profit as multiple of ATR
train_ratio: Walk-forward train/test split ratio
horizon: Look-ahead horizon for target calculation (hours)
profit_target: Minimum profit threshold for target labels
funding_threshold: Funding rate threshold for filtering
"""
# Asset Universe
assets: list[str] = field(default_factory=lambda: [
"BTC-USDT", "ETH-USDT", "SOL-USDT", "XRP-USDT", "BNB-USDT",
"DOGE-USDT", "ADA-USDT", "AVAX-USDT", "LINK-USDT", "DOT-USDT"
])
# Z-Score Thresholds
z_window: int = 24
z_entry_threshold: float = 1.0
# ML Thresholds
prob_threshold: float = 0.5
train_ratio: float = 0.7
horizon: int = 102
profit_target: float = 0.005
# Correlation Filtering
correlation_threshold: float = 0.85
correlation_window: int = 168 # 7 days in hours
# Risk Management - ATR-Based Stops
# SL/TP are calculated as multiples of ATR
# Mean ATR for crypto is ~0.6% per hour, so:
# - 10x ATR = ~6% SL (matches previous fixed 6%)
# - 8x ATR = ~5% TP (matches previous fixed 5%)
atr_period: int = 14 # ATR lookback period (hours for 1h timeframe)
sl_atr_multiplier: float = 10.0 # Stop-loss = entry +/- (ATR * multiplier)
tp_atr_multiplier: float = 8.0 # Take-profit = entry +/- (ATR * multiplier)
# Fallback fixed percentages (used if ATR is unavailable)
base_sl_pct: float = 0.06
base_tp_pct: float = 0.05
# ATR bounds to prevent extreme stops
min_sl_pct: float = 0.02 # Minimum 2% stop-loss
max_sl_pct: float = 0.10 # Maximum 10% stop-loss
min_tp_pct: float = 0.02 # Minimum 2% take-profit
max_tp_pct: float = 0.15 # Maximum 15% take-profit
volatility_window: int = 24
# Funding Rate Filter
# OKX funding rates are typically 0.0001 (0.01%) per 8h
# Extreme funding is > 0.0005 (0.05%) which indicates crowded trade
funding_threshold: float = 0.0005 # 0.05% - filter extreme funding
# Trade Management
# Note: Setting min_hold_bars=0 and z_exit_threshold=0 gives best results
# The mean-reversion exit at Z=0 is the primary profit driver
min_hold_bars: int = 0 # Disabled - let mean reversion drive exits
switch_threshold: float = 999.0 # Disabled - don't switch mid-trade
cooldown_bars: int = 0 # Disabled - enter when signal appears
z_exit_threshold: float = 0.0 # Exit at Z=0 (mean reversion complete)
# Exchange
exchange_id: str = "okx"
timeframe: str = "1h"
def get_pair_count(self) -> int:
"""Calculate number of unique pairs from asset list."""
n = len(self.assets)
return n * (n - 1) // 2

View File

@@ -0,0 +1,173 @@
"""
Correlation Filter for Multi-Pair Divergence Strategy.
Calculates rolling correlation matrix and filters pairs
to avoid highly correlated positions.
"""
import pandas as pd
import numpy as np
from engine.logging_config import get_logger
from .config import MultiPairConfig
from .pair_scanner import TradingPair
logger = get_logger(__name__)
class CorrelationFilter:
"""
Calculates and filters based on asset correlations.
Uses rolling correlation of returns to identify assets
moving together, avoiding redundant positions.
"""
def __init__(self, config: MultiPairConfig):
self.config = config
self._correlation_matrix: pd.DataFrame | None = None
self._last_update_idx: int = -1
def calculate_correlation_matrix(
self,
price_data: dict[str, pd.Series],
current_idx: int | None = None
) -> pd.DataFrame:
"""
Calculate rolling correlation matrix between all assets.
Args:
price_data: Dictionary mapping asset symbols to price series
current_idx: Current bar index (for caching)
Returns:
Correlation matrix DataFrame
"""
# Use cached if recent
if (
current_idx is not None
and self._correlation_matrix is not None
and current_idx - self._last_update_idx < 24 # Update every 24 bars
):
return self._correlation_matrix
# Calculate returns
returns = {}
for symbol, prices in price_data.items():
returns[symbol] = prices.pct_change()
returns_df = pd.DataFrame(returns)
# Rolling correlation
window = self.config.correlation_window
# Get latest correlation (last row of rolling correlation)
if len(returns_df) >= window:
rolling_corr = returns_df.rolling(window=window).corr()
# Extract last timestamp correlation matrix
last_idx = returns_df.index[-1]
corr_matrix = rolling_corr.loc[last_idx]
else:
# Fallback to full-period correlation if not enough data
corr_matrix = returns_df.corr()
self._correlation_matrix = corr_matrix
if current_idx is not None:
self._last_update_idx = current_idx
return corr_matrix
def filter_pairs(
self,
pairs: list[TradingPair],
current_position_asset: str | None,
price_data: dict[str, pd.Series],
current_idx: int | None = None
) -> list[TradingPair]:
"""
Filter pairs based on correlation with current position.
If we have an open position in an asset, exclude pairs where
either asset is highly correlated with the held asset.
Args:
pairs: List of candidate pairs
current_position_asset: Currently held asset (or None)
price_data: Dictionary of price series by symbol
current_idx: Current bar index for caching
Returns:
Filtered list of pairs
"""
if current_position_asset is None:
return pairs
corr_matrix = self.calculate_correlation_matrix(price_data, current_idx)
threshold = self.config.correlation_threshold
filtered = []
for pair in pairs:
# Check correlation of base and quote with held asset
base_corr = self._get_correlation(
corr_matrix, pair.base_asset, current_position_asset
)
quote_corr = self._get_correlation(
corr_matrix, pair.quote_asset, current_position_asset
)
# Filter if either asset highly correlated with position
if abs(base_corr) > threshold or abs(quote_corr) > threshold:
logger.debug(
"Filtered %s: base_corr=%.2f, quote_corr=%.2f (held: %s)",
pair.name, base_corr, quote_corr, current_position_asset
)
continue
filtered.append(pair)
if len(filtered) < len(pairs):
logger.info(
"Correlation filter: %d/%d pairs remaining (held: %s)",
len(filtered), len(pairs), current_position_asset
)
return filtered
def _get_correlation(
self,
corr_matrix: pd.DataFrame,
asset1: str,
asset2: str
) -> float:
"""
Get correlation between two assets from matrix.
Args:
corr_matrix: Correlation matrix
asset1: First asset symbol
asset2: Second asset symbol
Returns:
Correlation coefficient (-1 to 1), or 0 if not found
"""
if asset1 == asset2:
return 1.0
try:
return corr_matrix.loc[asset1, asset2]
except KeyError:
return 0.0
def get_correlation_report(
self,
price_data: dict[str, pd.Series]
) -> pd.DataFrame:
"""
Generate a readable correlation report.
Args:
price_data: Dictionary of price series
Returns:
Correlation matrix as DataFrame
"""
return self.calculate_correlation_matrix(price_data)

View File

@@ -0,0 +1,311 @@
"""
Divergence Scorer for Multi-Pair Strategy.
Ranks pairs by divergence score and selects the best candidate.
"""
from dataclasses import dataclass
from typing import Optional
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import pickle
from pathlib import Path
from engine.logging_config import get_logger
from .config import MultiPairConfig
from .pair_scanner import TradingPair
logger = get_logger(__name__)
@dataclass
class DivergenceSignal:
"""
Signal for a divergent pair.
Attributes:
pair: Trading pair
z_score: Current Z-Score of the spread
probability: ML model probability of profitable reversion
divergence_score: Combined score (|z_score| * probability)
direction: 'long' or 'short' (relative to base asset)
base_price: Current price of base asset
quote_price: Current price of quote asset
atr: Average True Range in price units
atr_pct: ATR as percentage of price
"""
pair: TradingPair
z_score: float
probability: float
divergence_score: float
direction: str
base_price: float
quote_price: float
atr: float
atr_pct: float
timestamp: pd.Timestamp
class DivergenceScorer:
"""
Scores and ranks pairs by divergence potential.
Uses ML model predictions combined with Z-Score magnitude
to identify the most promising mean-reversion opportunity.
"""
def __init__(self, config: MultiPairConfig, model_path: str = "data/multi_pair_model.pkl"):
self.config = config
self.model_path = Path(model_path)
self.model: RandomForestClassifier | None = None
self.feature_cols: list[str] | None = None
self._load_model()
def _load_model(self) -> None:
"""Load pre-trained model if available."""
if self.model_path.exists():
try:
with open(self.model_path, 'rb') as f:
saved = pickle.load(f)
self.model = saved['model']
self.feature_cols = saved['feature_cols']
logger.info("Loaded model from %s", self.model_path)
except Exception as e:
logger.warning("Could not load model: %s", e)
def save_model(self) -> None:
"""Save trained model."""
if self.model is None:
return
self.model_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.model_path, 'wb') as f:
pickle.dump({
'model': self.model,
'feature_cols': self.feature_cols,
}, f)
logger.info("Saved model to %s", self.model_path)
def train_model(
self,
combined_features: pd.DataFrame,
pair_features: dict[str, pd.DataFrame]
) -> None:
"""
Train universal model on all pairs.
Args:
combined_features: Combined feature DataFrame from all pairs
pair_features: Individual pair feature DataFrames (for target calculation)
"""
logger.info("Training universal model on %d samples...", len(combined_features))
z_thresh = self.config.z_entry_threshold
horizon = self.config.horizon
profit_target = self.config.profit_target
# Calculate targets for each pair
all_targets = []
all_features = []
for pair_id, features in pair_features.items():
if len(features) < horizon + 50:
continue
spread = features['spread']
z_score = features['z_score']
# Future price movements
future_min = spread.rolling(window=horizon).min().shift(-horizon)
future_max = spread.rolling(window=horizon).max().shift(-horizon)
# Target labels
target_short = spread * (1 - profit_target)
target_long = spread * (1 + profit_target)
success_short = (z_score > z_thresh) & (future_min < target_short)
success_long = (z_score < -z_thresh) & (future_max > target_long)
targets = np.select([success_short, success_long], [1, 1], default=0)
# Valid mask (exclude rows without complete future data)
valid_mask = future_min.notna() & future_max.notna()
# Collect valid samples
valid_features = features[valid_mask]
valid_targets = targets[valid_mask.values]
if len(valid_features) > 0:
all_features.append(valid_features)
all_targets.extend(valid_targets)
if not all_features:
logger.warning("No valid training samples")
return
# Combine all training data
X_df = pd.concat(all_features, ignore_index=True)
y = np.array(all_targets)
# Get feature columns
exclude_cols = [
'pair_id', 'base_asset', 'quote_asset',
'spread', 'base_close', 'quote_close', 'base_volume'
]
self.feature_cols = [c for c in X_df.columns if c not in exclude_cols]
# Prepare features
X = X_df[self.feature_cols].fillna(0)
X = X.replace([np.inf, -np.inf], 0)
# Train model
self.model = RandomForestClassifier(
n_estimators=300,
max_depth=5,
min_samples_leaf=30,
class_weight={0: 1, 1: 3},
random_state=42
)
self.model.fit(X, y)
logger.info(
"Model trained on %d samples, %d features, %.1f%% positive class",
len(X), len(self.feature_cols), y.mean() * 100
)
self.save_model()
def score_pairs(
self,
pair_features: dict[str, pd.DataFrame],
pairs: list[TradingPair],
timestamp: pd.Timestamp | None = None
) -> list[DivergenceSignal]:
"""
Score all pairs and return ranked signals.
Args:
pair_features: Feature DataFrames by pair_id
pairs: List of TradingPair objects
timestamp: Current timestamp for feature extraction
Returns:
List of DivergenceSignal sorted by score (descending)
"""
if self.model is None:
logger.warning("Model not trained, returning empty signals")
return []
signals = []
pair_map = {p.pair_id: p for p in pairs}
for pair_id, features in pair_features.items():
if pair_id not in pair_map:
continue
pair = pair_map[pair_id]
# Get latest features
if timestamp is not None:
valid = features[features.index <= timestamp]
if len(valid) == 0:
continue
latest = valid.iloc[-1]
ts = valid.index[-1]
else:
latest = features.iloc[-1]
ts = features.index[-1]
z_score = latest['z_score']
# Skip if Z-score below threshold
if abs(z_score) < self.config.z_entry_threshold:
continue
# Prepare features for prediction
feature_row = latest[self.feature_cols].fillna(0).infer_objects(copy=False)
feature_row = feature_row.replace([np.inf, -np.inf], 0)
X = pd.DataFrame([feature_row.values], columns=self.feature_cols)
# Predict probability
prob = self.model.predict_proba(X)[0, 1]
# Skip if probability below threshold
if prob < self.config.prob_threshold:
continue
# Apply funding rate filter
# Block trades where funding opposes our direction
base_funding = latest.get('base_funding', 0) or 0
funding_thresh = self.config.funding_threshold
if z_score > 0: # Short signal
# High negative funding = shorts are paying -> skip
if base_funding < -funding_thresh:
logger.debug(
"Skipping %s short: funding too negative (%.4f)",
pair.name, base_funding
)
continue
else: # Long signal
# High positive funding = longs are paying -> skip
if base_funding > funding_thresh:
logger.debug(
"Skipping %s long: funding too positive (%.4f)",
pair.name, base_funding
)
continue
# Calculate divergence score
divergence_score = abs(z_score) * prob
# Determine direction
# Z > 0: Spread high (base expensive vs quote) -> Short base
# Z < 0: Spread low (base cheap vs quote) -> Long base
direction = 'short' if z_score > 0 else 'long'
signal = DivergenceSignal(
pair=pair,
z_score=z_score,
probability=prob,
divergence_score=divergence_score,
direction=direction,
base_price=latest['base_close'],
quote_price=latest['quote_close'],
atr=latest.get('atr_base', 0),
atr_pct=latest.get('atr_pct_base', 0.02),
timestamp=ts
)
signals.append(signal)
# Sort by divergence score (highest first)
signals.sort(key=lambda s: s.divergence_score, reverse=True)
if signals:
logger.debug(
"Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f)",
len(signals),
signals[0].pair.name,
signals[0].divergence_score,
signals[0].z_score,
signals[0].probability
)
return signals
def select_best_pair(
self,
signals: list[DivergenceSignal]
) -> DivergenceSignal | None:
"""
Select the best pair from scored signals.
Args:
signals: List of DivergenceSignal (pre-sorted by score)
Returns:
Best signal or None if no valid candidates
"""
if not signals:
return None
return signals[0]

View File

@@ -0,0 +1,433 @@
"""
Feature Engineering for Multi-Pair Divergence Strategy.
Calculates features for all pairs in the universe, including
spread technicals, volatility, and on-chain data.
"""
import pandas as pd
import numpy as np
import ta
from engine.logging_config import get_logger
from engine.data_manager import DataManager
from engine.market import MarketType
from .config import MultiPairConfig
from .pair_scanner import TradingPair
from .funding import FundingRateFetcher
logger = get_logger(__name__)
class MultiPairFeatureEngine:
"""
Calculates features for multiple trading pairs.
Generates consistent feature sets across all pairs for
the universal ML model.
"""
def __init__(self, config: MultiPairConfig):
self.config = config
self.dm = DataManager()
self.funding_fetcher = FundingRateFetcher()
self._funding_data: pd.DataFrame | None = None
def load_all_assets(
self,
start_date: str | None = None,
end_date: str | None = None
) -> dict[str, pd.DataFrame]:
"""
Load OHLCV data for all assets in the universe.
Args:
start_date: Start date filter (YYYY-MM-DD)
end_date: End date filter (YYYY-MM-DD)
Returns:
Dictionary mapping symbol to OHLCV DataFrame
"""
data = {}
market_type = MarketType.PERPETUAL
for symbol in self.config.assets:
try:
df = self.dm.load_data(
self.config.exchange_id,
symbol,
self.config.timeframe,
market_type
)
# Apply date filters
if start_date:
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
if end_date:
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
if len(df) >= 200: # Minimum data requirement
data[symbol] = df
logger.debug("Loaded %s: %d bars", symbol, len(df))
else:
logger.warning(
"Skipping %s: insufficient data (%d bars)",
symbol, len(df)
)
except FileNotFoundError:
logger.warning("Data not found for %s", symbol)
except Exception as e:
logger.error("Error loading %s: %s", symbol, e)
logger.info("Loaded %d/%d assets", len(data), len(self.config.assets))
return data
def load_funding_data(
self,
start_date: str | None = None,
end_date: str | None = None,
use_cache: bool = True
) -> pd.DataFrame:
"""
Load funding rate data for all assets.
Args:
start_date: Start date filter
end_date: End date filter
use_cache: Whether to use cached data
Returns:
DataFrame with funding rates for all assets
"""
self._funding_data = self.funding_fetcher.get_funding_data(
self.config.assets,
start_date=start_date,
end_date=end_date,
use_cache=use_cache
)
if self._funding_data is not None and not self._funding_data.empty:
logger.info(
"Loaded funding data: %d rows, %d assets",
len(self._funding_data),
len(self._funding_data.columns)
)
else:
logger.warning("No funding data available")
return self._funding_data
def calculate_pair_features(
self,
pair: TradingPair,
asset_data: dict[str, pd.DataFrame],
on_chain_data: pd.DataFrame | None = None
) -> pd.DataFrame | None:
"""
Calculate features for a single pair.
Args:
pair: Trading pair
asset_data: Dictionary of OHLCV DataFrames by symbol
on_chain_data: Optional on-chain data (funding, inflows)
Returns:
DataFrame with features, or None if insufficient data
"""
base = pair.base_asset
quote = pair.quote_asset
if base not in asset_data or quote not in asset_data:
return None
df_base = asset_data[base]
df_quote = asset_data[quote]
# Align indices
common_idx = df_base.index.intersection(df_quote.index)
if len(common_idx) < 200:
logger.debug("Pair %s: insufficient aligned data", pair.name)
return None
df_a = df_base.loc[common_idx]
df_b = df_quote.loc[common_idx]
# Calculate spread (base / quote)
spread = df_a['close'] / df_b['close']
# Z-Score
z_window = self.config.z_window
rolling_mean = spread.rolling(window=z_window).mean()
rolling_std = spread.rolling(window=z_window).std()
z_score = (spread - rolling_mean) / rolling_std
# Spread Technicals
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
spread_roc = spread.pct_change(periods=5) * 100
spread_change_1h = spread.pct_change(periods=1)
# Volume Analysis
vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10)
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10)
# Volatility
ret_a = df_a['close'].pct_change()
ret_b = df_b['close'].pct_change()
vol_a = ret_a.rolling(window=z_window).std()
vol_b = ret_b.rolling(window=z_window).std()
vol_spread_ratio = vol_a / (vol_b + 1e-10)
# Realized Volatility (for dynamic SL/TP)
realized_vol_a = ret_a.rolling(window=self.config.volatility_window).std()
realized_vol_b = ret_b.rolling(window=self.config.volatility_window).std()
# ATR (Average True Range) for dynamic stops
# ATR = average of max(high-low, |high-prev_close|, |low-prev_close|)
high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close']
high_b, low_b, close_b = df_b['high'], df_b['low'], df_b['close']
# True Range for base asset
tr_a = pd.concat([
high_a - low_a,
(high_a - close_a.shift(1)).abs(),
(low_a - close_a.shift(1)).abs()
], axis=1).max(axis=1)
atr_a = tr_a.rolling(window=self.config.atr_period).mean()
# True Range for quote asset
tr_b = pd.concat([
high_b - low_b,
(high_b - close_b.shift(1)).abs(),
(low_b - close_b.shift(1)).abs()
], axis=1).max(axis=1)
atr_b = tr_b.rolling(window=self.config.atr_period).mean()
# ATR as percentage of price (normalized)
atr_pct_a = atr_a / close_a
atr_pct_b = atr_b / close_b
# Build feature DataFrame
features = pd.DataFrame(index=common_idx)
features['pair_id'] = pair.pair_id
features['base_asset'] = base
features['quote_asset'] = quote
# Price data (for reference, not features)
features['spread'] = spread
features['base_close'] = df_a['close']
features['quote_close'] = df_b['close']
features['base_volume'] = df_a['volume']
# Core Features
features['z_score'] = z_score
features['spread_rsi'] = spread_rsi
features['spread_roc'] = spread_roc
features['spread_change_1h'] = spread_change_1h
features['vol_ratio'] = vol_ratio
features['vol_ratio_rel'] = vol_ratio_rel
features['vol_diff_ratio'] = vol_spread_ratio
# Volatility for SL/TP
features['realized_vol_base'] = realized_vol_a
features['realized_vol_quote'] = realized_vol_b
features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2
# ATR for dynamic stops (in price units and as percentage)
features['atr_base'] = atr_a
features['atr_quote'] = atr_b
features['atr_pct_base'] = atr_pct_a
features['atr_pct_quote'] = atr_pct_b
features['atr_pct_avg'] = (atr_pct_a + atr_pct_b) / 2
# Pair encoding (for universal model)
# Using base and quote indices for hierarchical encoding
assets = self.config.assets
features['base_idx'] = assets.index(base) if base in assets else -1
features['quote_idx'] = assets.index(quote) if quote in assets else -1
# Add funding and on-chain features
# Funding data is always added from self._funding_data (OKX, all 10 assets)
# On-chain data is optional (CryptoQuant, BTC/ETH only)
features = self._add_on_chain_features(
features, on_chain_data, base, quote
)
# Drop rows with NaN in core features only (not funding/on-chain)
core_cols = [
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg',
'atr_base', 'atr_pct_base' # ATR is core for SL/TP
]
features = features.dropna(subset=core_cols)
# Fill missing funding/on-chain features with 0 (neutral)
optional_cols = [
'base_funding', 'quote_funding', 'funding_diff', 'funding_avg',
'base_inflow', 'quote_inflow', 'inflow_ratio'
]
for col in optional_cols:
if col in features.columns:
features[col] = features[col].fillna(0)
return features
def calculate_all_pair_features(
self,
pairs: list[TradingPair],
asset_data: dict[str, pd.DataFrame],
on_chain_data: pd.DataFrame | None = None
) -> dict[str, pd.DataFrame]:
"""
Calculate features for all pairs.
Args:
pairs: List of trading pairs
asset_data: Dictionary of OHLCV DataFrames
on_chain_data: Optional on-chain data
Returns:
Dictionary mapping pair_id to feature DataFrame
"""
all_features = {}
for pair in pairs:
features = self.calculate_pair_features(
pair, asset_data, on_chain_data
)
if features is not None and len(features) > 0:
all_features[pair.pair_id] = features
logger.info(
"Calculated features for %d/%d pairs",
len(all_features), len(pairs)
)
return all_features
def get_combined_features(
self,
pair_features: dict[str, pd.DataFrame],
timestamp: pd.Timestamp | None = None
) -> pd.DataFrame:
"""
Combine all pair features into a single DataFrame.
Useful for batch model prediction across all pairs.
Args:
pair_features: Dictionary of feature DataFrames by pair_id
timestamp: Optional specific timestamp to filter to
Returns:
Combined DataFrame with all pairs as rows
"""
if not pair_features:
return pd.DataFrame()
if timestamp is not None:
# Get latest row from each pair at or before timestamp
rows = []
for pair_id, features in pair_features.items():
valid = features[features.index <= timestamp]
if len(valid) > 0:
row = valid.iloc[-1:].copy()
rows.append(row)
if rows:
return pd.concat(rows, ignore_index=False)
return pd.DataFrame()
# Combine all features (for training)
return pd.concat(pair_features.values(), ignore_index=False)
def _add_on_chain_features(
self,
features: pd.DataFrame,
on_chain_data: pd.DataFrame | None,
base_asset: str,
quote_asset: str
) -> pd.DataFrame:
"""
Add on-chain and funding rate features for the pair.
Uses funding data from OKX (all 10 assets) and on-chain data
from CryptoQuant (BTC/ETH only for inflows).
"""
base_short = base_asset.replace('-USDT', '').lower()
quote_short = quote_asset.replace('-USDT', '').lower()
# Add funding rates from cached funding data
if self._funding_data is not None and not self._funding_data.empty:
funding_aligned = self._funding_data.reindex(
features.index, method='ffill'
)
base_funding_col = f'{base_short}_funding'
quote_funding_col = f'{quote_short}_funding'
if base_funding_col in funding_aligned.columns:
features['base_funding'] = funding_aligned[base_funding_col]
if quote_funding_col in funding_aligned.columns:
features['quote_funding'] = funding_aligned[quote_funding_col]
# Funding difference (positive = base has higher funding)
if 'base_funding' in features.columns and 'quote_funding' in features.columns:
features['funding_diff'] = (
features['base_funding'] - features['quote_funding']
)
# Funding sentiment: average of both assets
features['funding_avg'] = (
features['base_funding'] + features['quote_funding']
) / 2
# Add on-chain features from CryptoQuant (BTC/ETH only)
if on_chain_data is not None and not on_chain_data.empty:
cq_aligned = on_chain_data.reindex(features.index, method='ffill')
# Inflows (only available for BTC/ETH)
base_inflow_col = f'{base_short}_inflow'
quote_inflow_col = f'{quote_short}_inflow'
if base_inflow_col in cq_aligned.columns:
features['base_inflow'] = cq_aligned[base_inflow_col]
if quote_inflow_col in cq_aligned.columns:
features['quote_inflow'] = cq_aligned[quote_inflow_col]
if 'base_inflow' in features.columns and 'quote_inflow' in features.columns:
features['inflow_ratio'] = (
features['base_inflow'] /
(features['quote_inflow'] + 1)
)
return features
def get_feature_columns(self) -> list[str]:
"""
Get list of feature columns for ML model.
Excludes metadata and target-related columns.
Returns:
List of feature column names
"""
# Core features (always present)
core_features = [
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg',
'base_idx', 'quote_idx'
]
# Funding features (now available for all 10 assets via OKX)
funding_features = [
'base_funding', 'quote_funding', 'funding_diff', 'funding_avg'
]
# On-chain features (BTC/ETH only via CryptoQuant)
onchain_features = [
'base_inflow', 'quote_inflow', 'inflow_ratio'
]
return core_features + funding_features + onchain_features

View File

@@ -0,0 +1,272 @@
"""
Funding Rate Fetcher for Multi-Pair Strategy.
Fetches historical funding rates from OKX for all assets.
CryptoQuant only supports BTC/ETH, so we use OKX for the full universe.
"""
import time
from pathlib import Path
from datetime import datetime, timezone
import ccxt
import pandas as pd
from engine.logging_config import get_logger
logger = get_logger(__name__)
class FundingRateFetcher:
"""
Fetches and caches funding rate data from OKX.
OKX funding rates are settled every 8 hours (00:00, 08:00, 16:00 UTC).
This fetcher retrieves historical funding rate data and aligns it
to hourly candles for use in the multi-pair strategy.
"""
def __init__(self, cache_dir: str = "data/funding"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.exchange: ccxt.okx | None = None
def _init_exchange(self) -> None:
"""Initialize OKX exchange connection."""
if self.exchange is None:
self.exchange = ccxt.okx({
'enableRateLimit': True,
'options': {'defaultType': 'swap'}
})
self.exchange.load_markets()
def fetch_funding_history(
self,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 100
) -> pd.DataFrame:
"""
Fetch historical funding rates for a symbol.
Args:
symbol: Asset symbol (e.g., 'BTC-USDT')
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
limit: Max records per request
Returns:
DataFrame with funding rate history
"""
self._init_exchange()
# Convert symbol format
base = symbol.replace('-USDT', '')
okx_symbol = f"{base}/USDT:USDT"
try:
# OKX funding rate history endpoint
# Uses fetch_funding_rate_history if available
all_funding = []
# Parse dates
if start_date:
since = self.exchange.parse8601(f"{start_date}T00:00:00Z")
else:
# Default to 1 year ago
since = self.exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
if end_date:
until = self.exchange.parse8601(f"{end_date}T23:59:59Z")
else:
until = self.exchange.milliseconds()
# Fetch in batches
current_since = since
while current_since < until:
try:
funding = self.exchange.fetch_funding_rate_history(
okx_symbol,
since=current_since,
limit=limit
)
if not funding:
break
all_funding.extend(funding)
# Move to next batch
last_ts = funding[-1]['timestamp']
if last_ts <= current_since:
break
current_since = last_ts + 1
time.sleep(0.1) # Rate limit
except Exception as e:
logger.warning(
"Error fetching funding batch for %s: %s",
symbol, str(e)[:50]
)
break
if not all_funding:
return pd.DataFrame()
# Convert to DataFrame
df = pd.DataFrame(all_funding)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df.set_index('timestamp', inplace=True)
df = df[['fundingRate']].rename(columns={'fundingRate': 'funding_rate'})
df.sort_index(inplace=True)
# Remove duplicates
df = df[~df.index.duplicated(keep='first')]
logger.info("Fetched %d funding records for %s", len(df), symbol)
return df
except Exception as e:
logger.error("Failed to fetch funding for %s: %s", symbol, e)
return pd.DataFrame()
def fetch_all_assets(
self,
assets: list[str],
start_date: str | None = None,
end_date: str | None = None
) -> pd.DataFrame:
"""
Fetch funding rates for all assets and combine.
Args:
assets: List of asset symbols (e.g., ['BTC-USDT', 'ETH-USDT'])
start_date: Start date
end_date: End date
Returns:
Combined DataFrame with columns like 'btc_funding', 'eth_funding', etc.
"""
combined = pd.DataFrame()
for symbol in assets:
df = self.fetch_funding_history(symbol, start_date, end_date)
if df.empty:
continue
# Rename column
asset_name = symbol.replace('-USDT', '').lower()
col_name = f"{asset_name}_funding"
df = df.rename(columns={'funding_rate': col_name})
if combined.empty:
combined = df
else:
combined = combined.join(df, how='outer')
time.sleep(0.2) # Be nice to API
# Forward fill to hourly (funding is every 8h)
if not combined.empty:
combined = combined.sort_index()
combined = combined.ffill()
return combined
def save_to_cache(self, df: pd.DataFrame, filename: str = "funding_rates.csv") -> None:
"""Save funding data to cache file."""
path = self.cache_dir / filename
df.to_csv(path)
logger.info("Saved funding rates to %s", path)
def load_from_cache(self, filename: str = "funding_rates.csv") -> pd.DataFrame | None:
"""Load funding data from cache if available."""
path = self.cache_dir / filename
if path.exists():
df = pd.read_csv(path, index_col='timestamp', parse_dates=True)
logger.info("Loaded funding rates from cache: %d rows", len(df))
return df
return None
def get_funding_data(
self,
assets: list[str],
start_date: str | None = None,
end_date: str | None = None,
use_cache: bool = True,
force_refresh: bool = False
) -> pd.DataFrame:
"""
Get funding data, using cache if available.
Args:
assets: List of asset symbols
start_date: Start date
end_date: End date
use_cache: Whether to use cached data
force_refresh: Force refresh even if cache exists
Returns:
DataFrame with funding rates for all assets
"""
cache_file = "funding_rates.csv"
# Try cache first
if use_cache and not force_refresh:
cached = self.load_from_cache(cache_file)
if cached is not None:
# Check if cache covers requested range
if start_date and end_date:
start_ts = pd.Timestamp(start_date, tz='UTC')
end_ts = pd.Timestamp(end_date, tz='UTC')
if cached.index.min() <= start_ts and cached.index.max() >= end_ts:
# Filter to requested range
return cached[(cached.index >= start_ts) & (cached.index <= end_ts)]
# Fetch fresh data
logger.info("Fetching fresh funding rate data...")
df = self.fetch_all_assets(assets, start_date, end_date)
if not df.empty and use_cache:
self.save_to_cache(df, cache_file)
return df
def download_funding_data():
"""Download funding data for all multi-pair assets."""
from strategies.multi_pair.config import MultiPairConfig
config = MultiPairConfig()
fetcher = FundingRateFetcher()
# Fetch last year of data
end_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
start_date = (datetime.now(timezone.utc) - pd.Timedelta(days=365)).strftime("%Y-%m-%d")
logger.info("Downloading funding rates for %d assets...", len(config.assets))
logger.info("Date range: %s to %s", start_date, end_date)
df = fetcher.get_funding_data(
config.assets,
start_date=start_date,
end_date=end_date,
force_refresh=True
)
if not df.empty:
logger.info("Downloaded %d funding rate records", len(df))
logger.info("Columns: %s", list(df.columns))
else:
logger.warning("No funding data downloaded")
return df
if __name__ == "__main__":
from engine.logging_config import setup_logging
setup_logging()
download_funding_data()

View File

@@ -0,0 +1,168 @@
"""
Pair Scanner for Multi-Pair Divergence Strategy.
Generates all possible pairs from asset universe and checks tradeability.
"""
from dataclasses import dataclass
from itertools import combinations
from typing import Optional
import ccxt
from engine.logging_config import get_logger
from .config import MultiPairConfig
logger = get_logger(__name__)
@dataclass
class TradingPair:
"""
Represents a tradeable pair for spread analysis.
Attributes:
base_asset: First asset in the pair (numerator)
quote_asset: Second asset in the pair (denominator)
pair_id: Unique identifier for the pair
is_direct: Whether pair can be traded directly on exchange
exchange_symbol: Symbol for direct trading (if available)
"""
base_asset: str
quote_asset: str
pair_id: str
is_direct: bool = False
exchange_symbol: Optional[str] = None
@property
def name(self) -> str:
"""Human-readable pair name."""
return f"{self.base_asset}/{self.quote_asset}"
def __hash__(self):
return hash(self.pair_id)
def __eq__(self, other):
if not isinstance(other, TradingPair):
return False
return self.pair_id == other.pair_id
class PairScanner:
"""
Scans and generates tradeable pairs from asset universe.
Checks OKX for directly tradeable cross-pairs and generates
synthetic pairs via USDT for others.
"""
def __init__(self, config: MultiPairConfig):
self.config = config
self.exchange: Optional[ccxt.Exchange] = None
self._available_markets: set[str] = set()
def _init_exchange(self) -> None:
"""Initialize exchange connection for market lookup."""
if self.exchange is None:
exchange_class = getattr(ccxt, self.config.exchange_id)
self.exchange = exchange_class({'enableRateLimit': True})
self.exchange.load_markets()
self._available_markets = set(self.exchange.symbols)
logger.info(
"Loaded %d markets from %s",
len(self._available_markets),
self.config.exchange_id
)
def generate_pairs(self, check_exchange: bool = True) -> list[TradingPair]:
"""
Generate all unique pairs from asset universe.
Args:
check_exchange: Whether to check OKX for direct trading
Returns:
List of TradingPair objects
"""
if check_exchange:
self._init_exchange()
pairs = []
assets = self.config.assets
for base, quote in combinations(assets, 2):
pair_id = f"{base}__{quote}"
# Check if directly tradeable as cross-pair on OKX
is_direct = False
exchange_symbol = None
if check_exchange:
# Check perpetual cross-pair (e.g., ETH/BTC:BTC)
# OKX perpetuals are typically quoted in USDT
# Cross-pairs like ETH/BTC are less common
cross_symbol = f"{base.replace('-USDT', '')}/{quote.replace('-USDT', '')}:USDT"
if cross_symbol in self._available_markets:
is_direct = True
exchange_symbol = cross_symbol
pair = TradingPair(
base_asset=base,
quote_asset=quote,
pair_id=pair_id,
is_direct=is_direct,
exchange_symbol=exchange_symbol
)
pairs.append(pair)
# Log summary
direct_count = sum(1 for p in pairs if p.is_direct)
logger.info(
"Generated %d pairs: %d direct, %d synthetic",
len(pairs), direct_count, len(pairs) - direct_count
)
return pairs
def get_required_symbols(self, pairs: list[TradingPair]) -> list[str]:
"""
Get list of symbols needed to calculate all pair spreads.
For synthetic pairs, we need both USDT pairs.
For direct pairs, we still load USDT pairs for simplicity.
Args:
pairs: List of trading pairs
Returns:
List of unique symbols to load (e.g., ['BTC-USDT', 'ETH-USDT'])
"""
symbols = set()
for pair in pairs:
symbols.add(pair.base_asset)
symbols.add(pair.quote_asset)
return list(symbols)
def filter_by_assets(
self,
pairs: list[TradingPair],
exclude_assets: list[str]
) -> list[TradingPair]:
"""
Filter pairs that contain any of the excluded assets.
Args:
pairs: List of trading pairs
exclude_assets: Assets to exclude
Returns:
Filtered list of pairs
"""
if not exclude_assets:
return pairs
exclude_set = set(exclude_assets)
return [
p for p in pairs
if p.base_asset not in exclude_set
and p.quote_asset not in exclude_set
]

View File

@@ -0,0 +1,525 @@
"""
Multi-Pair Divergence Selection Strategy.
Main strategy class that orchestrates pair scanning, feature calculation,
model training, and signal generation for backtesting.
"""
from dataclasses import dataclass
from typing import Optional
import pandas as pd
import numpy as np
from strategies.base import BaseStrategy
from engine.market import MarketType
from engine.logging_config import get_logger
from .config import MultiPairConfig
from .pair_scanner import PairScanner, TradingPair
from .correlation import CorrelationFilter
from .feature_engine import MultiPairFeatureEngine
from .divergence_scorer import DivergenceScorer, DivergenceSignal
logger = get_logger(__name__)
@dataclass
class PositionState:
"""Tracks current position state."""
pair: TradingPair | None = None
direction: str | None = None # 'long' or 'short'
entry_price: float = 0.0
entry_idx: int = -1
stop_loss: float = 0.0
take_profit: float = 0.0
atr: float = 0.0 # ATR at entry for reference
last_exit_idx: int = -100 # For cooldown tracking
class MultiPairDivergenceStrategy(BaseStrategy):
"""
Multi-Pair Divergence Selection Strategy.
Scans multiple cryptocurrency pairs for spread divergence,
selects the most divergent pair using ML-enhanced scoring,
and trades mean-reversion opportunities.
Key Features:
- Universal ML model across all pairs
- Correlation-based pair filtering
- Dynamic SL/TP based on volatility
- Walk-forward training
"""
def __init__(
self,
config: MultiPairConfig | None = None,
model_path: str = "data/multi_pair_model.pkl"
):
super().__init__()
self.config = config or MultiPairConfig()
# Initialize components
self.pair_scanner = PairScanner(self.config)
self.correlation_filter = CorrelationFilter(self.config)
self.feature_engine = MultiPairFeatureEngine(self.config)
self.divergence_scorer = DivergenceScorer(self.config, model_path)
# Strategy configuration
self.default_market_type = MarketType.PERPETUAL
self.default_leverage = 1
# Runtime state
self.pairs: list[TradingPair] = []
self.asset_data: dict[str, pd.DataFrame] = {}
self.pair_features: dict[str, pd.DataFrame] = {}
self.position = PositionState()
self.train_end_idx: int = 0
def run(self, close: pd.Series, **kwargs) -> tuple:
"""
Execute the multi-pair divergence strategy.
This method is called by the backtester with the primary asset's
close prices. For multi-pair, we load all assets internally.
Args:
close: Primary close prices (used for index alignment)
**kwargs: Additional data (high, low, volume)
Returns:
Tuple of (long_entries, long_exits, short_entries, short_exits, size)
"""
logger.info("Starting Multi-Pair Divergence Strategy")
# 1. Load all asset data
start_date = close.index.min().strftime("%Y-%m-%d")
end_date = close.index.max().strftime("%Y-%m-%d")
self.asset_data = self.feature_engine.load_all_assets(
start_date=start_date,
end_date=end_date
)
# 1b. Load funding rate data for all assets
self.feature_engine.load_funding_data(
start_date=start_date,
end_date=end_date,
use_cache=True
)
if len(self.asset_data) < 2:
logger.error("Insufficient assets loaded, need at least 2")
return self._empty_signals(close)
# 2. Generate pairs
self.pairs = self.pair_scanner.generate_pairs(check_exchange=False)
# Filter to pairs with available data
available_assets = set(self.asset_data.keys())
self.pairs = [
p for p in self.pairs
if p.base_asset in available_assets
and p.quote_asset in available_assets
]
logger.info("Trading %d pairs from %d assets", len(self.pairs), len(self.asset_data))
# 3. Calculate features for all pairs
self.pair_features = self.feature_engine.calculate_all_pair_features(
self.pairs, self.asset_data
)
if not self.pair_features:
logger.error("No pair features calculated")
return self._empty_signals(close)
# 4. Align to common index
common_index = self._get_common_index()
if len(common_index) < 200:
logger.error("Insufficient common data across pairs")
return self._empty_signals(close)
# 5. Walk-forward split
n_samples = len(common_index)
train_size = int(n_samples * self.config.train_ratio)
self.train_end_idx = train_size
train_end_date = common_index[train_size - 1]
test_start_date = common_index[train_size]
logger.info(
"Walk-Forward Split: Train=%d bars (until %s), Test=%d bars (from %s)",
train_size, train_end_date.strftime('%Y-%m-%d'),
n_samples - train_size, test_start_date.strftime('%Y-%m-%d')
)
# 6. Train model on training period
if self.divergence_scorer.model is None:
train_features = {
pid: feat[feat.index <= train_end_date]
for pid, feat in self.pair_features.items()
}
combined = self.feature_engine.get_combined_features(train_features)
self.divergence_scorer.train_model(combined, train_features)
# 7. Generate signals for test period
return self._generate_signals(common_index, train_size, close)
def _generate_signals(
self,
index: pd.DatetimeIndex,
train_size: int,
reference_close: pd.Series
) -> tuple:
"""
Generate entry/exit signals for the test period.
Iterates through each bar in the test period, scoring pairs
and generating signals based on divergence scores.
"""
# Initialize signal arrays aligned to reference close
long_entries = pd.Series(False, index=reference_close.index)
long_exits = pd.Series(False, index=reference_close.index)
short_entries = pd.Series(False, index=reference_close.index)
short_exits = pd.Series(False, index=reference_close.index)
size = pd.Series(1.0, index=reference_close.index)
# Track position state
self.position = PositionState()
# Price data for correlation calculation
price_data = {
symbol: df['close'] for symbol, df in self.asset_data.items()
}
# Iterate through test period
test_indices = index[train_size:]
trade_count = 0
for i, timestamp in enumerate(test_indices):
current_idx = train_size + i
# Check exit conditions first
if self.position.pair is not None:
# Enforce minimum hold period
bars_held = current_idx - self.position.entry_idx
if bars_held < self.config.min_hold_bars:
# Only allow SL/TP exits during min hold period
should_exit, exit_reason = self._check_sl_tp_only(timestamp)
else:
should_exit, exit_reason = self._check_exit(timestamp)
if should_exit:
# Map exit signal to reference index
if timestamp in reference_close.index:
if self.position.direction == 'long':
long_exits.loc[timestamp] = True
else:
short_exits.loc[timestamp] = True
logger.debug(
"Exit %s %s at %s: %s (held %d bars)",
self.position.direction,
self.position.pair.name,
timestamp.strftime('%Y-%m-%d %H:%M'),
exit_reason,
bars_held
)
self.position = PositionState(last_exit_idx=current_idx)
# Score pairs (with correlation filter if position exists)
held_asset = None
if self.position.pair is not None:
held_asset = self.position.pair.base_asset
# Filter pairs by correlation
candidate_pairs = self.correlation_filter.filter_pairs(
self.pairs,
held_asset,
price_data,
current_idx
)
# Get candidate features
candidate_features = {
pid: feat for pid, feat in self.pair_features.items()
if any(p.pair_id == pid for p in candidate_pairs)
}
# Score pairs
signals = self.divergence_scorer.score_pairs(
candidate_features,
candidate_pairs,
timestamp
)
# Get best signal
best = self.divergence_scorer.select_best_pair(signals)
if best is None:
continue
# Check if we should switch positions or enter new
should_enter = False
# Check cooldown
bars_since_exit = current_idx - self.position.last_exit_idx
in_cooldown = bars_since_exit < self.config.cooldown_bars
if self.position.pair is None and not in_cooldown:
# No position and not in cooldown, can enter
should_enter = True
elif self.position.pair is not None:
# Check if we should switch (requires min hold + significant improvement)
bars_held = current_idx - self.position.entry_idx
current_score = self._get_current_score(timestamp)
if (bars_held >= self.config.min_hold_bars and
best.divergence_score > current_score * self.config.switch_threshold):
# New opportunity is significantly better
if timestamp in reference_close.index:
if self.position.direction == 'long':
long_exits.loc[timestamp] = True
else:
short_exits.loc[timestamp] = True
self.position = PositionState(last_exit_idx=current_idx)
should_enter = True
if should_enter:
# Calculate ATR-based dynamic SL/TP
sl_price, tp_price = self._calculate_sl_tp(
best.base_price,
best.direction,
best.atr,
best.atr_pct
)
# Set position
self.position = PositionState(
pair=best.pair,
direction=best.direction,
entry_price=best.base_price,
entry_idx=current_idx,
stop_loss=sl_price,
take_profit=tp_price,
atr=best.atr
)
# Calculate position size based on divergence
pos_size = self._calculate_size(best.divergence_score)
# Generate entry signal
if timestamp in reference_close.index:
if best.direction == 'long':
long_entries.loc[timestamp] = True
else:
short_entries.loc[timestamp] = True
size.loc[timestamp] = pos_size
trade_count += 1
logger.debug(
"Entry %s %s at %s: z=%.2f, prob=%.2f, score=%.3f",
best.direction,
best.pair.name,
timestamp.strftime('%Y-%m-%d %H:%M'),
best.z_score,
best.probability,
best.divergence_score
)
logger.info("Generated %d trades in test period", trade_count)
return long_entries, long_exits, short_entries, short_exits, size
def _check_exit(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
"""
Check if current position should be exited.
Exit conditions:
1. Z-Score reverted to mean (|Z| < threshold)
2. Stop-loss hit
3. Take-profit hit
Returns:
Tuple of (should_exit, reason)
"""
if self.position.pair is None:
return False, ""
pair_id = self.position.pair.pair_id
if pair_id not in self.pair_features:
return True, "pair_data_missing"
features = self.pair_features[pair_id]
valid = features[features.index <= timestamp]
if len(valid) == 0:
return True, "no_data"
latest = valid.iloc[-1]
z_score = latest['z_score']
current_price = latest['base_close']
# Check mean reversion (primary exit)
if abs(z_score) < self.config.z_exit_threshold:
return True, f"mean_reversion (z={z_score:.2f})"
# Check SL/TP
return self._check_sl_tp(current_price)
def _check_sl_tp_only(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
"""
Check only stop-loss and take-profit conditions.
Used during minimum hold period.
"""
if self.position.pair is None:
return False, ""
pair_id = self.position.pair.pair_id
if pair_id not in self.pair_features:
return True, "pair_data_missing"
features = self.pair_features[pair_id]
valid = features[features.index <= timestamp]
if len(valid) == 0:
return True, "no_data"
latest = valid.iloc[-1]
current_price = latest['base_close']
return self._check_sl_tp(current_price)
def _check_sl_tp(self, current_price: float) -> tuple[bool, str]:
"""Check stop-loss and take-profit levels."""
if self.position.direction == 'long':
if current_price <= self.position.stop_loss:
return True, f"stop_loss ({current_price:.2f} <= {self.position.stop_loss:.2f})"
if current_price >= self.position.take_profit:
return True, f"take_profit ({current_price:.2f} >= {self.position.take_profit:.2f})"
else: # short
if current_price >= self.position.stop_loss:
return True, f"stop_loss ({current_price:.2f} >= {self.position.stop_loss:.2f})"
if current_price <= self.position.take_profit:
return True, f"take_profit ({current_price:.2f} <= {self.position.take_profit:.2f})"
return False, ""
def _get_current_score(self, timestamp: pd.Timestamp) -> float:
"""Get current position's divergence score for comparison."""
if self.position.pair is None:
return 0.0
pair_id = self.position.pair.pair_id
if pair_id not in self.pair_features:
return 0.0
features = self.pair_features[pair_id]
valid = features[features.index <= timestamp]
if len(valid) == 0:
return 0.0
latest = valid.iloc[-1]
z_score = abs(latest['z_score'])
# Re-score with model
if self.divergence_scorer.model is not None:
feature_row = latest[self.divergence_scorer.feature_cols].fillna(0)
feature_row = feature_row.replace([np.inf, -np.inf], 0)
X = pd.DataFrame(
[feature_row.values],
columns=self.divergence_scorer.feature_cols
)
prob = self.divergence_scorer.model.predict_proba(X)[0, 1]
return z_score * prob
return z_score * 0.5
def _calculate_sl_tp(
self,
entry_price: float,
direction: str,
atr: float,
atr_pct: float
) -> tuple[float, float]:
"""
Calculate ATR-based dynamic stop-loss and take-profit prices.
Uses ATR (Average True Range) to set stops that adapt to
each asset's volatility. More volatile assets get wider stops.
Args:
entry_price: Entry price
direction: 'long' or 'short'
atr: ATR in price units
atr_pct: ATR as percentage of price
Returns:
Tuple of (stop_loss_price, take_profit_price)
"""
# Calculate SL/TP as ATR multiples
if atr > 0 and atr_pct > 0:
# ATR-based calculation
sl_distance = atr * self.config.sl_atr_multiplier
tp_distance = atr * self.config.tp_atr_multiplier
# Convert to percentage for bounds checking
sl_pct = sl_distance / entry_price
tp_pct = tp_distance / entry_price
else:
# Fallback to fixed percentages if ATR unavailable
sl_pct = self.config.base_sl_pct
tp_pct = self.config.base_tp_pct
# Apply bounds to prevent extreme stops
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
# Calculate actual prices
if direction == 'long':
stop_loss = entry_price * (1 - sl_pct)
take_profit = entry_price * (1 + tp_pct)
else: # short
stop_loss = entry_price * (1 + sl_pct)
take_profit = entry_price * (1 - tp_pct)
return stop_loss, take_profit
def _calculate_size(self, divergence_score: float) -> float:
"""
Calculate position size based on divergence score.
Higher divergence = larger position (up to 2x).
"""
# Base score threshold (Z=1.0, prob=0.5 -> score=0.5)
base_threshold = 0.5
# Scale factor
if divergence_score <= base_threshold:
return 1.0
# Linear scaling: 1.0 at threshold, up to 2.0 at 2x threshold
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
return min(scale, 2.0)
def _get_common_index(self) -> pd.DatetimeIndex:
"""Get the intersection of all pair feature indices."""
if not self.pair_features:
return pd.DatetimeIndex([])
common = None
for features in self.pair_features.values():
if common is None:
common = features.index
else:
common = common.intersection(features.index)
return common.sort_values()
def _empty_signals(self, close: pd.Series) -> tuple:
"""Return empty signal arrays."""
empty = self.create_empty_signals(close)
size = pd.Series(1.0, index=close.index)
return empty, empty, empty, empty, size

View File

@@ -0,0 +1,321 @@
# PRD: Multi-Pair Divergence Selection Strategy
## 1. Introduction / Overview
This document describes the **Multi-Pair Divergence Selection Strategy**, an extension of the existing BTC/ETH regime reversion system. The strategy expands spread analysis to the **top 10 cryptocurrencies by market cap**, calculates divergence scores for all tradeable pairs, and dynamically selects the **most divergent pair** for trading.
The core hypothesis: by scanning multiple pairs simultaneously, we can identify stronger mean-reversion opportunities than focusing on a single pair, improving net PnL while maintaining the proven ML-based regime detection approach.
---
## 2. Goals
1. **Extend regime detection** to top 10 market cap cryptocurrencies
2. **Dynamically select** the most divergent tradeable pair each cycle
3. **Integrate volatility** into dynamic SL/TP calculations
4. **Filter correlated pairs** to avoid redundant positions
5. **Improve net PnL** compared to single-pair BTC/ETH strategy
6. **Backtest-first** implementation with walk-forward validation
---
## 3. User Stories
### US-1: Multi-Pair Analysis
> As a trader, I want the system to analyze spread divergence across multiple cryptocurrency pairs so that I can identify the best trading opportunity at any given moment.
### US-2: Dynamic Pair Selection
> As a trader, I want the system to automatically select and trade the pair with the highest divergence score (combination of Z-score magnitude and ML probability) so that I maximize mean-reversion profit potential.
### US-3: Volatility-Adjusted Risk
> As a trader, I want stop-loss and take-profit levels to adapt to each pair's volatility so that I avoid being stopped out prematurely on volatile assets while protecting profits on stable ones.
### US-4: Correlation Filtering
> As a trader, I want the system to avoid selecting pairs that are highly correlated with my current position so that I don't inadvertently double-down on the same market exposure.
### US-5: Backtest Validation
> As a researcher, I want to backtest this multi-pair strategy with walk-forward training so that I can validate improvement over the single-pair baseline without look-ahead bias.
---
## 4. Functional Requirements
### 4.1 Data Management
| ID | Requirement |
|----|-------------|
| FR-1.1 | System must support loading OHLCV data for top 10 market cap cryptocurrencies |
| FR-1.2 | Target assets: BTC, ETH, SOL, XRP, BNB, DOGE, ADA, AVAX, LINK, DOT (configurable) |
| FR-1.3 | System must identify all directly tradeable cross-pairs on OKX perpetuals |
| FR-1.4 | System must align timestamps across all pairs for synchronized analysis |
| FR-1.5 | System must handle missing data gracefully (skip pair if insufficient history) |
### 4.2 Pair Generation
| ID | Requirement |
|----|-------------|
| FR-2.1 | Generate all unique pairs from asset universe: N*(N-1)/2 pairs (e.g., 45 pairs for 10 assets) |
| FR-2.2 | Filter pairs to only those directly tradeable on OKX (no USDT intermediate) |
| FR-2.3 | Fallback: If cross-pair not available, calculate synthetic spread via USDT pairs |
| FR-2.4 | Store pair metadata: base asset, quote asset, exchange symbol, tradeable flag |
### 4.3 Feature Engineering (Per Pair)
| ID | Requirement |
|----|-------------|
| FR-3.1 | Calculate spread ratio: `asset_a_close / asset_b_close` |
| FR-3.2 | Calculate Z-Score with configurable rolling window (default: 24h) |
| FR-3.3 | Calculate spread technicals: RSI(14), ROC(5), 1h change |
| FR-3.4 | Calculate volume ratio and relative volume |
| FR-3.5 | Calculate volatility ratio: `std(returns_a) / std(returns_b)` over Z-window |
| FR-3.6 | Calculate realized volatility for each asset (for dynamic SL/TP) |
| FR-3.7 | Merge on-chain data (funding rates, inflows) if available per asset |
| FR-3.8 | Add pair identifier as categorical feature for universal model |
### 4.4 Correlation Filtering
| ID | Requirement |
|----|-------------|
| FR-4.1 | Calculate rolling correlation matrix between all assets (default: 168h / 7 days) |
| FR-4.2 | Define correlation threshold (default: 0.85) |
| FR-4.3 | If current position exists, exclude pairs where either asset has correlation > threshold with held asset |
| FR-4.4 | Log filtered pairs with reason for exclusion |
### 4.5 Divergence Scoring & Pair Selection
| ID | Requirement |
|----|-------------|
| FR-5.1 | Calculate divergence score: `abs(z_score) * model_probability` |
| FR-5.2 | Only consider pairs where `abs(z_score) > z_entry_threshold` (default: 1.0) |
| FR-5.3 | Only consider pairs where `model_probability > prob_threshold` (default: 0.5) |
| FR-5.4 | Apply correlation filter to eligible pairs |
| FR-5.5 | Select pair with highest divergence score |
| FR-5.6 | If no pair qualifies, signal "hold" |
| FR-5.7 | Log all pair scores for analysis/debugging |
### 4.6 ML Model (Universal)
| ID | Requirement |
|----|-------------|
| FR-6.1 | Train single Random Forest model on all pairs combined |
| FR-6.2 | Include `pair_id` as one-hot encoded or label-encoded feature |
| FR-6.3 | Target: binary (1 = profitable reversion within horizon, 0 = no reversion) |
| FR-6.4 | Walk-forward training: 70% train / 30% test split |
| FR-6.5 | Daily retraining schedule (for live, configurable for backtest) |
| FR-6.6 | Model hyperparameters: `n_estimators=300, max_depth=5, min_samples_leaf=30, class_weight={0:1, 1:3}` |
| FR-6.7 | Save/load model with feature column metadata |
### 4.7 Signal Generation
| ID | Requirement |
|----|-------------|
| FR-7.1 | Direction: If `z_score > threshold` -> Short spread (short asset_a), If `z_score < -threshold` -> Long spread (long asset_a) |
| FR-7.2 | Apply funding rate filter per asset (block if extreme funding opposes direction) |
| FR-7.3 | Output signal: `{pair, action, side, probability, z_score, divergence_score, reason}` |
### 4.8 Position Sizing
| ID | Requirement |
|----|-------------|
| FR-8.1 | Base size: 100% of available subaccount balance |
| FR-8.2 | Scale by divergence: `size_multiplier = 1.0 + (divergence_score - base_threshold) * scaling_factor` |
| FR-8.3 | Cap multiplier between 1.0x and 2.0x |
| FR-8.4 | Respect exchange minimum order size per asset |
### 4.9 Dynamic SL/TP (Volatility-Adjusted)
| ID | Requirement |
|----|-------------|
| FR-9.1 | Calculate asset realized volatility: `std(returns) * sqrt(24)` for daily vol |
| FR-9.2 | Base SL: `entry_price * (1 - base_sl_pct * vol_multiplier)` for longs |
| FR-9.3 | Base TP: `entry_price * (1 + base_tp_pct * vol_multiplier)` for longs |
| FR-9.4 | `vol_multiplier = asset_volatility / baseline_volatility` (baseline = BTC volatility) |
| FR-9.5 | Cap vol_multiplier between 0.5x and 2.0x to prevent extreme values |
| FR-9.6 | Invert logic for short positions |
### 4.10 Exit Conditions
| ID | Requirement |
|----|-------------|
| FR-10.1 | Exit when Z-score crosses back through 0 (mean reversion complete) |
| FR-10.2 | Exit when dynamic SL or TP hit |
| FR-10.3 | No minimum holding period (can switch pairs immediately) |
| FR-10.4 | If new pair has higher divergence score, close current and open new |
### 4.11 Backtest Integration
| ID | Requirement |
|----|-------------|
| FR-11.1 | Integrate with existing `engine/backtester.py` framework |
| FR-11.2 | Support 1h timeframe (matching live trading) |
| FR-11.3 | Walk-forward validation: train on 70%, test on 30% |
| FR-11.4 | Output: trades log, equity curve, performance metrics |
| FR-11.5 | Compare against single-pair BTC/ETH baseline |
---
## 5. Non-Goals (Out of Scope)
1. **Live trading implementation** - Backtest validation first
2. **Multi-position portfolio** - Single pair at a time for v1
3. **Cross-exchange arbitrage** - OKX only
4. **Alternative ML models** - Stick with Random Forest for consistency
5. **Sub-1h timeframes** - 1h candles only for initial version
6. **Leveraged positions** - 1x leverage for backtest
7. **Portfolio-level VaR/risk budgeting** - Full subaccount allocation
---
## 6. Design Considerations
### 6.1 Architecture
```
strategies/
multi_pair/
__init__.py
pair_scanner.py # Generates all pairs, filters tradeable
feature_engine.py # Calculates features for all pairs
correlation.py # Rolling correlation matrix & filtering
divergence_scorer.py # Ranks pairs by divergence score
strategy.py # Main strategy orchestration
```
### 6.2 Data Flow
```
1. Load OHLCV for all 10 assets
2. Generate pair combinations (45 pairs)
3. Filter to tradeable pairs (OKX check)
4. Calculate features for each pair
5. Train/load universal ML model
6. Predict probability for all pairs
7. Calculate divergence scores
8. Apply correlation filter
9. Select top pair
10. Generate signal with dynamic SL/TP
11. Execute in backtest engine
```
### 6.3 Configuration
```python
@dataclass
class MultiPairConfig:
# Assets
assets: list[str] = field(default_factory=lambda: [
"BTC", "ETH", "SOL", "XRP", "BNB",
"DOGE", "ADA", "AVAX", "LINK", "DOT"
])
# Thresholds
z_window: int = 24
z_entry_threshold: float = 1.0
prob_threshold: float = 0.5
correlation_threshold: float = 0.85
correlation_window: int = 168 # 7 days in hours
# Risk
base_sl_pct: float = 0.06
base_tp_pct: float = 0.05
vol_multiplier_min: float = 0.5
vol_multiplier_max: float = 2.0
# Model
train_ratio: float = 0.7
horizon: int = 102
profit_target: float = 0.005
```
---
## 7. Technical Considerations
### 7.1 Dependencies
- Extend `DataManager` to load multiple symbols
- Query OKX API for available perpetual cross-pairs
- Reuse existing feature engineering from `RegimeReversionStrategy`
### 7.2 Performance
- Pre-calculate all pair features in batch (vectorized)
- Cache correlation matrix (update every N candles, not every minute)
- Model inference is fast (single predict call with all pairs as rows)
### 7.3 Edge Cases
- Handle pairs with insufficient history (< 200 bars) - exclude
- Handle assets delisted mid-backtest - skip pair
- Handle zero-volume periods - use last valid price
---
## 8. Success Metrics
| Metric | Baseline (BTC/ETH) | Target |
|--------|-------------------|--------|
| Net PnL | Current performance | > 10% improvement |
| Number of Trades | N | Comparable or higher |
| Win Rate | Baseline % | Maintain or improve |
| Average Trade Duration | Baseline hours | Flexible |
| Max Drawdown | Baseline % | Not significantly worse |
---
## 9. Open Questions
1. **OKX Cross-Pairs**: Need to verify which cross-pairs are available on OKX perpetuals. May need to fallback to synthetic spreads for most pairs.
2. **On-Chain Data**: CryptoQuant data currently covers BTC/ETH. Should we:
- Run without on-chain features for other assets?
- Source alternative on-chain data?
- Use funding rates only (available from OKX)?
3. **Pair ID Encoding**: For the universal model, should pair_id be:
- One-hot encoded (adds 45 features)?
- Label encoded (single ordinal feature)?
- Hierarchical (base_asset + quote_asset as separate features)?
4. **Synthetic Spreads**: If trading SOL/DOT spread but only USDT pairs available:
- Calculate spread synthetically: `SOL-USDT / DOT-USDT`
- Execute as two legs: Long SOL-USDT, Short DOT-USDT
- This doubles fees and adds execution complexity. Include in v1?
---
## 10. Implementation Phases
### Phase 1: Data & Infrastructure (Est. 2-3 days)
- Extend DataManager for multi-symbol loading
- Build pair scanner with OKX tradeable filter
- Implement correlation matrix calculation
### Phase 2: Feature Engineering (Est. 2 days)
- Adapt existing feature calculation for arbitrary pairs
- Add pair identifier feature
- Batch feature calculation for all pairs
### Phase 3: Model & Scoring (Est. 2 days)
- Train universal model on all pairs
- Implement divergence scoring
- Add correlation filtering to pair selection
### Phase 4: Strategy Integration (Est. 2-3 days)
- Implement dynamic SL/TP with volatility
- Integrate with backtester
- Build strategy orchestration class
### Phase 5: Validation & Comparison (Est. 2 days)
- Run walk-forward backtest
- Compare against BTC/ETH baseline
- Generate performance report
**Total Estimated Effort: 10-12 days**
---
*Document Version: 1.0*
*Created: 2026-01-15*
*Author: AI Assistant*
*Status: Draft - Awaiting Review*