Files
lowkey_backtest/live_trading/multi_pair/strategy.py
Simon Moisy 1af0aab5fa feat: Add Multi-Pair Divergence Live Trading Module
- Introduced a new module for live trading based on the Multi-Pair Divergence Strategy.
- Implemented configuration classes for OKX API and multi-pair settings.
- Developed data feed functionality to fetch real-time OHLCV and funding data for multiple assets.
- Created a trading bot orchestrator to manage trading cycles, including entry and exit signals based on ML model predictions.
- Added comprehensive logging and error handling for robust operation.
- Included a README with setup instructions and usage guidelines for the new module.
2026-01-15 22:17:13 +08:00

397 lines
12 KiB
Python

"""
Live Multi-Pair Divergence Strategy.
Scores all pairs and selects the best divergence opportunity for trading.
Uses the pre-trained universal ML model from backtesting.
"""
import logging
import pickle
from dataclasses import dataclass
from pathlib import Path
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
# Opt-in to future pandas behavior to silence FutureWarning on fillna
pd.set_option('future.no_silent_downcasting', True)
from .config import MultiPairLiveConfig, PathConfig
from .data_feed import TradingPair
logger = logging.getLogger(__name__)
@dataclass
class DivergenceSignal:
"""
Signal for a divergent pair.
Attributes:
pair: Trading pair
z_score: Current Z-Score of the spread
probability: ML model probability of profitable reversion
divergence_score: Combined score (|z_score| * probability)
direction: 'long' or 'short' (relative to base asset)
base_price: Current price of base asset
quote_price: Current price of quote asset
atr: Average True Range in price units
atr_pct: ATR as percentage of price
"""
pair: TradingPair
z_score: float
probability: float
divergence_score: float
direction: str
base_price: float
quote_price: float
atr: float
atr_pct: float
base_funding: float = 0.0
class LiveMultiPairStrategy:
"""
Live trading implementation of multi-pair divergence strategy.
Scores all pairs using the universal ML model and selects
the best opportunity for mean-reversion trading.
"""
def __init__(
self,
config: MultiPairLiveConfig,
path_config: PathConfig
):
self.config = config
self.paths = path_config
self.model: RandomForestClassifier | None = None
self.feature_cols: list[str] | None = None
self._load_model()
def _load_model(self) -> None:
"""Load pre-trained model from backtesting."""
if self.paths.model_path.exists():
try:
with open(self.paths.model_path, 'rb') as f:
saved = pickle.load(f)
self.model = saved['model']
self.feature_cols = saved['feature_cols']
logger.info("Loaded model from %s", self.paths.model_path)
except Exception as e:
logger.error("Could not load model: %s", e)
raise ValueError(
f"Multi-pair model not found at {self.paths.model_path}. "
"Run the backtest first to train the model."
)
else:
raise ValueError(
f"Multi-pair model not found at {self.paths.model_path}. "
"Run the backtest first to train the model."
)
def score_pairs(
self,
pair_features: dict[str, pd.DataFrame],
pairs: list[TradingPair]
) -> list[DivergenceSignal]:
"""
Score all pairs and return ranked signals.
Args:
pair_features: Feature DataFrames by pair_id
pairs: List of TradingPair objects
Returns:
List of DivergenceSignal sorted by score (descending)
"""
if self.model is None:
logger.warning("Model not loaded")
return []
signals = []
pair_map = {p.pair_id: p for p in pairs}
for pair_id, features in pair_features.items():
if pair_id not in pair_map:
continue
pair = pair_map[pair_id]
# Get latest features
if len(features) == 0:
continue
latest = features.iloc[-1]
z_score = latest['z_score']
# Skip if Z-score below threshold
if abs(z_score) < self.config.z_entry_threshold:
continue
# Prepare features for prediction
# Handle missing feature columns gracefully
available_cols = [c for c in self.feature_cols if c in latest.index]
missing_cols = [c for c in self.feature_cols if c not in latest.index]
if missing_cols:
logger.debug("Missing feature columns: %s", missing_cols)
feature_row = latest[available_cols].fillna(0)
feature_row = feature_row.replace([np.inf, -np.inf], 0)
# Create full feature vector with zeros for missing
X_dict = {c: 0 for c in self.feature_cols}
for col in available_cols:
X_dict[col] = feature_row[col]
X = pd.DataFrame([X_dict])
# Predict probability
prob = self.model.predict_proba(X)[0, 1]
# Skip if probability below threshold
if prob < self.config.prob_threshold:
continue
# Apply funding rate filter
base_funding = latest.get('base_funding', 0) or 0
funding_thresh = self.config.funding_threshold
if z_score > 0: # Short signal
if base_funding < -funding_thresh:
logger.debug(
"Skipping %s short: funding too negative (%.4f)",
pair.name, base_funding
)
continue
else: # Long signal
if base_funding > funding_thresh:
logger.debug(
"Skipping %s long: funding too positive (%.4f)",
pair.name, base_funding
)
continue
# Calculate divergence score
divergence_score = abs(z_score) * prob
# Determine direction
direction = 'short' if z_score > 0 else 'long'
signal = DivergenceSignal(
pair=pair,
z_score=z_score,
probability=prob,
divergence_score=divergence_score,
direction=direction,
base_price=latest['base_close'],
quote_price=latest['quote_close'],
atr=latest.get('atr_base', 0),
atr_pct=latest.get('atr_pct_base', 0.02),
base_funding=base_funding
)
signals.append(signal)
# Sort by divergence score (highest first)
signals.sort(key=lambda s: s.divergence_score, reverse=True)
if signals:
logger.info(
"Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f, dir=%s)",
len(signals),
signals[0].pair.name,
signals[0].divergence_score,
signals[0].z_score,
signals[0].probability,
signals[0].direction
)
else:
logger.info("No pairs meet entry criteria")
return signals
def select_best_pair(
self,
signals: list[DivergenceSignal]
) -> DivergenceSignal | None:
"""
Select the best pair from scored signals.
Args:
signals: List of DivergenceSignal (pre-sorted by score)
Returns:
Best signal or None if no valid candidates
"""
if not signals:
return None
return signals[0]
def generate_signal(
self,
pair_features: dict[str, pd.DataFrame],
pairs: list[TradingPair]
) -> dict:
"""
Generate trading signal from latest features.
Args:
pair_features: Feature DataFrames by pair_id
pairs: List of TradingPair objects
Returns:
Signal dictionary with action, pair, direction, etc.
"""
# Score all pairs
signals = self.score_pairs(pair_features, pairs)
# Select best
best = self.select_best_pair(signals)
if best is None:
return {
'action': 'hold',
'reason': 'no_valid_signals'
}
return {
'action': 'entry',
'pair': best.pair,
'pair_id': best.pair.pair_id,
'direction': best.direction,
'z_score': best.z_score,
'probability': best.probability,
'divergence_score': best.divergence_score,
'base_price': best.base_price,
'quote_price': best.quote_price,
'atr': best.atr,
'atr_pct': best.atr_pct,
'base_funding': best.base_funding,
'reason': f'{best.pair.name} z={best.z_score:.2f} p={best.probability:.2f}'
}
def check_exit_signal(
self,
pair_features: dict[str, pd.DataFrame],
current_pair_id: str
) -> dict:
"""
Check if current position should be exited.
Exit conditions:
1. Z-Score reverted to mean (|Z| < threshold)
Args:
pair_features: Feature DataFrames by pair_id
current_pair_id: Current position's pair ID
Returns:
Signal dictionary with action and reason
"""
if current_pair_id not in pair_features:
return {
'action': 'exit',
'reason': 'pair_data_missing'
}
features = pair_features[current_pair_id]
if len(features) == 0:
return {
'action': 'exit',
'reason': 'no_data'
}
latest = features.iloc[-1]
z_score = latest['z_score']
# Check mean reversion
if abs(z_score) < self.config.z_exit_threshold:
return {
'action': 'exit',
'reason': f'mean_reversion (z={z_score:.2f})'
}
return {
'action': 'hold',
'z_score': z_score,
'reason': f'holding (z={z_score:.2f})'
}
def calculate_sl_tp(
self,
entry_price: float,
direction: str,
atr: float,
atr_pct: float
) -> tuple[float, float]:
"""
Calculate ATR-based dynamic stop-loss and take-profit prices.
Args:
entry_price: Entry price
direction: 'long' or 'short'
atr: ATR in price units
atr_pct: ATR as percentage of price
Returns:
Tuple of (stop_loss_price, take_profit_price)
"""
if atr > 0 and atr_pct > 0:
sl_distance = atr * self.config.sl_atr_multiplier
tp_distance = atr * self.config.tp_atr_multiplier
sl_pct = sl_distance / entry_price
tp_pct = tp_distance / entry_price
else:
sl_pct = self.config.base_sl_pct
tp_pct = self.config.base_tp_pct
# Apply bounds
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
if direction == 'long':
stop_loss = entry_price * (1 - sl_pct)
take_profit = entry_price * (1 + tp_pct)
else:
stop_loss = entry_price * (1 + sl_pct)
take_profit = entry_price * (1 - tp_pct)
return stop_loss, take_profit
def calculate_position_size(
self,
divergence_score: float,
available_usdt: float
) -> float:
"""
Calculate position size based on divergence score.
Args:
divergence_score: Combined score (|z| * prob)
available_usdt: Available USDT balance
Returns:
Position size in USDT
"""
if self.config.max_position_usdt <= 0:
base_size = available_usdt
else:
base_size = min(available_usdt, self.config.max_position_usdt)
# Scale by divergence (1.0 at 0.5 score, up to 2.0 at 1.0+ score)
base_threshold = 0.5
if divergence_score <= base_threshold:
scale = 1.0
else:
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
scale = min(scale, 2.0)
size = base_size * scale
if size < self.config.min_position_usdt:
return 0.0
return min(size, available_usdt * 0.95)