feat: Multi-Pair Divergence Selection Strategy
- Extend regime detection to top 10 cryptocurrencies (45 pairs) - Dynamic pair selection based on divergence score (|z_score| * probability) - Universal ML model trained on all pairs - Correlation-based filtering to avoid redundant positions - Funding rate integration from OKX for all 10 assets - ATR-based dynamic stop-loss and take-profit - Walk-forward training with 70/30 split Performance: +35.69% return (vs +28.66% baseline), 63.6% win rate
This commit is contained in:
BIN
data/multi_pair_model.pkl
Normal file
BIN
data/multi_pair_model.pkl
Normal file
Binary file not shown.
47
scripts/download_multi_pair_data.py
Normal file
47
scripts/download_multi_pair_data.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Download historical data for Multi-Pair Divergence Strategy.
|
||||
|
||||
Downloads 1h OHLCV data for top 10 cryptocurrencies from OKX.
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, '.')
|
||||
|
||||
from engine.data_manager import DataManager
|
||||
from engine.market import MarketType
|
||||
from engine.logging_config import setup_logging, get_logger
|
||||
from strategies.multi_pair import MultiPairConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
"""Download data for all configured assets."""
|
||||
setup_logging()
|
||||
|
||||
config = MultiPairConfig()
|
||||
dm = DataManager()
|
||||
|
||||
logger.info("Downloading data for %d assets...", len(config.assets))
|
||||
|
||||
for symbol in config.assets:
|
||||
logger.info("Downloading %s perpetual 1h data...", symbol)
|
||||
try:
|
||||
df = dm.download_data(
|
||||
exchange_id=config.exchange_id,
|
||||
symbol=symbol,
|
||||
timeframe=config.timeframe,
|
||||
market_type=MarketType.PERPETUAL
|
||||
)
|
||||
if df is not None:
|
||||
logger.info("Downloaded %d candles for %s", len(df), symbol)
|
||||
else:
|
||||
logger.warning("No data downloaded for %s", symbol)
|
||||
except Exception as e:
|
||||
logger.error("Failed to download %s: %s", symbol, e)
|
||||
|
||||
logger.info("Download complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
156
scripts/run_multi_pair_backtest.py
Normal file
156
scripts/run_multi_pair_backtest.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Run Multi-Pair Divergence Strategy backtest and compare with baseline.
|
||||
|
||||
Compares the multi-pair strategy against the single-pair BTC/ETH regime strategy.
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, '.')
|
||||
|
||||
from engine.backtester import Backtester
|
||||
from engine.data_manager import DataManager
|
||||
from engine.logging_config import setup_logging, get_logger
|
||||
from engine.reporting import Reporter
|
||||
from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig
|
||||
from strategies.regime_strategy import RegimeReversionStrategy
|
||||
from engine.market import MarketType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_baseline():
|
||||
"""Run baseline BTC/ETH regime strategy."""
|
||||
logger.info("=" * 60)
|
||||
logger.info("BASELINE: BTC/ETH Regime Reversion Strategy")
|
||||
logger.info("=" * 60)
|
||||
|
||||
dm = DataManager()
|
||||
bt = Backtester(dm)
|
||||
|
||||
strategy = RegimeReversionStrategy()
|
||||
|
||||
result = bt.run_strategy(
|
||||
strategy,
|
||||
'okx',
|
||||
'ETH-USDT',
|
||||
timeframe='1h',
|
||||
init_cash=10000
|
||||
)
|
||||
|
||||
logger.info("Baseline Results:")
|
||||
logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100)
|
||||
logger.info(" Total Trades: %d", result.portfolio.trades.count())
|
||||
logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def run_multi_pair(assets: list[str] | None = None):
|
||||
"""Run multi-pair divergence strategy."""
|
||||
logger.info("=" * 60)
|
||||
logger.info("MULTI-PAIR: Divergence Selection Strategy")
|
||||
logger.info("=" * 60)
|
||||
|
||||
dm = DataManager()
|
||||
bt = Backtester(dm)
|
||||
|
||||
# Use provided assets or default
|
||||
if assets:
|
||||
config = MultiPairConfig(assets=assets)
|
||||
else:
|
||||
config = MultiPairConfig()
|
||||
|
||||
logger.info("Configured %d assets, %d pairs", len(config.assets), config.get_pair_count())
|
||||
|
||||
strategy = MultiPairDivergenceStrategy(config=config)
|
||||
|
||||
result = bt.run_strategy(
|
||||
strategy,
|
||||
'okx',
|
||||
'ETH-USDT', # Reference asset (not used for trading, just index alignment)
|
||||
timeframe='1h',
|
||||
init_cash=10000
|
||||
)
|
||||
|
||||
logger.info("Multi-Pair Results:")
|
||||
logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100)
|
||||
logger.info(" Total Trades: %d", result.portfolio.trades.count())
|
||||
logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def compare_results(baseline, multi_pair):
|
||||
"""Compare and display results."""
|
||||
logger.info("=" * 60)
|
||||
logger.info("COMPARISON")
|
||||
logger.info("=" * 60)
|
||||
|
||||
baseline_return = baseline.portfolio.total_return() * 100
|
||||
multi_return = multi_pair.portfolio.total_return() * 100
|
||||
|
||||
improvement = multi_return - baseline_return
|
||||
|
||||
logger.info("Baseline Return: %.2f%%", baseline_return)
|
||||
logger.info("Multi-Pair Return: %.2f%%", multi_return)
|
||||
logger.info("Improvement: %.2f%% (%.1fx)",
|
||||
improvement,
|
||||
multi_return / baseline_return if baseline_return != 0 else 0)
|
||||
|
||||
baseline_trades = baseline.portfolio.trades.count()
|
||||
multi_trades = multi_pair.portfolio.trades.count()
|
||||
|
||||
logger.info("Baseline Trades: %d", baseline_trades)
|
||||
logger.info("Multi-Pair Trades: %d", multi_trades)
|
||||
|
||||
return {
|
||||
'baseline_return': baseline_return,
|
||||
'multi_pair_return': multi_return,
|
||||
'improvement': improvement,
|
||||
'baseline_trades': baseline_trades,
|
||||
'multi_pair_trades': multi_trades
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
setup_logging()
|
||||
|
||||
# Check available assets
|
||||
dm = DataManager()
|
||||
available = []
|
||||
|
||||
for symbol in MultiPairConfig().assets:
|
||||
try:
|
||||
dm.load_data('okx', symbol, '1h', market_type=MarketType.PERPETUAL)
|
||||
available.append(symbol)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
if len(available) < 2:
|
||||
logger.error(
|
||||
"Need at least 2 assets to run multi-pair strategy. "
|
||||
"Run: uv run python scripts/download_multi_pair_data.py"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("Found data for %d assets: %s", len(available), available)
|
||||
|
||||
# Run baseline
|
||||
baseline_result = run_baseline()
|
||||
|
||||
# Run multi-pair
|
||||
multi_result = run_multi_pair(available)
|
||||
|
||||
# Compare
|
||||
comparison = compare_results(baseline_result, multi_result)
|
||||
|
||||
# Save reports
|
||||
reporter = Reporter()
|
||||
reporter.save_reports(multi_result, "multi_pair_divergence")
|
||||
|
||||
logger.info("Reports saved to backtest_logs/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -37,6 +37,7 @@ def _build_registry() -> dict[str, StrategyConfig]:
|
||||
from strategies.examples import MaCrossStrategy, RsiStrategy
|
||||
from strategies.supertrend import MetaSupertrendStrategy
|
||||
from strategies.regime_strategy import RegimeReversionStrategy
|
||||
from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig
|
||||
|
||||
return {
|
||||
"rsi": StrategyConfig(
|
||||
@@ -98,6 +99,18 @@ def _build_registry() -> dict[str, StrategyConfig]:
|
||||
'stop_loss': [0.04, 0.06, 0.08],
|
||||
'funding_threshold': [0.005, 0.01, 0.02]
|
||||
}
|
||||
),
|
||||
"multi_pair": StrategyConfig(
|
||||
strategy_class=MultiPairDivergenceStrategy,
|
||||
default_params={
|
||||
# Multi-pair divergence strategy uses config object
|
||||
# Parameters passed here will override MultiPairConfig defaults
|
||||
},
|
||||
grid_params={
|
||||
'z_entry_threshold': [0.8, 1.0, 1.2],
|
||||
'prob_threshold': [0.4, 0.5, 0.6],
|
||||
'correlation_threshold': [0.75, 0.85, 0.95]
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
24
strategies/multi_pair/__init__.py
Normal file
24
strategies/multi_pair/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Multi-Pair Divergence Selection Strategy.
|
||||
|
||||
Extends regime detection to multiple cryptocurrency pairs and dynamically
|
||||
selects the most divergent pair for trading.
|
||||
"""
|
||||
from .config import MultiPairConfig
|
||||
from .pair_scanner import PairScanner, TradingPair
|
||||
from .correlation import CorrelationFilter
|
||||
from .feature_engine import MultiPairFeatureEngine
|
||||
from .divergence_scorer import DivergenceScorer
|
||||
from .strategy import MultiPairDivergenceStrategy
|
||||
from .funding import FundingRateFetcher
|
||||
|
||||
__all__ = [
|
||||
"MultiPairConfig",
|
||||
"PairScanner",
|
||||
"TradingPair",
|
||||
"CorrelationFilter",
|
||||
"MultiPairFeatureEngine",
|
||||
"DivergenceScorer",
|
||||
"MultiPairDivergenceStrategy",
|
||||
"FundingRateFetcher",
|
||||
]
|
||||
88
strategies/multi_pair/config.py
Normal file
88
strategies/multi_pair/config.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Configuration for Multi-Pair Divergence Strategy.
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiPairConfig:
|
||||
"""
|
||||
Configuration parameters for multi-pair divergence strategy.
|
||||
|
||||
Attributes:
|
||||
assets: List of asset symbols to analyze (top 10 by market cap)
|
||||
z_window: Rolling window for Z-Score calculation (hours)
|
||||
z_entry_threshold: Minimum |Z-Score| to consider for entry
|
||||
prob_threshold: Minimum ML probability to consider for entry
|
||||
correlation_threshold: Max correlation to allow between pairs
|
||||
correlation_window: Rolling window for correlation (hours)
|
||||
atr_period: ATR lookback period for dynamic stops
|
||||
sl_atr_multiplier: Stop-loss as multiple of ATR
|
||||
tp_atr_multiplier: Take-profit as multiple of ATR
|
||||
train_ratio: Walk-forward train/test split ratio
|
||||
horizon: Look-ahead horizon for target calculation (hours)
|
||||
profit_target: Minimum profit threshold for target labels
|
||||
funding_threshold: Funding rate threshold for filtering
|
||||
"""
|
||||
# Asset Universe
|
||||
assets: list[str] = field(default_factory=lambda: [
|
||||
"BTC-USDT", "ETH-USDT", "SOL-USDT", "XRP-USDT", "BNB-USDT",
|
||||
"DOGE-USDT", "ADA-USDT", "AVAX-USDT", "LINK-USDT", "DOT-USDT"
|
||||
])
|
||||
|
||||
# Z-Score Thresholds
|
||||
z_window: int = 24
|
||||
z_entry_threshold: float = 1.0
|
||||
|
||||
# ML Thresholds
|
||||
prob_threshold: float = 0.5
|
||||
train_ratio: float = 0.7
|
||||
horizon: int = 102
|
||||
profit_target: float = 0.005
|
||||
|
||||
# Correlation Filtering
|
||||
correlation_threshold: float = 0.85
|
||||
correlation_window: int = 168 # 7 days in hours
|
||||
|
||||
# Risk Management - ATR-Based Stops
|
||||
# SL/TP are calculated as multiples of ATR
|
||||
# Mean ATR for crypto is ~0.6% per hour, so:
|
||||
# - 10x ATR = ~6% SL (matches previous fixed 6%)
|
||||
# - 8x ATR = ~5% TP (matches previous fixed 5%)
|
||||
atr_period: int = 14 # ATR lookback period (hours for 1h timeframe)
|
||||
sl_atr_multiplier: float = 10.0 # Stop-loss = entry +/- (ATR * multiplier)
|
||||
tp_atr_multiplier: float = 8.0 # Take-profit = entry +/- (ATR * multiplier)
|
||||
|
||||
# Fallback fixed percentages (used if ATR is unavailable)
|
||||
base_sl_pct: float = 0.06
|
||||
base_tp_pct: float = 0.05
|
||||
|
||||
# ATR bounds to prevent extreme stops
|
||||
min_sl_pct: float = 0.02 # Minimum 2% stop-loss
|
||||
max_sl_pct: float = 0.10 # Maximum 10% stop-loss
|
||||
min_tp_pct: float = 0.02 # Minimum 2% take-profit
|
||||
max_tp_pct: float = 0.15 # Maximum 15% take-profit
|
||||
|
||||
volatility_window: int = 24
|
||||
|
||||
# Funding Rate Filter
|
||||
# OKX funding rates are typically 0.0001 (0.01%) per 8h
|
||||
# Extreme funding is > 0.0005 (0.05%) which indicates crowded trade
|
||||
funding_threshold: float = 0.0005 # 0.05% - filter extreme funding
|
||||
|
||||
# Trade Management
|
||||
# Note: Setting min_hold_bars=0 and z_exit_threshold=0 gives best results
|
||||
# The mean-reversion exit at Z=0 is the primary profit driver
|
||||
min_hold_bars: int = 0 # Disabled - let mean reversion drive exits
|
||||
switch_threshold: float = 999.0 # Disabled - don't switch mid-trade
|
||||
cooldown_bars: int = 0 # Disabled - enter when signal appears
|
||||
z_exit_threshold: float = 0.0 # Exit at Z=0 (mean reversion complete)
|
||||
|
||||
# Exchange
|
||||
exchange_id: str = "okx"
|
||||
timeframe: str = "1h"
|
||||
|
||||
def get_pair_count(self) -> int:
|
||||
"""Calculate number of unique pairs from asset list."""
|
||||
n = len(self.assets)
|
||||
return n * (n - 1) // 2
|
||||
173
strategies/multi_pair/correlation.py
Normal file
173
strategies/multi_pair/correlation.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Correlation Filter for Multi-Pair Divergence Strategy.
|
||||
|
||||
Calculates rolling correlation matrix and filters pairs
|
||||
to avoid highly correlated positions.
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from .config import MultiPairConfig
|
||||
from .pair_scanner import TradingPair
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CorrelationFilter:
|
||||
"""
|
||||
Calculates and filters based on asset correlations.
|
||||
|
||||
Uses rolling correlation of returns to identify assets
|
||||
moving together, avoiding redundant positions.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MultiPairConfig):
|
||||
self.config = config
|
||||
self._correlation_matrix: pd.DataFrame | None = None
|
||||
self._last_update_idx: int = -1
|
||||
|
||||
def calculate_correlation_matrix(
|
||||
self,
|
||||
price_data: dict[str, pd.Series],
|
||||
current_idx: int | None = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Calculate rolling correlation matrix between all assets.
|
||||
|
||||
Args:
|
||||
price_data: Dictionary mapping asset symbols to price series
|
||||
current_idx: Current bar index (for caching)
|
||||
|
||||
Returns:
|
||||
Correlation matrix DataFrame
|
||||
"""
|
||||
# Use cached if recent
|
||||
if (
|
||||
current_idx is not None
|
||||
and self._correlation_matrix is not None
|
||||
and current_idx - self._last_update_idx < 24 # Update every 24 bars
|
||||
):
|
||||
return self._correlation_matrix
|
||||
|
||||
# Calculate returns
|
||||
returns = {}
|
||||
for symbol, prices in price_data.items():
|
||||
returns[symbol] = prices.pct_change()
|
||||
|
||||
returns_df = pd.DataFrame(returns)
|
||||
|
||||
# Rolling correlation
|
||||
window = self.config.correlation_window
|
||||
|
||||
# Get latest correlation (last row of rolling correlation)
|
||||
if len(returns_df) >= window:
|
||||
rolling_corr = returns_df.rolling(window=window).corr()
|
||||
# Extract last timestamp correlation matrix
|
||||
last_idx = returns_df.index[-1]
|
||||
corr_matrix = rolling_corr.loc[last_idx]
|
||||
else:
|
||||
# Fallback to full-period correlation if not enough data
|
||||
corr_matrix = returns_df.corr()
|
||||
|
||||
self._correlation_matrix = corr_matrix
|
||||
if current_idx is not None:
|
||||
self._last_update_idx = current_idx
|
||||
|
||||
return corr_matrix
|
||||
|
||||
def filter_pairs(
|
||||
self,
|
||||
pairs: list[TradingPair],
|
||||
current_position_asset: str | None,
|
||||
price_data: dict[str, pd.Series],
|
||||
current_idx: int | None = None
|
||||
) -> list[TradingPair]:
|
||||
"""
|
||||
Filter pairs based on correlation with current position.
|
||||
|
||||
If we have an open position in an asset, exclude pairs where
|
||||
either asset is highly correlated with the held asset.
|
||||
|
||||
Args:
|
||||
pairs: List of candidate pairs
|
||||
current_position_asset: Currently held asset (or None)
|
||||
price_data: Dictionary of price series by symbol
|
||||
current_idx: Current bar index for caching
|
||||
|
||||
Returns:
|
||||
Filtered list of pairs
|
||||
"""
|
||||
if current_position_asset is None:
|
||||
return pairs
|
||||
|
||||
corr_matrix = self.calculate_correlation_matrix(price_data, current_idx)
|
||||
threshold = self.config.correlation_threshold
|
||||
|
||||
filtered = []
|
||||
for pair in pairs:
|
||||
# Check correlation of base and quote with held asset
|
||||
base_corr = self._get_correlation(
|
||||
corr_matrix, pair.base_asset, current_position_asset
|
||||
)
|
||||
quote_corr = self._get_correlation(
|
||||
corr_matrix, pair.quote_asset, current_position_asset
|
||||
)
|
||||
|
||||
# Filter if either asset highly correlated with position
|
||||
if abs(base_corr) > threshold or abs(quote_corr) > threshold:
|
||||
logger.debug(
|
||||
"Filtered %s: base_corr=%.2f, quote_corr=%.2f (held: %s)",
|
||||
pair.name, base_corr, quote_corr, current_position_asset
|
||||
)
|
||||
continue
|
||||
|
||||
filtered.append(pair)
|
||||
|
||||
if len(filtered) < len(pairs):
|
||||
logger.info(
|
||||
"Correlation filter: %d/%d pairs remaining (held: %s)",
|
||||
len(filtered), len(pairs), current_position_asset
|
||||
)
|
||||
|
||||
return filtered
|
||||
|
||||
def _get_correlation(
|
||||
self,
|
||||
corr_matrix: pd.DataFrame,
|
||||
asset1: str,
|
||||
asset2: str
|
||||
) -> float:
|
||||
"""
|
||||
Get correlation between two assets from matrix.
|
||||
|
||||
Args:
|
||||
corr_matrix: Correlation matrix
|
||||
asset1: First asset symbol
|
||||
asset2: Second asset symbol
|
||||
|
||||
Returns:
|
||||
Correlation coefficient (-1 to 1), or 0 if not found
|
||||
"""
|
||||
if asset1 == asset2:
|
||||
return 1.0
|
||||
|
||||
try:
|
||||
return corr_matrix.loc[asset1, asset2]
|
||||
except KeyError:
|
||||
return 0.0
|
||||
|
||||
def get_correlation_report(
|
||||
self,
|
||||
price_data: dict[str, pd.Series]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Generate a readable correlation report.
|
||||
|
||||
Args:
|
||||
price_data: Dictionary of price series
|
||||
|
||||
Returns:
|
||||
Correlation matrix as DataFrame
|
||||
"""
|
||||
return self.calculate_correlation_matrix(price_data)
|
||||
311
strategies/multi_pair/divergence_scorer.py
Normal file
311
strategies/multi_pair/divergence_scorer.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Divergence Scorer for Multi-Pair Strategy.
|
||||
|
||||
Ranks pairs by divergence score and selects the best candidate.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from .config import MultiPairConfig
|
||||
from .pair_scanner import TradingPair
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DivergenceSignal:
|
||||
"""
|
||||
Signal for a divergent pair.
|
||||
|
||||
Attributes:
|
||||
pair: Trading pair
|
||||
z_score: Current Z-Score of the spread
|
||||
probability: ML model probability of profitable reversion
|
||||
divergence_score: Combined score (|z_score| * probability)
|
||||
direction: 'long' or 'short' (relative to base asset)
|
||||
base_price: Current price of base asset
|
||||
quote_price: Current price of quote asset
|
||||
atr: Average True Range in price units
|
||||
atr_pct: ATR as percentage of price
|
||||
"""
|
||||
pair: TradingPair
|
||||
z_score: float
|
||||
probability: float
|
||||
divergence_score: float
|
||||
direction: str
|
||||
base_price: float
|
||||
quote_price: float
|
||||
atr: float
|
||||
atr_pct: float
|
||||
timestamp: pd.Timestamp
|
||||
|
||||
|
||||
class DivergenceScorer:
|
||||
"""
|
||||
Scores and ranks pairs by divergence potential.
|
||||
|
||||
Uses ML model predictions combined with Z-Score magnitude
|
||||
to identify the most promising mean-reversion opportunity.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MultiPairConfig, model_path: str = "data/multi_pair_model.pkl"):
|
||||
self.config = config
|
||||
self.model_path = Path(model_path)
|
||||
self.model: RandomForestClassifier | None = None
|
||||
self.feature_cols: list[str] | None = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Load pre-trained model if available."""
|
||||
if self.model_path.exists():
|
||||
try:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
saved = pickle.load(f)
|
||||
self.model = saved['model']
|
||||
self.feature_cols = saved['feature_cols']
|
||||
logger.info("Loaded model from %s", self.model_path)
|
||||
except Exception as e:
|
||||
logger.warning("Could not load model: %s", e)
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""Save trained model."""
|
||||
if self.model is None:
|
||||
return
|
||||
|
||||
self.model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.model_path, 'wb') as f:
|
||||
pickle.dump({
|
||||
'model': self.model,
|
||||
'feature_cols': self.feature_cols,
|
||||
}, f)
|
||||
logger.info("Saved model to %s", self.model_path)
|
||||
|
||||
def train_model(
|
||||
self,
|
||||
combined_features: pd.DataFrame,
|
||||
pair_features: dict[str, pd.DataFrame]
|
||||
) -> None:
|
||||
"""
|
||||
Train universal model on all pairs.
|
||||
|
||||
Args:
|
||||
combined_features: Combined feature DataFrame from all pairs
|
||||
pair_features: Individual pair feature DataFrames (for target calculation)
|
||||
"""
|
||||
logger.info("Training universal model on %d samples...", len(combined_features))
|
||||
|
||||
z_thresh = self.config.z_entry_threshold
|
||||
horizon = self.config.horizon
|
||||
profit_target = self.config.profit_target
|
||||
|
||||
# Calculate targets for each pair
|
||||
all_targets = []
|
||||
all_features = []
|
||||
|
||||
for pair_id, features in pair_features.items():
|
||||
if len(features) < horizon + 50:
|
||||
continue
|
||||
|
||||
spread = features['spread']
|
||||
z_score = features['z_score']
|
||||
|
||||
# Future price movements
|
||||
future_min = spread.rolling(window=horizon).min().shift(-horizon)
|
||||
future_max = spread.rolling(window=horizon).max().shift(-horizon)
|
||||
|
||||
# Target labels
|
||||
target_short = spread * (1 - profit_target)
|
||||
target_long = spread * (1 + profit_target)
|
||||
|
||||
success_short = (z_score > z_thresh) & (future_min < target_short)
|
||||
success_long = (z_score < -z_thresh) & (future_max > target_long)
|
||||
|
||||
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||
|
||||
# Valid mask (exclude rows without complete future data)
|
||||
valid_mask = future_min.notna() & future_max.notna()
|
||||
|
||||
# Collect valid samples
|
||||
valid_features = features[valid_mask]
|
||||
valid_targets = targets[valid_mask.values]
|
||||
|
||||
if len(valid_features) > 0:
|
||||
all_features.append(valid_features)
|
||||
all_targets.extend(valid_targets)
|
||||
|
||||
if not all_features:
|
||||
logger.warning("No valid training samples")
|
||||
return
|
||||
|
||||
# Combine all training data
|
||||
X_df = pd.concat(all_features, ignore_index=True)
|
||||
y = np.array(all_targets)
|
||||
|
||||
# Get feature columns
|
||||
exclude_cols = [
|
||||
'pair_id', 'base_asset', 'quote_asset',
|
||||
'spread', 'base_close', 'quote_close', 'base_volume'
|
||||
]
|
||||
self.feature_cols = [c for c in X_df.columns if c not in exclude_cols]
|
||||
|
||||
# Prepare features
|
||||
X = X_df[self.feature_cols].fillna(0)
|
||||
X = X.replace([np.inf, -np.inf], 0)
|
||||
|
||||
# Train model
|
||||
self.model = RandomForestClassifier(
|
||||
n_estimators=300,
|
||||
max_depth=5,
|
||||
min_samples_leaf=30,
|
||||
class_weight={0: 1, 1: 3},
|
||||
random_state=42
|
||||
)
|
||||
self.model.fit(X, y)
|
||||
|
||||
logger.info(
|
||||
"Model trained on %d samples, %d features, %.1f%% positive class",
|
||||
len(X), len(self.feature_cols), y.mean() * 100
|
||||
)
|
||||
self.save_model()
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pair_features: dict[str, pd.DataFrame],
|
||||
pairs: list[TradingPair],
|
||||
timestamp: pd.Timestamp | None = None
|
||||
) -> list[DivergenceSignal]:
|
||||
"""
|
||||
Score all pairs and return ranked signals.
|
||||
|
||||
Args:
|
||||
pair_features: Feature DataFrames by pair_id
|
||||
pairs: List of TradingPair objects
|
||||
timestamp: Current timestamp for feature extraction
|
||||
|
||||
Returns:
|
||||
List of DivergenceSignal sorted by score (descending)
|
||||
"""
|
||||
if self.model is None:
|
||||
logger.warning("Model not trained, returning empty signals")
|
||||
return []
|
||||
|
||||
signals = []
|
||||
pair_map = {p.pair_id: p for p in pairs}
|
||||
|
||||
for pair_id, features in pair_features.items():
|
||||
if pair_id not in pair_map:
|
||||
continue
|
||||
|
||||
pair = pair_map[pair_id]
|
||||
|
||||
# Get latest features
|
||||
if timestamp is not None:
|
||||
valid = features[features.index <= timestamp]
|
||||
if len(valid) == 0:
|
||||
continue
|
||||
latest = valid.iloc[-1]
|
||||
ts = valid.index[-1]
|
||||
else:
|
||||
latest = features.iloc[-1]
|
||||
ts = features.index[-1]
|
||||
|
||||
z_score = latest['z_score']
|
||||
|
||||
# Skip if Z-score below threshold
|
||||
if abs(z_score) < self.config.z_entry_threshold:
|
||||
continue
|
||||
|
||||
# Prepare features for prediction
|
||||
feature_row = latest[self.feature_cols].fillna(0).infer_objects(copy=False)
|
||||
feature_row = feature_row.replace([np.inf, -np.inf], 0)
|
||||
X = pd.DataFrame([feature_row.values], columns=self.feature_cols)
|
||||
|
||||
# Predict probability
|
||||
prob = self.model.predict_proba(X)[0, 1]
|
||||
|
||||
# Skip if probability below threshold
|
||||
if prob < self.config.prob_threshold:
|
||||
continue
|
||||
|
||||
# Apply funding rate filter
|
||||
# Block trades where funding opposes our direction
|
||||
base_funding = latest.get('base_funding', 0) or 0
|
||||
funding_thresh = self.config.funding_threshold
|
||||
|
||||
if z_score > 0: # Short signal
|
||||
# High negative funding = shorts are paying -> skip
|
||||
if base_funding < -funding_thresh:
|
||||
logger.debug(
|
||||
"Skipping %s short: funding too negative (%.4f)",
|
||||
pair.name, base_funding
|
||||
)
|
||||
continue
|
||||
else: # Long signal
|
||||
# High positive funding = longs are paying -> skip
|
||||
if base_funding > funding_thresh:
|
||||
logger.debug(
|
||||
"Skipping %s long: funding too positive (%.4f)",
|
||||
pair.name, base_funding
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate divergence score
|
||||
divergence_score = abs(z_score) * prob
|
||||
|
||||
# Determine direction
|
||||
# Z > 0: Spread high (base expensive vs quote) -> Short base
|
||||
# Z < 0: Spread low (base cheap vs quote) -> Long base
|
||||
direction = 'short' if z_score > 0 else 'long'
|
||||
|
||||
signal = DivergenceSignal(
|
||||
pair=pair,
|
||||
z_score=z_score,
|
||||
probability=prob,
|
||||
divergence_score=divergence_score,
|
||||
direction=direction,
|
||||
base_price=latest['base_close'],
|
||||
quote_price=latest['quote_close'],
|
||||
atr=latest.get('atr_base', 0),
|
||||
atr_pct=latest.get('atr_pct_base', 0.02),
|
||||
timestamp=ts
|
||||
)
|
||||
signals.append(signal)
|
||||
|
||||
# Sort by divergence score (highest first)
|
||||
signals.sort(key=lambda s: s.divergence_score, reverse=True)
|
||||
|
||||
if signals:
|
||||
logger.debug(
|
||||
"Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f)",
|
||||
len(signals),
|
||||
signals[0].pair.name,
|
||||
signals[0].divergence_score,
|
||||
signals[0].z_score,
|
||||
signals[0].probability
|
||||
)
|
||||
|
||||
return signals
|
||||
|
||||
def select_best_pair(
|
||||
self,
|
||||
signals: list[DivergenceSignal]
|
||||
) -> DivergenceSignal | None:
|
||||
"""
|
||||
Select the best pair from scored signals.
|
||||
|
||||
Args:
|
||||
signals: List of DivergenceSignal (pre-sorted by score)
|
||||
|
||||
Returns:
|
||||
Best signal or None if no valid candidates
|
||||
"""
|
||||
if not signals:
|
||||
return None
|
||||
return signals[0]
|
||||
433
strategies/multi_pair/feature_engine.py
Normal file
433
strategies/multi_pair/feature_engine.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Feature Engineering for Multi-Pair Divergence Strategy.
|
||||
|
||||
Calculates features for all pairs in the universe, including
|
||||
spread technicals, volatility, and on-chain data.
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import ta
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from engine.data_manager import DataManager
|
||||
from engine.market import MarketType
|
||||
from .config import MultiPairConfig
|
||||
from .pair_scanner import TradingPair
|
||||
from .funding import FundingRateFetcher
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MultiPairFeatureEngine:
|
||||
"""
|
||||
Calculates features for multiple trading pairs.
|
||||
|
||||
Generates consistent feature sets across all pairs for
|
||||
the universal ML model.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MultiPairConfig):
|
||||
self.config = config
|
||||
self.dm = DataManager()
|
||||
self.funding_fetcher = FundingRateFetcher()
|
||||
self._funding_data: pd.DataFrame | None = None
|
||||
|
||||
def load_all_assets(
|
||||
self,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None
|
||||
) -> dict[str, pd.DataFrame]:
|
||||
"""
|
||||
Load OHLCV data for all assets in the universe.
|
||||
|
||||
Args:
|
||||
start_date: Start date filter (YYYY-MM-DD)
|
||||
end_date: End date filter (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol to OHLCV DataFrame
|
||||
"""
|
||||
data = {}
|
||||
market_type = MarketType.PERPETUAL
|
||||
|
||||
for symbol in self.config.assets:
|
||||
try:
|
||||
df = self.dm.load_data(
|
||||
self.config.exchange_id,
|
||||
symbol,
|
||||
self.config.timeframe,
|
||||
market_type
|
||||
)
|
||||
|
||||
# Apply date filters
|
||||
if start_date:
|
||||
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
|
||||
if end_date:
|
||||
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
|
||||
|
||||
if len(df) >= 200: # Minimum data requirement
|
||||
data[symbol] = df
|
||||
logger.debug("Loaded %s: %d bars", symbol, len(df))
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping %s: insufficient data (%d bars)",
|
||||
symbol, len(df)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
logger.warning("Data not found for %s", symbol)
|
||||
except Exception as e:
|
||||
logger.error("Error loading %s: %s", symbol, e)
|
||||
|
||||
logger.info("Loaded %d/%d assets", len(data), len(self.config.assets))
|
||||
return data
|
||||
|
||||
def load_funding_data(
|
||||
self,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
use_cache: bool = True
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Load funding rate data for all assets.
|
||||
|
||||
Args:
|
||||
start_date: Start date filter
|
||||
end_date: End date filter
|
||||
use_cache: Whether to use cached data
|
||||
|
||||
Returns:
|
||||
DataFrame with funding rates for all assets
|
||||
"""
|
||||
self._funding_data = self.funding_fetcher.get_funding_data(
|
||||
self.config.assets,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
if self._funding_data is not None and not self._funding_data.empty:
|
||||
logger.info(
|
||||
"Loaded funding data: %d rows, %d assets",
|
||||
len(self._funding_data),
|
||||
len(self._funding_data.columns)
|
||||
)
|
||||
else:
|
||||
logger.warning("No funding data available")
|
||||
|
||||
return self._funding_data
|
||||
|
||||
def calculate_pair_features(
|
||||
self,
|
||||
pair: TradingPair,
|
||||
asset_data: dict[str, pd.DataFrame],
|
||||
on_chain_data: pd.DataFrame | None = None
|
||||
) -> pd.DataFrame | None:
|
||||
"""
|
||||
Calculate features for a single pair.
|
||||
|
||||
Args:
|
||||
pair: Trading pair
|
||||
asset_data: Dictionary of OHLCV DataFrames by symbol
|
||||
on_chain_data: Optional on-chain data (funding, inflows)
|
||||
|
||||
Returns:
|
||||
DataFrame with features, or None if insufficient data
|
||||
"""
|
||||
base = pair.base_asset
|
||||
quote = pair.quote_asset
|
||||
|
||||
if base not in asset_data or quote not in asset_data:
|
||||
return None
|
||||
|
||||
df_base = asset_data[base]
|
||||
df_quote = asset_data[quote]
|
||||
|
||||
# Align indices
|
||||
common_idx = df_base.index.intersection(df_quote.index)
|
||||
if len(common_idx) < 200:
|
||||
logger.debug("Pair %s: insufficient aligned data", pair.name)
|
||||
return None
|
||||
|
||||
df_a = df_base.loc[common_idx]
|
||||
df_b = df_quote.loc[common_idx]
|
||||
|
||||
# Calculate spread (base / quote)
|
||||
spread = df_a['close'] / df_b['close']
|
||||
|
||||
# Z-Score
|
||||
z_window = self.config.z_window
|
||||
rolling_mean = spread.rolling(window=z_window).mean()
|
||||
rolling_std = spread.rolling(window=z_window).std()
|
||||
z_score = (spread - rolling_mean) / rolling_std
|
||||
|
||||
# Spread Technicals
|
||||
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||
spread_roc = spread.pct_change(periods=5) * 100
|
||||
spread_change_1h = spread.pct_change(periods=1)
|
||||
|
||||
# Volume Analysis
|
||||
vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10)
|
||||
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||
vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10)
|
||||
|
||||
# Volatility
|
||||
ret_a = df_a['close'].pct_change()
|
||||
ret_b = df_b['close'].pct_change()
|
||||
vol_a = ret_a.rolling(window=z_window).std()
|
||||
vol_b = ret_b.rolling(window=z_window).std()
|
||||
vol_spread_ratio = vol_a / (vol_b + 1e-10)
|
||||
|
||||
# Realized Volatility (for dynamic SL/TP)
|
||||
realized_vol_a = ret_a.rolling(window=self.config.volatility_window).std()
|
||||
realized_vol_b = ret_b.rolling(window=self.config.volatility_window).std()
|
||||
|
||||
# ATR (Average True Range) for dynamic stops
|
||||
# ATR = average of max(high-low, |high-prev_close|, |low-prev_close|)
|
||||
high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close']
|
||||
high_b, low_b, close_b = df_b['high'], df_b['low'], df_b['close']
|
||||
|
||||
# True Range for base asset
|
||||
tr_a = pd.concat([
|
||||
high_a - low_a,
|
||||
(high_a - close_a.shift(1)).abs(),
|
||||
(low_a - close_a.shift(1)).abs()
|
||||
], axis=1).max(axis=1)
|
||||
atr_a = tr_a.rolling(window=self.config.atr_period).mean()
|
||||
|
||||
# True Range for quote asset
|
||||
tr_b = pd.concat([
|
||||
high_b - low_b,
|
||||
(high_b - close_b.shift(1)).abs(),
|
||||
(low_b - close_b.shift(1)).abs()
|
||||
], axis=1).max(axis=1)
|
||||
atr_b = tr_b.rolling(window=self.config.atr_period).mean()
|
||||
|
||||
# ATR as percentage of price (normalized)
|
||||
atr_pct_a = atr_a / close_a
|
||||
atr_pct_b = atr_b / close_b
|
||||
|
||||
# Build feature DataFrame
|
||||
features = pd.DataFrame(index=common_idx)
|
||||
features['pair_id'] = pair.pair_id
|
||||
features['base_asset'] = base
|
||||
features['quote_asset'] = quote
|
||||
|
||||
# Price data (for reference, not features)
|
||||
features['spread'] = spread
|
||||
features['base_close'] = df_a['close']
|
||||
features['quote_close'] = df_b['close']
|
||||
features['base_volume'] = df_a['volume']
|
||||
|
||||
# Core Features
|
||||
features['z_score'] = z_score
|
||||
features['spread_rsi'] = spread_rsi
|
||||
features['spread_roc'] = spread_roc
|
||||
features['spread_change_1h'] = spread_change_1h
|
||||
features['vol_ratio'] = vol_ratio
|
||||
features['vol_ratio_rel'] = vol_ratio_rel
|
||||
features['vol_diff_ratio'] = vol_spread_ratio
|
||||
|
||||
# Volatility for SL/TP
|
||||
features['realized_vol_base'] = realized_vol_a
|
||||
features['realized_vol_quote'] = realized_vol_b
|
||||
features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2
|
||||
|
||||
# ATR for dynamic stops (in price units and as percentage)
|
||||
features['atr_base'] = atr_a
|
||||
features['atr_quote'] = atr_b
|
||||
features['atr_pct_base'] = atr_pct_a
|
||||
features['atr_pct_quote'] = atr_pct_b
|
||||
features['atr_pct_avg'] = (atr_pct_a + atr_pct_b) / 2
|
||||
|
||||
# Pair encoding (for universal model)
|
||||
# Using base and quote indices for hierarchical encoding
|
||||
assets = self.config.assets
|
||||
features['base_idx'] = assets.index(base) if base in assets else -1
|
||||
features['quote_idx'] = assets.index(quote) if quote in assets else -1
|
||||
|
||||
# Add funding and on-chain features
|
||||
# Funding data is always added from self._funding_data (OKX, all 10 assets)
|
||||
# On-chain data is optional (CryptoQuant, BTC/ETH only)
|
||||
features = self._add_on_chain_features(
|
||||
features, on_chain_data, base, quote
|
||||
)
|
||||
|
||||
# Drop rows with NaN in core features only (not funding/on-chain)
|
||||
core_cols = [
|
||||
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
|
||||
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
|
||||
'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg',
|
||||
'atr_base', 'atr_pct_base' # ATR is core for SL/TP
|
||||
]
|
||||
features = features.dropna(subset=core_cols)
|
||||
|
||||
# Fill missing funding/on-chain features with 0 (neutral)
|
||||
optional_cols = [
|
||||
'base_funding', 'quote_funding', 'funding_diff', 'funding_avg',
|
||||
'base_inflow', 'quote_inflow', 'inflow_ratio'
|
||||
]
|
||||
for col in optional_cols:
|
||||
if col in features.columns:
|
||||
features[col] = features[col].fillna(0)
|
||||
|
||||
return features
|
||||
|
||||
def calculate_all_pair_features(
|
||||
self,
|
||||
pairs: list[TradingPair],
|
||||
asset_data: dict[str, pd.DataFrame],
|
||||
on_chain_data: pd.DataFrame | None = None
|
||||
) -> dict[str, pd.DataFrame]:
|
||||
"""
|
||||
Calculate features for all pairs.
|
||||
|
||||
Args:
|
||||
pairs: List of trading pairs
|
||||
asset_data: Dictionary of OHLCV DataFrames
|
||||
on_chain_data: Optional on-chain data
|
||||
|
||||
Returns:
|
||||
Dictionary mapping pair_id to feature DataFrame
|
||||
"""
|
||||
all_features = {}
|
||||
|
||||
for pair in pairs:
|
||||
features = self.calculate_pair_features(
|
||||
pair, asset_data, on_chain_data
|
||||
)
|
||||
if features is not None and len(features) > 0:
|
||||
all_features[pair.pair_id] = features
|
||||
|
||||
logger.info(
|
||||
"Calculated features for %d/%d pairs",
|
||||
len(all_features), len(pairs)
|
||||
)
|
||||
|
||||
return all_features
|
||||
|
||||
def get_combined_features(
|
||||
self,
|
||||
pair_features: dict[str, pd.DataFrame],
|
||||
timestamp: pd.Timestamp | None = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Combine all pair features into a single DataFrame.
|
||||
|
||||
Useful for batch model prediction across all pairs.
|
||||
|
||||
Args:
|
||||
pair_features: Dictionary of feature DataFrames by pair_id
|
||||
timestamp: Optional specific timestamp to filter to
|
||||
|
||||
Returns:
|
||||
Combined DataFrame with all pairs as rows
|
||||
"""
|
||||
if not pair_features:
|
||||
return pd.DataFrame()
|
||||
|
||||
if timestamp is not None:
|
||||
# Get latest row from each pair at or before timestamp
|
||||
rows = []
|
||||
for pair_id, features in pair_features.items():
|
||||
valid = features[features.index <= timestamp]
|
||||
if len(valid) > 0:
|
||||
row = valid.iloc[-1:].copy()
|
||||
rows.append(row)
|
||||
|
||||
if rows:
|
||||
return pd.concat(rows, ignore_index=False)
|
||||
return pd.DataFrame()
|
||||
|
||||
# Combine all features (for training)
|
||||
return pd.concat(pair_features.values(), ignore_index=False)
|
||||
|
||||
def _add_on_chain_features(
|
||||
self,
|
||||
features: pd.DataFrame,
|
||||
on_chain_data: pd.DataFrame | None,
|
||||
base_asset: str,
|
||||
quote_asset: str
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Add on-chain and funding rate features for the pair.
|
||||
|
||||
Uses funding data from OKX (all 10 assets) and on-chain data
|
||||
from CryptoQuant (BTC/ETH only for inflows).
|
||||
"""
|
||||
base_short = base_asset.replace('-USDT', '').lower()
|
||||
quote_short = quote_asset.replace('-USDT', '').lower()
|
||||
|
||||
# Add funding rates from cached funding data
|
||||
if self._funding_data is not None and not self._funding_data.empty:
|
||||
funding_aligned = self._funding_data.reindex(
|
||||
features.index, method='ffill'
|
||||
)
|
||||
|
||||
base_funding_col = f'{base_short}_funding'
|
||||
quote_funding_col = f'{quote_short}_funding'
|
||||
|
||||
if base_funding_col in funding_aligned.columns:
|
||||
features['base_funding'] = funding_aligned[base_funding_col]
|
||||
if quote_funding_col in funding_aligned.columns:
|
||||
features['quote_funding'] = funding_aligned[quote_funding_col]
|
||||
|
||||
# Funding difference (positive = base has higher funding)
|
||||
if 'base_funding' in features.columns and 'quote_funding' in features.columns:
|
||||
features['funding_diff'] = (
|
||||
features['base_funding'] - features['quote_funding']
|
||||
)
|
||||
|
||||
# Funding sentiment: average of both assets
|
||||
features['funding_avg'] = (
|
||||
features['base_funding'] + features['quote_funding']
|
||||
) / 2
|
||||
|
||||
# Add on-chain features from CryptoQuant (BTC/ETH only)
|
||||
if on_chain_data is not None and not on_chain_data.empty:
|
||||
cq_aligned = on_chain_data.reindex(features.index, method='ffill')
|
||||
|
||||
# Inflows (only available for BTC/ETH)
|
||||
base_inflow_col = f'{base_short}_inflow'
|
||||
quote_inflow_col = f'{quote_short}_inflow'
|
||||
|
||||
if base_inflow_col in cq_aligned.columns:
|
||||
features['base_inflow'] = cq_aligned[base_inflow_col]
|
||||
if quote_inflow_col in cq_aligned.columns:
|
||||
features['quote_inflow'] = cq_aligned[quote_inflow_col]
|
||||
|
||||
if 'base_inflow' in features.columns and 'quote_inflow' in features.columns:
|
||||
features['inflow_ratio'] = (
|
||||
features['base_inflow'] /
|
||||
(features['quote_inflow'] + 1)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
def get_feature_columns(self) -> list[str]:
|
||||
"""
|
||||
Get list of feature columns for ML model.
|
||||
|
||||
Excludes metadata and target-related columns.
|
||||
|
||||
Returns:
|
||||
List of feature column names
|
||||
"""
|
||||
# Core features (always present)
|
||||
core_features = [
|
||||
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
|
||||
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
|
||||
'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg',
|
||||
'base_idx', 'quote_idx'
|
||||
]
|
||||
|
||||
# Funding features (now available for all 10 assets via OKX)
|
||||
funding_features = [
|
||||
'base_funding', 'quote_funding', 'funding_diff', 'funding_avg'
|
||||
]
|
||||
|
||||
# On-chain features (BTC/ETH only via CryptoQuant)
|
||||
onchain_features = [
|
||||
'base_inflow', 'quote_inflow', 'inflow_ratio'
|
||||
]
|
||||
|
||||
return core_features + funding_features + onchain_features
|
||||
272
strategies/multi_pair/funding.py
Normal file
272
strategies/multi_pair/funding.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Funding Rate Fetcher for Multi-Pair Strategy.
|
||||
|
||||
Fetches historical funding rates from OKX for all assets.
|
||||
CryptoQuant only supports BTC/ETH, so we use OKX for the full universe.
|
||||
"""
|
||||
import time
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import ccxt
|
||||
import pandas as pd
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FundingRateFetcher:
|
||||
"""
|
||||
Fetches and caches funding rate data from OKX.
|
||||
|
||||
OKX funding rates are settled every 8 hours (00:00, 08:00, 16:00 UTC).
|
||||
This fetcher retrieves historical funding rate data and aligns it
|
||||
to hourly candles for use in the multi-pair strategy.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str = "data/funding"):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.exchange: ccxt.okx | None = None
|
||||
|
||||
def _init_exchange(self) -> None:
|
||||
"""Initialize OKX exchange connection."""
|
||||
if self.exchange is None:
|
||||
self.exchange = ccxt.okx({
|
||||
'enableRateLimit': True,
|
||||
'options': {'defaultType': 'swap'}
|
||||
})
|
||||
self.exchange.load_markets()
|
||||
|
||||
def fetch_funding_history(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 100
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch historical funding rates for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Asset symbol (e.g., 'BTC-USDT')
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
limit: Max records per request
|
||||
|
||||
Returns:
|
||||
DataFrame with funding rate history
|
||||
"""
|
||||
self._init_exchange()
|
||||
|
||||
# Convert symbol format
|
||||
base = symbol.replace('-USDT', '')
|
||||
okx_symbol = f"{base}/USDT:USDT"
|
||||
|
||||
try:
|
||||
# OKX funding rate history endpoint
|
||||
# Uses fetch_funding_rate_history if available
|
||||
all_funding = []
|
||||
|
||||
# Parse dates
|
||||
if start_date:
|
||||
since = self.exchange.parse8601(f"{start_date}T00:00:00Z")
|
||||
else:
|
||||
# Default to 1 year ago
|
||||
since = self.exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
|
||||
|
||||
if end_date:
|
||||
until = self.exchange.parse8601(f"{end_date}T23:59:59Z")
|
||||
else:
|
||||
until = self.exchange.milliseconds()
|
||||
|
||||
# Fetch in batches
|
||||
current_since = since
|
||||
while current_since < until:
|
||||
try:
|
||||
funding = self.exchange.fetch_funding_rate_history(
|
||||
okx_symbol,
|
||||
since=current_since,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if not funding:
|
||||
break
|
||||
|
||||
all_funding.extend(funding)
|
||||
|
||||
# Move to next batch
|
||||
last_ts = funding[-1]['timestamp']
|
||||
if last_ts <= current_since:
|
||||
break
|
||||
current_since = last_ts + 1
|
||||
|
||||
time.sleep(0.1) # Rate limit
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error fetching funding batch for %s: %s",
|
||||
symbol, str(e)[:50]
|
||||
)
|
||||
break
|
||||
|
||||
if not all_funding:
|
||||
return pd.DataFrame()
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(all_funding)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
df.set_index('timestamp', inplace=True)
|
||||
df = df[['fundingRate']].rename(columns={'fundingRate': 'funding_rate'})
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
# Remove duplicates
|
||||
df = df[~df.index.duplicated(keep='first')]
|
||||
|
||||
logger.info("Fetched %d funding records for %s", len(df), symbol)
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to fetch funding for %s: %s", symbol, e)
|
||||
return pd.DataFrame()
|
||||
|
||||
def fetch_all_assets(
|
||||
self,
|
||||
assets: list[str],
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch funding rates for all assets and combine.
|
||||
|
||||
Args:
|
||||
assets: List of asset symbols (e.g., ['BTC-USDT', 'ETH-USDT'])
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
Returns:
|
||||
Combined DataFrame with columns like 'btc_funding', 'eth_funding', etc.
|
||||
"""
|
||||
combined = pd.DataFrame()
|
||||
|
||||
for symbol in assets:
|
||||
df = self.fetch_funding_history(symbol, start_date, end_date)
|
||||
|
||||
if df.empty:
|
||||
continue
|
||||
|
||||
# Rename column
|
||||
asset_name = symbol.replace('-USDT', '').lower()
|
||||
col_name = f"{asset_name}_funding"
|
||||
df = df.rename(columns={'funding_rate': col_name})
|
||||
|
||||
if combined.empty:
|
||||
combined = df
|
||||
else:
|
||||
combined = combined.join(df, how='outer')
|
||||
|
||||
time.sleep(0.2) # Be nice to API
|
||||
|
||||
# Forward fill to hourly (funding is every 8h)
|
||||
if not combined.empty:
|
||||
combined = combined.sort_index()
|
||||
combined = combined.ffill()
|
||||
|
||||
return combined
|
||||
|
||||
def save_to_cache(self, df: pd.DataFrame, filename: str = "funding_rates.csv") -> None:
|
||||
"""Save funding data to cache file."""
|
||||
path = self.cache_dir / filename
|
||||
df.to_csv(path)
|
||||
logger.info("Saved funding rates to %s", path)
|
||||
|
||||
def load_from_cache(self, filename: str = "funding_rates.csv") -> pd.DataFrame | None:
|
||||
"""Load funding data from cache if available."""
|
||||
path = self.cache_dir / filename
|
||||
if path.exists():
|
||||
df = pd.read_csv(path, index_col='timestamp', parse_dates=True)
|
||||
logger.info("Loaded funding rates from cache: %d rows", len(df))
|
||||
return df
|
||||
return None
|
||||
|
||||
def get_funding_data(
|
||||
self,
|
||||
assets: list[str],
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
use_cache: bool = True,
|
||||
force_refresh: bool = False
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Get funding data, using cache if available.
|
||||
|
||||
Args:
|
||||
assets: List of asset symbols
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
use_cache: Whether to use cached data
|
||||
force_refresh: Force refresh even if cache exists
|
||||
|
||||
Returns:
|
||||
DataFrame with funding rates for all assets
|
||||
"""
|
||||
cache_file = "funding_rates.csv"
|
||||
|
||||
# Try cache first
|
||||
if use_cache and not force_refresh:
|
||||
cached = self.load_from_cache(cache_file)
|
||||
if cached is not None:
|
||||
# Check if cache covers requested range
|
||||
if start_date and end_date:
|
||||
start_ts = pd.Timestamp(start_date, tz='UTC')
|
||||
end_ts = pd.Timestamp(end_date, tz='UTC')
|
||||
|
||||
if cached.index.min() <= start_ts and cached.index.max() >= end_ts:
|
||||
# Filter to requested range
|
||||
return cached[(cached.index >= start_ts) & (cached.index <= end_ts)]
|
||||
|
||||
# Fetch fresh data
|
||||
logger.info("Fetching fresh funding rate data...")
|
||||
df = self.fetch_all_assets(assets, start_date, end_date)
|
||||
|
||||
if not df.empty and use_cache:
|
||||
self.save_to_cache(df, cache_file)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def download_funding_data():
|
||||
"""Download funding data for all multi-pair assets."""
|
||||
from strategies.multi_pair.config import MultiPairConfig
|
||||
|
||||
config = MultiPairConfig()
|
||||
fetcher = FundingRateFetcher()
|
||||
|
||||
# Fetch last year of data
|
||||
end_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
start_date = (datetime.now(timezone.utc) - pd.Timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
|
||||
logger.info("Downloading funding rates for %d assets...", len(config.assets))
|
||||
logger.info("Date range: %s to %s", start_date, end_date)
|
||||
|
||||
df = fetcher.get_funding_data(
|
||||
config.assets,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
force_refresh=True
|
||||
)
|
||||
|
||||
if not df.empty:
|
||||
logger.info("Downloaded %d funding rate records", len(df))
|
||||
logger.info("Columns: %s", list(df.columns))
|
||||
else:
|
||||
logger.warning("No funding data downloaded")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from engine.logging_config import setup_logging
|
||||
setup_logging()
|
||||
download_funding_data()
|
||||
168
strategies/multi_pair/pair_scanner.py
Normal file
168
strategies/multi_pair/pair_scanner.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Pair Scanner for Multi-Pair Divergence Strategy.
|
||||
|
||||
Generates all possible pairs from asset universe and checks tradeability.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from itertools import combinations
|
||||
from typing import Optional
|
||||
|
||||
import ccxt
|
||||
|
||||
from engine.logging_config import get_logger
|
||||
from .config import MultiPairConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradingPair:
|
||||
"""
|
||||
Represents a tradeable pair for spread analysis.
|
||||
|
||||
Attributes:
|
||||
base_asset: First asset in the pair (numerator)
|
||||
quote_asset: Second asset in the pair (denominator)
|
||||
pair_id: Unique identifier for the pair
|
||||
is_direct: Whether pair can be traded directly on exchange
|
||||
exchange_symbol: Symbol for direct trading (if available)
|
||||
"""
|
||||
base_asset: str
|
||||
quote_asset: str
|
||||
pair_id: str
|
||||
is_direct: bool = False
|
||||
exchange_symbol: Optional[str] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Human-readable pair name."""
|
||||
return f"{self.base_asset}/{self.quote_asset}"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.pair_id)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, TradingPair):
|
||||
return False
|
||||
return self.pair_id == other.pair_id
|
||||
|
||||
|
||||
class PairScanner:
|
||||
"""
|
||||
Scans and generates tradeable pairs from asset universe.
|
||||
|
||||
Checks OKX for directly tradeable cross-pairs and generates
|
||||
synthetic pairs via USDT for others.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MultiPairConfig):
|
||||
self.config = config
|
||||
self.exchange: Optional[ccxt.Exchange] = None
|
||||
self._available_markets: set[str] = set()
|
||||
|
||||
def _init_exchange(self) -> None:
|
||||
"""Initialize exchange connection for market lookup."""
|
||||
if self.exchange is None:
|
||||
exchange_class = getattr(ccxt, self.config.exchange_id)
|
||||
self.exchange = exchange_class({'enableRateLimit': True})
|
||||
self.exchange.load_markets()
|
||||
self._available_markets = set(self.exchange.symbols)
|
||||
logger.info(
|
||||
"Loaded %d markets from %s",
|
||||
len(self._available_markets),
|
||||
self.config.exchange_id
|
||||
)
|
||||
|
||||
def generate_pairs(self, check_exchange: bool = True) -> list[TradingPair]:
|
||||
"""
|
||||
Generate all unique pairs from asset universe.
|
||||
|
||||
Args:
|
||||
check_exchange: Whether to check OKX for direct trading
|
||||
|
||||
Returns:
|
||||
List of TradingPair objects
|
||||
"""
|
||||
if check_exchange:
|
||||
self._init_exchange()
|
||||
|
||||
pairs = []
|
||||
assets = self.config.assets
|
||||
|
||||
for base, quote in combinations(assets, 2):
|
||||
pair_id = f"{base}__{quote}"
|
||||
|
||||
# Check if directly tradeable as cross-pair on OKX
|
||||
is_direct = False
|
||||
exchange_symbol = None
|
||||
|
||||
if check_exchange:
|
||||
# Check perpetual cross-pair (e.g., ETH/BTC:BTC)
|
||||
# OKX perpetuals are typically quoted in USDT
|
||||
# Cross-pairs like ETH/BTC are less common
|
||||
cross_symbol = f"{base.replace('-USDT', '')}/{quote.replace('-USDT', '')}:USDT"
|
||||
if cross_symbol in self._available_markets:
|
||||
is_direct = True
|
||||
exchange_symbol = cross_symbol
|
||||
|
||||
pair = TradingPair(
|
||||
base_asset=base,
|
||||
quote_asset=quote,
|
||||
pair_id=pair_id,
|
||||
is_direct=is_direct,
|
||||
exchange_symbol=exchange_symbol
|
||||
)
|
||||
pairs.append(pair)
|
||||
|
||||
# Log summary
|
||||
direct_count = sum(1 for p in pairs if p.is_direct)
|
||||
logger.info(
|
||||
"Generated %d pairs: %d direct, %d synthetic",
|
||||
len(pairs), direct_count, len(pairs) - direct_count
|
||||
)
|
||||
|
||||
return pairs
|
||||
|
||||
def get_required_symbols(self, pairs: list[TradingPair]) -> list[str]:
|
||||
"""
|
||||
Get list of symbols needed to calculate all pair spreads.
|
||||
|
||||
For synthetic pairs, we need both USDT pairs.
|
||||
For direct pairs, we still load USDT pairs for simplicity.
|
||||
|
||||
Args:
|
||||
pairs: List of trading pairs
|
||||
|
||||
Returns:
|
||||
List of unique symbols to load (e.g., ['BTC-USDT', 'ETH-USDT'])
|
||||
"""
|
||||
symbols = set()
|
||||
for pair in pairs:
|
||||
symbols.add(pair.base_asset)
|
||||
symbols.add(pair.quote_asset)
|
||||
return list(symbols)
|
||||
|
||||
def filter_by_assets(
|
||||
self,
|
||||
pairs: list[TradingPair],
|
||||
exclude_assets: list[str]
|
||||
) -> list[TradingPair]:
|
||||
"""
|
||||
Filter pairs that contain any of the excluded assets.
|
||||
|
||||
Args:
|
||||
pairs: List of trading pairs
|
||||
exclude_assets: Assets to exclude
|
||||
|
||||
Returns:
|
||||
Filtered list of pairs
|
||||
"""
|
||||
if not exclude_assets:
|
||||
return pairs
|
||||
|
||||
exclude_set = set(exclude_assets)
|
||||
return [
|
||||
p for p in pairs
|
||||
if p.base_asset not in exclude_set
|
||||
and p.quote_asset not in exclude_set
|
||||
]
|
||||
525
strategies/multi_pair/strategy.py
Normal file
525
strategies/multi_pair/strategy.py
Normal file
@@ -0,0 +1,525 @@
|
||||
"""
|
||||
Multi-Pair Divergence Selection Strategy.
|
||||
|
||||
Main strategy class that orchestrates pair scanning, feature calculation,
|
||||
model training, and signal generation for backtesting.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from strategies.base import BaseStrategy
|
||||
from engine.market import MarketType
|
||||
from engine.logging_config import get_logger
|
||||
from .config import MultiPairConfig
|
||||
from .pair_scanner import PairScanner, TradingPair
|
||||
from .correlation import CorrelationFilter
|
||||
from .feature_engine import MultiPairFeatureEngine
|
||||
from .divergence_scorer import DivergenceScorer, DivergenceSignal
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PositionState:
|
||||
"""Tracks current position state."""
|
||||
pair: TradingPair | None = None
|
||||
direction: str | None = None # 'long' or 'short'
|
||||
entry_price: float = 0.0
|
||||
entry_idx: int = -1
|
||||
stop_loss: float = 0.0
|
||||
take_profit: float = 0.0
|
||||
atr: float = 0.0 # ATR at entry for reference
|
||||
last_exit_idx: int = -100 # For cooldown tracking
|
||||
|
||||
|
||||
class MultiPairDivergenceStrategy(BaseStrategy):
|
||||
"""
|
||||
Multi-Pair Divergence Selection Strategy.
|
||||
|
||||
Scans multiple cryptocurrency pairs for spread divergence,
|
||||
selects the most divergent pair using ML-enhanced scoring,
|
||||
and trades mean-reversion opportunities.
|
||||
|
||||
Key Features:
|
||||
- Universal ML model across all pairs
|
||||
- Correlation-based pair filtering
|
||||
- Dynamic SL/TP based on volatility
|
||||
- Walk-forward training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MultiPairConfig | None = None,
|
||||
model_path: str = "data/multi_pair_model.pkl"
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config or MultiPairConfig()
|
||||
|
||||
# Initialize components
|
||||
self.pair_scanner = PairScanner(self.config)
|
||||
self.correlation_filter = CorrelationFilter(self.config)
|
||||
self.feature_engine = MultiPairFeatureEngine(self.config)
|
||||
self.divergence_scorer = DivergenceScorer(self.config, model_path)
|
||||
|
||||
# Strategy configuration
|
||||
self.default_market_type = MarketType.PERPETUAL
|
||||
self.default_leverage = 1
|
||||
|
||||
# Runtime state
|
||||
self.pairs: list[TradingPair] = []
|
||||
self.asset_data: dict[str, pd.DataFrame] = {}
|
||||
self.pair_features: dict[str, pd.DataFrame] = {}
|
||||
self.position = PositionState()
|
||||
self.train_end_idx: int = 0
|
||||
|
||||
def run(self, close: pd.Series, **kwargs) -> tuple:
|
||||
"""
|
||||
Execute the multi-pair divergence strategy.
|
||||
|
||||
This method is called by the backtester with the primary asset's
|
||||
close prices. For multi-pair, we load all assets internally.
|
||||
|
||||
Args:
|
||||
close: Primary close prices (used for index alignment)
|
||||
**kwargs: Additional data (high, low, volume)
|
||||
|
||||
Returns:
|
||||
Tuple of (long_entries, long_exits, short_entries, short_exits, size)
|
||||
"""
|
||||
logger.info("Starting Multi-Pair Divergence Strategy")
|
||||
|
||||
# 1. Load all asset data
|
||||
start_date = close.index.min().strftime("%Y-%m-%d")
|
||||
end_date = close.index.max().strftime("%Y-%m-%d")
|
||||
|
||||
self.asset_data = self.feature_engine.load_all_assets(
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
# 1b. Load funding rate data for all assets
|
||||
self.feature_engine.load_funding_data(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
use_cache=True
|
||||
)
|
||||
|
||||
if len(self.asset_data) < 2:
|
||||
logger.error("Insufficient assets loaded, need at least 2")
|
||||
return self._empty_signals(close)
|
||||
|
||||
# 2. Generate pairs
|
||||
self.pairs = self.pair_scanner.generate_pairs(check_exchange=False)
|
||||
|
||||
# Filter to pairs with available data
|
||||
available_assets = set(self.asset_data.keys())
|
||||
self.pairs = [
|
||||
p for p in self.pairs
|
||||
if p.base_asset in available_assets
|
||||
and p.quote_asset in available_assets
|
||||
]
|
||||
|
||||
logger.info("Trading %d pairs from %d assets", len(self.pairs), len(self.asset_data))
|
||||
|
||||
# 3. Calculate features for all pairs
|
||||
self.pair_features = self.feature_engine.calculate_all_pair_features(
|
||||
self.pairs, self.asset_data
|
||||
)
|
||||
|
||||
if not self.pair_features:
|
||||
logger.error("No pair features calculated")
|
||||
return self._empty_signals(close)
|
||||
|
||||
# 4. Align to common index
|
||||
common_index = self._get_common_index()
|
||||
if len(common_index) < 200:
|
||||
logger.error("Insufficient common data across pairs")
|
||||
return self._empty_signals(close)
|
||||
|
||||
# 5. Walk-forward split
|
||||
n_samples = len(common_index)
|
||||
train_size = int(n_samples * self.config.train_ratio)
|
||||
self.train_end_idx = train_size
|
||||
|
||||
train_end_date = common_index[train_size - 1]
|
||||
test_start_date = common_index[train_size]
|
||||
|
||||
logger.info(
|
||||
"Walk-Forward Split: Train=%d bars (until %s), Test=%d bars (from %s)",
|
||||
train_size, train_end_date.strftime('%Y-%m-%d'),
|
||||
n_samples - train_size, test_start_date.strftime('%Y-%m-%d')
|
||||
)
|
||||
|
||||
# 6. Train model on training period
|
||||
if self.divergence_scorer.model is None:
|
||||
train_features = {
|
||||
pid: feat[feat.index <= train_end_date]
|
||||
for pid, feat in self.pair_features.items()
|
||||
}
|
||||
combined = self.feature_engine.get_combined_features(train_features)
|
||||
self.divergence_scorer.train_model(combined, train_features)
|
||||
|
||||
# 7. Generate signals for test period
|
||||
return self._generate_signals(common_index, train_size, close)
|
||||
|
||||
def _generate_signals(
|
||||
self,
|
||||
index: pd.DatetimeIndex,
|
||||
train_size: int,
|
||||
reference_close: pd.Series
|
||||
) -> tuple:
|
||||
"""
|
||||
Generate entry/exit signals for the test period.
|
||||
|
||||
Iterates through each bar in the test period, scoring pairs
|
||||
and generating signals based on divergence scores.
|
||||
"""
|
||||
# Initialize signal arrays aligned to reference close
|
||||
long_entries = pd.Series(False, index=reference_close.index)
|
||||
long_exits = pd.Series(False, index=reference_close.index)
|
||||
short_entries = pd.Series(False, index=reference_close.index)
|
||||
short_exits = pd.Series(False, index=reference_close.index)
|
||||
size = pd.Series(1.0, index=reference_close.index)
|
||||
|
||||
# Track position state
|
||||
self.position = PositionState()
|
||||
|
||||
# Price data for correlation calculation
|
||||
price_data = {
|
||||
symbol: df['close'] for symbol, df in self.asset_data.items()
|
||||
}
|
||||
|
||||
# Iterate through test period
|
||||
test_indices = index[train_size:]
|
||||
|
||||
trade_count = 0
|
||||
|
||||
for i, timestamp in enumerate(test_indices):
|
||||
current_idx = train_size + i
|
||||
|
||||
# Check exit conditions first
|
||||
if self.position.pair is not None:
|
||||
# Enforce minimum hold period
|
||||
bars_held = current_idx - self.position.entry_idx
|
||||
if bars_held < self.config.min_hold_bars:
|
||||
# Only allow SL/TP exits during min hold period
|
||||
should_exit, exit_reason = self._check_sl_tp_only(timestamp)
|
||||
else:
|
||||
should_exit, exit_reason = self._check_exit(timestamp)
|
||||
|
||||
if should_exit:
|
||||
# Map exit signal to reference index
|
||||
if timestamp in reference_close.index:
|
||||
if self.position.direction == 'long':
|
||||
long_exits.loc[timestamp] = True
|
||||
else:
|
||||
short_exits.loc[timestamp] = True
|
||||
|
||||
logger.debug(
|
||||
"Exit %s %s at %s: %s (held %d bars)",
|
||||
self.position.direction,
|
||||
self.position.pair.name,
|
||||
timestamp.strftime('%Y-%m-%d %H:%M'),
|
||||
exit_reason,
|
||||
bars_held
|
||||
)
|
||||
self.position = PositionState(last_exit_idx=current_idx)
|
||||
|
||||
# Score pairs (with correlation filter if position exists)
|
||||
held_asset = None
|
||||
if self.position.pair is not None:
|
||||
held_asset = self.position.pair.base_asset
|
||||
|
||||
# Filter pairs by correlation
|
||||
candidate_pairs = self.correlation_filter.filter_pairs(
|
||||
self.pairs,
|
||||
held_asset,
|
||||
price_data,
|
||||
current_idx
|
||||
)
|
||||
|
||||
# Get candidate features
|
||||
candidate_features = {
|
||||
pid: feat for pid, feat in self.pair_features.items()
|
||||
if any(p.pair_id == pid for p in candidate_pairs)
|
||||
}
|
||||
|
||||
# Score pairs
|
||||
signals = self.divergence_scorer.score_pairs(
|
||||
candidate_features,
|
||||
candidate_pairs,
|
||||
timestamp
|
||||
)
|
||||
|
||||
# Get best signal
|
||||
best = self.divergence_scorer.select_best_pair(signals)
|
||||
|
||||
if best is None:
|
||||
continue
|
||||
|
||||
# Check if we should switch positions or enter new
|
||||
should_enter = False
|
||||
|
||||
# Check cooldown
|
||||
bars_since_exit = current_idx - self.position.last_exit_idx
|
||||
in_cooldown = bars_since_exit < self.config.cooldown_bars
|
||||
|
||||
if self.position.pair is None and not in_cooldown:
|
||||
# No position and not in cooldown, can enter
|
||||
should_enter = True
|
||||
elif self.position.pair is not None:
|
||||
# Check if we should switch (requires min hold + significant improvement)
|
||||
bars_held = current_idx - self.position.entry_idx
|
||||
current_score = self._get_current_score(timestamp)
|
||||
|
||||
if (bars_held >= self.config.min_hold_bars and
|
||||
best.divergence_score > current_score * self.config.switch_threshold):
|
||||
# New opportunity is significantly better
|
||||
if timestamp in reference_close.index:
|
||||
if self.position.direction == 'long':
|
||||
long_exits.loc[timestamp] = True
|
||||
else:
|
||||
short_exits.loc[timestamp] = True
|
||||
self.position = PositionState(last_exit_idx=current_idx)
|
||||
should_enter = True
|
||||
|
||||
if should_enter:
|
||||
# Calculate ATR-based dynamic SL/TP
|
||||
sl_price, tp_price = self._calculate_sl_tp(
|
||||
best.base_price,
|
||||
best.direction,
|
||||
best.atr,
|
||||
best.atr_pct
|
||||
)
|
||||
|
||||
# Set position
|
||||
self.position = PositionState(
|
||||
pair=best.pair,
|
||||
direction=best.direction,
|
||||
entry_price=best.base_price,
|
||||
entry_idx=current_idx,
|
||||
stop_loss=sl_price,
|
||||
take_profit=tp_price,
|
||||
atr=best.atr
|
||||
)
|
||||
|
||||
# Calculate position size based on divergence
|
||||
pos_size = self._calculate_size(best.divergence_score)
|
||||
|
||||
# Generate entry signal
|
||||
if timestamp in reference_close.index:
|
||||
if best.direction == 'long':
|
||||
long_entries.loc[timestamp] = True
|
||||
else:
|
||||
short_entries.loc[timestamp] = True
|
||||
size.loc[timestamp] = pos_size
|
||||
|
||||
trade_count += 1
|
||||
logger.debug(
|
||||
"Entry %s %s at %s: z=%.2f, prob=%.2f, score=%.3f",
|
||||
best.direction,
|
||||
best.pair.name,
|
||||
timestamp.strftime('%Y-%m-%d %H:%M'),
|
||||
best.z_score,
|
||||
best.probability,
|
||||
best.divergence_score
|
||||
)
|
||||
|
||||
logger.info("Generated %d trades in test period", trade_count)
|
||||
|
||||
return long_entries, long_exits, short_entries, short_exits, size
|
||||
|
||||
def _check_exit(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
|
||||
"""
|
||||
Check if current position should be exited.
|
||||
|
||||
Exit conditions:
|
||||
1. Z-Score reverted to mean (|Z| < threshold)
|
||||
2. Stop-loss hit
|
||||
3. Take-profit hit
|
||||
|
||||
Returns:
|
||||
Tuple of (should_exit, reason)
|
||||
"""
|
||||
if self.position.pair is None:
|
||||
return False, ""
|
||||
|
||||
pair_id = self.position.pair.pair_id
|
||||
if pair_id not in self.pair_features:
|
||||
return True, "pair_data_missing"
|
||||
|
||||
features = self.pair_features[pair_id]
|
||||
valid = features[features.index <= timestamp]
|
||||
|
||||
if len(valid) == 0:
|
||||
return True, "no_data"
|
||||
|
||||
latest = valid.iloc[-1]
|
||||
z_score = latest['z_score']
|
||||
current_price = latest['base_close']
|
||||
|
||||
# Check mean reversion (primary exit)
|
||||
if abs(z_score) < self.config.z_exit_threshold:
|
||||
return True, f"mean_reversion (z={z_score:.2f})"
|
||||
|
||||
# Check SL/TP
|
||||
return self._check_sl_tp(current_price)
|
||||
|
||||
def _check_sl_tp_only(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
|
||||
"""
|
||||
Check only stop-loss and take-profit conditions.
|
||||
Used during minimum hold period.
|
||||
"""
|
||||
if self.position.pair is None:
|
||||
return False, ""
|
||||
|
||||
pair_id = self.position.pair.pair_id
|
||||
if pair_id not in self.pair_features:
|
||||
return True, "pair_data_missing"
|
||||
|
||||
features = self.pair_features[pair_id]
|
||||
valid = features[features.index <= timestamp]
|
||||
|
||||
if len(valid) == 0:
|
||||
return True, "no_data"
|
||||
|
||||
latest = valid.iloc[-1]
|
||||
current_price = latest['base_close']
|
||||
|
||||
return self._check_sl_tp(current_price)
|
||||
|
||||
def _check_sl_tp(self, current_price: float) -> tuple[bool, str]:
|
||||
"""Check stop-loss and take-profit levels."""
|
||||
if self.position.direction == 'long':
|
||||
if current_price <= self.position.stop_loss:
|
||||
return True, f"stop_loss ({current_price:.2f} <= {self.position.stop_loss:.2f})"
|
||||
if current_price >= self.position.take_profit:
|
||||
return True, f"take_profit ({current_price:.2f} >= {self.position.take_profit:.2f})"
|
||||
else: # short
|
||||
if current_price >= self.position.stop_loss:
|
||||
return True, f"stop_loss ({current_price:.2f} >= {self.position.stop_loss:.2f})"
|
||||
if current_price <= self.position.take_profit:
|
||||
return True, f"take_profit ({current_price:.2f} <= {self.position.take_profit:.2f})"
|
||||
|
||||
return False, ""
|
||||
|
||||
def _get_current_score(self, timestamp: pd.Timestamp) -> float:
|
||||
"""Get current position's divergence score for comparison."""
|
||||
if self.position.pair is None:
|
||||
return 0.0
|
||||
|
||||
pair_id = self.position.pair.pair_id
|
||||
if pair_id not in self.pair_features:
|
||||
return 0.0
|
||||
|
||||
features = self.pair_features[pair_id]
|
||||
valid = features[features.index <= timestamp]
|
||||
|
||||
if len(valid) == 0:
|
||||
return 0.0
|
||||
|
||||
latest = valid.iloc[-1]
|
||||
z_score = abs(latest['z_score'])
|
||||
|
||||
# Re-score with model
|
||||
if self.divergence_scorer.model is not None:
|
||||
feature_row = latest[self.divergence_scorer.feature_cols].fillna(0)
|
||||
feature_row = feature_row.replace([np.inf, -np.inf], 0)
|
||||
X = pd.DataFrame(
|
||||
[feature_row.values],
|
||||
columns=self.divergence_scorer.feature_cols
|
||||
)
|
||||
prob = self.divergence_scorer.model.predict_proba(X)[0, 1]
|
||||
return z_score * prob
|
||||
|
||||
return z_score * 0.5
|
||||
|
||||
def _calculate_sl_tp(
|
||||
self,
|
||||
entry_price: float,
|
||||
direction: str,
|
||||
atr: float,
|
||||
atr_pct: float
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate ATR-based dynamic stop-loss and take-profit prices.
|
||||
|
||||
Uses ATR (Average True Range) to set stops that adapt to
|
||||
each asset's volatility. More volatile assets get wider stops.
|
||||
|
||||
Args:
|
||||
entry_price: Entry price
|
||||
direction: 'long' or 'short'
|
||||
atr: ATR in price units
|
||||
atr_pct: ATR as percentage of price
|
||||
|
||||
Returns:
|
||||
Tuple of (stop_loss_price, take_profit_price)
|
||||
"""
|
||||
# Calculate SL/TP as ATR multiples
|
||||
if atr > 0 and atr_pct > 0:
|
||||
# ATR-based calculation
|
||||
sl_distance = atr * self.config.sl_atr_multiplier
|
||||
tp_distance = atr * self.config.tp_atr_multiplier
|
||||
|
||||
# Convert to percentage for bounds checking
|
||||
sl_pct = sl_distance / entry_price
|
||||
tp_pct = tp_distance / entry_price
|
||||
else:
|
||||
# Fallback to fixed percentages if ATR unavailable
|
||||
sl_pct = self.config.base_sl_pct
|
||||
tp_pct = self.config.base_tp_pct
|
||||
|
||||
# Apply bounds to prevent extreme stops
|
||||
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
|
||||
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
|
||||
|
||||
# Calculate actual prices
|
||||
if direction == 'long':
|
||||
stop_loss = entry_price * (1 - sl_pct)
|
||||
take_profit = entry_price * (1 + tp_pct)
|
||||
else: # short
|
||||
stop_loss = entry_price * (1 + sl_pct)
|
||||
take_profit = entry_price * (1 - tp_pct)
|
||||
|
||||
return stop_loss, take_profit
|
||||
|
||||
def _calculate_size(self, divergence_score: float) -> float:
|
||||
"""
|
||||
Calculate position size based on divergence score.
|
||||
|
||||
Higher divergence = larger position (up to 2x).
|
||||
"""
|
||||
# Base score threshold (Z=1.0, prob=0.5 -> score=0.5)
|
||||
base_threshold = 0.5
|
||||
|
||||
# Scale factor
|
||||
if divergence_score <= base_threshold:
|
||||
return 1.0
|
||||
|
||||
# Linear scaling: 1.0 at threshold, up to 2.0 at 2x threshold
|
||||
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
|
||||
return min(scale, 2.0)
|
||||
|
||||
def _get_common_index(self) -> pd.DatetimeIndex:
|
||||
"""Get the intersection of all pair feature indices."""
|
||||
if not self.pair_features:
|
||||
return pd.DatetimeIndex([])
|
||||
|
||||
common = None
|
||||
for features in self.pair_features.values():
|
||||
if common is None:
|
||||
common = features.index
|
||||
else:
|
||||
common = common.intersection(features.index)
|
||||
|
||||
return common.sort_values()
|
||||
|
||||
def _empty_signals(self, close: pd.Series) -> tuple:
|
||||
"""Return empty signal arrays."""
|
||||
empty = self.create_empty_signals(close)
|
||||
size = pd.Series(1.0, index=close.index)
|
||||
return empty, empty, empty, empty, size
|
||||
321
tasks/prd-multi-pair-divergence-strategy.md
Normal file
321
tasks/prd-multi-pair-divergence-strategy.md
Normal file
@@ -0,0 +1,321 @@
|
||||
# PRD: Multi-Pair Divergence Selection Strategy
|
||||
|
||||
## 1. Introduction / Overview
|
||||
|
||||
This document describes the **Multi-Pair Divergence Selection Strategy**, an extension of the existing BTC/ETH regime reversion system. The strategy expands spread analysis to the **top 10 cryptocurrencies by market cap**, calculates divergence scores for all tradeable pairs, and dynamically selects the **most divergent pair** for trading.
|
||||
|
||||
The core hypothesis: by scanning multiple pairs simultaneously, we can identify stronger mean-reversion opportunities than focusing on a single pair, improving net PnL while maintaining the proven ML-based regime detection approach.
|
||||
|
||||
---
|
||||
|
||||
## 2. Goals
|
||||
|
||||
1. **Extend regime detection** to top 10 market cap cryptocurrencies
|
||||
2. **Dynamically select** the most divergent tradeable pair each cycle
|
||||
3. **Integrate volatility** into dynamic SL/TP calculations
|
||||
4. **Filter correlated pairs** to avoid redundant positions
|
||||
5. **Improve net PnL** compared to single-pair BTC/ETH strategy
|
||||
6. **Backtest-first** implementation with walk-forward validation
|
||||
|
||||
---
|
||||
|
||||
## 3. User Stories
|
||||
|
||||
### US-1: Multi-Pair Analysis
|
||||
> As a trader, I want the system to analyze spread divergence across multiple cryptocurrency pairs so that I can identify the best trading opportunity at any given moment.
|
||||
|
||||
### US-2: Dynamic Pair Selection
|
||||
> As a trader, I want the system to automatically select and trade the pair with the highest divergence score (combination of Z-score magnitude and ML probability) so that I maximize mean-reversion profit potential.
|
||||
|
||||
### US-3: Volatility-Adjusted Risk
|
||||
> As a trader, I want stop-loss and take-profit levels to adapt to each pair's volatility so that I avoid being stopped out prematurely on volatile assets while protecting profits on stable ones.
|
||||
|
||||
### US-4: Correlation Filtering
|
||||
> As a trader, I want the system to avoid selecting pairs that are highly correlated with my current position so that I don't inadvertently double-down on the same market exposure.
|
||||
|
||||
### US-5: Backtest Validation
|
||||
> As a researcher, I want to backtest this multi-pair strategy with walk-forward training so that I can validate improvement over the single-pair baseline without look-ahead bias.
|
||||
|
||||
---
|
||||
|
||||
## 4. Functional Requirements
|
||||
|
||||
### 4.1 Data Management
|
||||
|
||||
| ID | Requirement |
|
||||
|----|-------------|
|
||||
| FR-1.1 | System must support loading OHLCV data for top 10 market cap cryptocurrencies |
|
||||
| FR-1.2 | Target assets: BTC, ETH, SOL, XRP, BNB, DOGE, ADA, AVAX, LINK, DOT (configurable) |
|
||||
| FR-1.3 | System must identify all directly tradeable cross-pairs on OKX perpetuals |
|
||||
| FR-1.4 | System must align timestamps across all pairs for synchronized analysis |
|
||||
| FR-1.5 | System must handle missing data gracefully (skip pair if insufficient history) |
|
||||
|
||||
### 4.2 Pair Generation
|
||||
|
||||
| ID | Requirement |
|
||||
|----|-------------|
|
||||
| FR-2.1 | Generate all unique pairs from asset universe: N*(N-1)/2 pairs (e.g., 45 pairs for 10 assets) |
|
||||
| FR-2.2 | Filter pairs to only those directly tradeable on OKX (no USDT intermediate) |
|
||||
| FR-2.3 | Fallback: If cross-pair not available, calculate synthetic spread via USDT pairs |
|
||||
| FR-2.4 | Store pair metadata: base asset, quote asset, exchange symbol, tradeable flag |
|
||||
|
||||
### 4.3 Feature Engineering (Per Pair)
|
||||
|
||||
| ID | Requirement |
|
||||
|----|-------------|
|
||||
| FR-3.1 | Calculate spread ratio: `asset_a_close / asset_b_close` |
|
||||
| FR-3.2 | Calculate Z-Score with configurable rolling window (default: 24h) |
|
||||
| FR-3.3 | Calculate spread technicals: RSI(14), ROC(5), 1h change |
|
||||
| FR-3.4 | Calculate volume ratio and relative volume |
|
||||
| FR-3.5 | Calculate volatility ratio: `std(returns_a) / std(returns_b)` over Z-window |
|
||||
| FR-3.6 | Calculate realized volatility for each asset (for dynamic SL/TP) |
|
||||
| FR-3.7 | Merge on-chain data (funding rates, inflows) if available per asset |
|
||||
| FR-3.8 | Add pair identifier as categorical feature for universal model |
|
||||
|
||||
### 4.4 Correlation Filtering
|
||||
|
||||
| ID | Requirement |
|
||||
|----|-------------|
|
||||
| FR-4.1 | Calculate rolling correlation matrix between all assets (default: 168h / 7 days) |
|
||||
| FR-4.2 | Define correlation threshold (default: 0.85) |
|
||||
| FR-4.3 | If current position exists, exclude pairs where either asset has correlation > threshold with held asset |
|
||||
| FR-4.4 | Log filtered pairs with reason for exclusion |
|
||||
|
||||
### 4.5 Divergence Scoring & Pair Selection
|
||||
|
||||
| ID | Requirement |
|
||||
|----|-------------|
|
||||
| FR-5.1 | Calculate divergence score: `abs(z_score) * model_probability` |
|
||||
| FR-5.2 | Only consider pairs where `abs(z_score) > z_entry_threshold` (default: 1.0) |
|
||||
| FR-5.3 | Only consider pairs where `model_probability > prob_threshold` (default: 0.5) |
|
||||
| FR-5.4 | Apply correlation filter to eligible pairs |
|
||||
| FR-5.5 | Select pair with highest divergence score |
|
||||
| FR-5.6 | If no pair qualifies, signal "hold" |
|
||||
| FR-5.7 | Log all pair scores for analysis/debugging |
|
||||
|
||||
### 4.6 ML Model (Universal)
|
||||
|
||||
| ID | Requirement |
|
||||
|----|-------------|
|
||||
| FR-6.1 | Train single Random Forest model on all pairs combined |
|
||||
| FR-6.2 | Include `pair_id` as one-hot encoded or label-encoded feature |
|
||||
| FR-6.3 | Target: binary (1 = profitable reversion within horizon, 0 = no reversion) |
|
||||
| FR-6.4 | Walk-forward training: 70% train / 30% test split |
|
||||
| FR-6.5 | Daily retraining schedule (for live, configurable for backtest) |
|
||||
| FR-6.6 | Model hyperparameters: `n_estimators=300, max_depth=5, min_samples_leaf=30, class_weight={0:1, 1:3}` |
|
||||
| FR-6.7 | Save/load model with feature column metadata |
|
||||
|
||||
### 4.7 Signal Generation
|
||||
|
||||
| ID | Requirement |
|
||||
|----|-------------|
|
||||
| FR-7.1 | Direction: If `z_score > threshold` -> Short spread (short asset_a), If `z_score < -threshold` -> Long spread (long asset_a) |
|
||||
| FR-7.2 | Apply funding rate filter per asset (block if extreme funding opposes direction) |
|
||||
| FR-7.3 | Output signal: `{pair, action, side, probability, z_score, divergence_score, reason}` |
|
||||
|
||||
### 4.8 Position Sizing
|
||||
|
||||
| ID | Requirement |
|
||||
|----|-------------|
|
||||
| FR-8.1 | Base size: 100% of available subaccount balance |
|
||||
| FR-8.2 | Scale by divergence: `size_multiplier = 1.0 + (divergence_score - base_threshold) * scaling_factor` |
|
||||
| FR-8.3 | Cap multiplier between 1.0x and 2.0x |
|
||||
| FR-8.4 | Respect exchange minimum order size per asset |
|
||||
|
||||
### 4.9 Dynamic SL/TP (Volatility-Adjusted)
|
||||
|
||||
| ID | Requirement |
|
||||
|----|-------------|
|
||||
| FR-9.1 | Calculate asset realized volatility: `std(returns) * sqrt(24)` for daily vol |
|
||||
| FR-9.2 | Base SL: `entry_price * (1 - base_sl_pct * vol_multiplier)` for longs |
|
||||
| FR-9.3 | Base TP: `entry_price * (1 + base_tp_pct * vol_multiplier)` for longs |
|
||||
| FR-9.4 | `vol_multiplier = asset_volatility / baseline_volatility` (baseline = BTC volatility) |
|
||||
| FR-9.5 | Cap vol_multiplier between 0.5x and 2.0x to prevent extreme values |
|
||||
| FR-9.6 | Invert logic for short positions |
|
||||
|
||||
### 4.10 Exit Conditions
|
||||
|
||||
| ID | Requirement |
|
||||
|----|-------------|
|
||||
| FR-10.1 | Exit when Z-score crosses back through 0 (mean reversion complete) |
|
||||
| FR-10.2 | Exit when dynamic SL or TP hit |
|
||||
| FR-10.3 | No minimum holding period (can switch pairs immediately) |
|
||||
| FR-10.4 | If new pair has higher divergence score, close current and open new |
|
||||
|
||||
### 4.11 Backtest Integration
|
||||
|
||||
| ID | Requirement |
|
||||
|----|-------------|
|
||||
| FR-11.1 | Integrate with existing `engine/backtester.py` framework |
|
||||
| FR-11.2 | Support 1h timeframe (matching live trading) |
|
||||
| FR-11.3 | Walk-forward validation: train on 70%, test on 30% |
|
||||
| FR-11.4 | Output: trades log, equity curve, performance metrics |
|
||||
| FR-11.5 | Compare against single-pair BTC/ETH baseline |
|
||||
|
||||
---
|
||||
|
||||
## 5. Non-Goals (Out of Scope)
|
||||
|
||||
1. **Live trading implementation** - Backtest validation first
|
||||
2. **Multi-position portfolio** - Single pair at a time for v1
|
||||
3. **Cross-exchange arbitrage** - OKX only
|
||||
4. **Alternative ML models** - Stick with Random Forest for consistency
|
||||
5. **Sub-1h timeframes** - 1h candles only for initial version
|
||||
6. **Leveraged positions** - 1x leverage for backtest
|
||||
7. **Portfolio-level VaR/risk budgeting** - Full subaccount allocation
|
||||
|
||||
---
|
||||
|
||||
## 6. Design Considerations
|
||||
|
||||
### 6.1 Architecture
|
||||
|
||||
```
|
||||
strategies/
|
||||
multi_pair/
|
||||
__init__.py
|
||||
pair_scanner.py # Generates all pairs, filters tradeable
|
||||
feature_engine.py # Calculates features for all pairs
|
||||
correlation.py # Rolling correlation matrix & filtering
|
||||
divergence_scorer.py # Ranks pairs by divergence score
|
||||
strategy.py # Main strategy orchestration
|
||||
```
|
||||
|
||||
### 6.2 Data Flow
|
||||
|
||||
```
|
||||
1. Load OHLCV for all 10 assets
|
||||
2. Generate pair combinations (45 pairs)
|
||||
3. Filter to tradeable pairs (OKX check)
|
||||
4. Calculate features for each pair
|
||||
5. Train/load universal ML model
|
||||
6. Predict probability for all pairs
|
||||
7. Calculate divergence scores
|
||||
8. Apply correlation filter
|
||||
9. Select top pair
|
||||
10. Generate signal with dynamic SL/TP
|
||||
11. Execute in backtest engine
|
||||
```
|
||||
|
||||
### 6.3 Configuration
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class MultiPairConfig:
|
||||
# Assets
|
||||
assets: list[str] = field(default_factory=lambda: [
|
||||
"BTC", "ETH", "SOL", "XRP", "BNB",
|
||||
"DOGE", "ADA", "AVAX", "LINK", "DOT"
|
||||
])
|
||||
|
||||
# Thresholds
|
||||
z_window: int = 24
|
||||
z_entry_threshold: float = 1.0
|
||||
prob_threshold: float = 0.5
|
||||
correlation_threshold: float = 0.85
|
||||
correlation_window: int = 168 # 7 days in hours
|
||||
|
||||
# Risk
|
||||
base_sl_pct: float = 0.06
|
||||
base_tp_pct: float = 0.05
|
||||
vol_multiplier_min: float = 0.5
|
||||
vol_multiplier_max: float = 2.0
|
||||
|
||||
# Model
|
||||
train_ratio: float = 0.7
|
||||
horizon: int = 102
|
||||
profit_target: float = 0.005
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Technical Considerations
|
||||
|
||||
### 7.1 Dependencies
|
||||
|
||||
- Extend `DataManager` to load multiple symbols
|
||||
- Query OKX API for available perpetual cross-pairs
|
||||
- Reuse existing feature engineering from `RegimeReversionStrategy`
|
||||
|
||||
### 7.2 Performance
|
||||
|
||||
- Pre-calculate all pair features in batch (vectorized)
|
||||
- Cache correlation matrix (update every N candles, not every minute)
|
||||
- Model inference is fast (single predict call with all pairs as rows)
|
||||
|
||||
### 7.3 Edge Cases
|
||||
|
||||
- Handle pairs with insufficient history (< 200 bars) - exclude
|
||||
- Handle assets delisted mid-backtest - skip pair
|
||||
- Handle zero-volume periods - use last valid price
|
||||
|
||||
---
|
||||
|
||||
## 8. Success Metrics
|
||||
|
||||
| Metric | Baseline (BTC/ETH) | Target |
|
||||
|--------|-------------------|--------|
|
||||
| Net PnL | Current performance | > 10% improvement |
|
||||
| Number of Trades | N | Comparable or higher |
|
||||
| Win Rate | Baseline % | Maintain or improve |
|
||||
| Average Trade Duration | Baseline hours | Flexible |
|
||||
| Max Drawdown | Baseline % | Not significantly worse |
|
||||
|
||||
---
|
||||
|
||||
## 9. Open Questions
|
||||
|
||||
1. **OKX Cross-Pairs**: Need to verify which cross-pairs are available on OKX perpetuals. May need to fallback to synthetic spreads for most pairs.
|
||||
|
||||
2. **On-Chain Data**: CryptoQuant data currently covers BTC/ETH. Should we:
|
||||
- Run without on-chain features for other assets?
|
||||
- Source alternative on-chain data?
|
||||
- Use funding rates only (available from OKX)?
|
||||
|
||||
3. **Pair ID Encoding**: For the universal model, should pair_id be:
|
||||
- One-hot encoded (adds 45 features)?
|
||||
- Label encoded (single ordinal feature)?
|
||||
- Hierarchical (base_asset + quote_asset as separate features)?
|
||||
|
||||
4. **Synthetic Spreads**: If trading SOL/DOT spread but only USDT pairs available:
|
||||
- Calculate spread synthetically: `SOL-USDT / DOT-USDT`
|
||||
- Execute as two legs: Long SOL-USDT, Short DOT-USDT
|
||||
- This doubles fees and adds execution complexity. Include in v1?
|
||||
|
||||
---
|
||||
|
||||
## 10. Implementation Phases
|
||||
|
||||
### Phase 1: Data & Infrastructure (Est. 2-3 days)
|
||||
- Extend DataManager for multi-symbol loading
|
||||
- Build pair scanner with OKX tradeable filter
|
||||
- Implement correlation matrix calculation
|
||||
|
||||
### Phase 2: Feature Engineering (Est. 2 days)
|
||||
- Adapt existing feature calculation for arbitrary pairs
|
||||
- Add pair identifier feature
|
||||
- Batch feature calculation for all pairs
|
||||
|
||||
### Phase 3: Model & Scoring (Est. 2 days)
|
||||
- Train universal model on all pairs
|
||||
- Implement divergence scoring
|
||||
- Add correlation filtering to pair selection
|
||||
|
||||
### Phase 4: Strategy Integration (Est. 2-3 days)
|
||||
- Implement dynamic SL/TP with volatility
|
||||
- Integrate with backtester
|
||||
- Build strategy orchestration class
|
||||
|
||||
### Phase 5: Validation & Comparison (Est. 2 days)
|
||||
- Run walk-forward backtest
|
||||
- Compare against BTC/ETH baseline
|
||||
- Generate performance report
|
||||
|
||||
**Total Estimated Effort: 10-12 days**
|
||||
|
||||
---
|
||||
|
||||
*Document Version: 1.0*
|
||||
*Created: 2026-01-15*
|
||||
*Author: AI Assistant*
|
||||
*Status: Draft - Awaiting Review*
|
||||
Reference in New Issue
Block a user