From 1af0aab5fab6e598058e3a9ca4c20ad5888acb14 Mon Sep 17 00:00:00 2001 From: Simon Moisy Date: Thu, 15 Jan 2026 22:17:13 +0800 Subject: [PATCH] feat: Add Multi-Pair Divergence Live Trading Module - Introduced a new module for live trading based on the Multi-Pair Divergence Strategy. - Implemented configuration classes for OKX API and multi-pair settings. - Developed data feed functionality to fetch real-time OHLCV and funding data for multiple assets. - Created a trading bot orchestrator to manage trading cycles, including entry and exit signals based on ML model predictions. - Added comprehensive logging and error handling for robust operation. - Included a README with setup instructions and usage guidelines for the new module. --- check_demo_account.py | 98 +++++ live_trading/config.py | 8 +- live_trading/multi_pair/README.md | 145 +++++++ live_trading/multi_pair/__init__.py | 11 + live_trading/multi_pair/config.py | 145 +++++++ live_trading/multi_pair/data_feed.py | 336 +++++++++++++++ live_trading/multi_pair/main.py | 609 +++++++++++++++++++++++++++ live_trading/multi_pair/strategy.py | 396 +++++++++++++++++ 8 files changed, 1741 insertions(+), 7 deletions(-) create mode 100644 check_demo_account.py create mode 100644 live_trading/multi_pair/README.md create mode 100644 live_trading/multi_pair/__init__.py create mode 100644 live_trading/multi_pair/config.py create mode 100644 live_trading/multi_pair/data_feed.py create mode 100644 live_trading/multi_pair/main.py create mode 100644 live_trading/multi_pair/strategy.py diff --git a/check_demo_account.py b/check_demo_account.py new file mode 100644 index 0000000..6df755f --- /dev/null +++ b/check_demo_account.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +""" +Check OKX demo account positions and recent orders. + +Usage: + uv run python check_demo_account.py +""" +import sys +from pathlib import Path +from datetime import datetime, timezone + +sys.path.insert(0, str(Path(__file__).parent)) + +from live_trading.config import OKXConfig +import ccxt + + +def main(): + """Check demo account status.""" + config = OKXConfig() + + print(f"\n{'='*60}") + print(f" OKX Demo Account Check") + print(f"{'='*60}") + print(f" Demo Mode: {config.demo_mode}") + print(f" API Key: {config.api_key[:8]}..." if config.api_key else " API Key: NOT SET") + print(f"{'='*60}\n") + + exchange = ccxt.okx({ + 'apiKey': config.api_key, + 'secret': config.secret, + 'password': config.password, + 'sandbox': config.demo_mode, + 'options': {'defaultType': 'swap'}, + 'enableRateLimit': True, + }) + + # Check balance + print("--- BALANCE ---") + balance = exchange.fetch_balance() + usdt = balance.get('USDT', {}) + print(f"USDT Total: {usdt.get('total', 0):.2f}") + print(f"USDT Free: {usdt.get('free', 0):.2f}") + print(f"USDT Used: {usdt.get('used', 0):.2f}") + + # Check all balances + print("\n--- ALL NON-ZERO BALANCES ---") + for currency, data in balance.items(): + if isinstance(data, dict) and data.get('total', 0) > 0: + print(f"{currency}: total={data.get('total', 0):.6f}, free={data.get('free', 0):.6f}") + + # Check open positions + print("\n--- OPEN POSITIONS ---") + positions = exchange.fetch_positions() + open_positions = [p for p in positions if abs(float(p.get('contracts', 0))) > 0] + + if open_positions: + for pos in open_positions: + print(f" {pos['symbol']}: {pos['side']} {pos['contracts']} contracts @ {pos.get('entryPrice', 'N/A')}") + print(f" Unrealized PnL: {pos.get('unrealizedPnl', 'N/A')}") + else: + print(" No open positions") + + # Check recent orders (last 50) + print("\n--- RECENT ORDERS (last 24h) ---") + try: + # Fetch closed orders for AVAX + orders = exchange.fetch_orders('AVAX/USDT:USDT', limit=20) + if orders: + for order in orders[-10:]: # Last 10 + ts = datetime.fromtimestamp(order['timestamp']/1000, tz=timezone.utc) + print(f" [{ts.strftime('%H:%M:%S')}] {order['side'].upper()} {order['amount']} AVAX @ {order.get('average', order.get('price', 'market'))}") + print(f" Status: {order['status']}, Filled: {order.get('filled', 0)}, ID: {order['id']}") + else: + print(" No recent AVAX orders") + except Exception as e: + print(f" Could not fetch orders: {e}") + + # Check order history more broadly + print("\n--- ORDER HISTORY (AVAX) ---") + try: + # Try fetching my trades + trades = exchange.fetch_my_trades('AVAX/USDT:USDT', limit=10) + if trades: + for trade in trades[-5:]: + ts = datetime.fromtimestamp(trade['timestamp']/1000, tz=timezone.utc) + print(f" [{ts.strftime('%Y-%m-%d %H:%M:%S')}] {trade['side'].upper()} {trade['amount']} @ {trade['price']}") + print(f" Fee: {trade.get('fee', {}).get('cost', 'N/A')} {trade.get('fee', {}).get('currency', '')}") + else: + print(" No recent AVAX trades") + except Exception as e: + print(f" Could not fetch trades: {e}") + + print(f"\n{'='*60}\n") + + +if __name__ == "__main__": + main() diff --git a/live_trading/config.py b/live_trading/config.py index 39172cd..47fb507 100644 --- a/live_trading/config.py +++ b/live_trading/config.py @@ -9,13 +9,7 @@ from pathlib import Path from dataclasses import dataclass, field from dotenv import load_dotenv -# Load .env from sibling project (BTC_spot_MVRV) -ENV_PATH = Path(__file__).parent.parent.parent / "BTC_spot_MVRV" / ".env" -if ENV_PATH.exists(): - load_dotenv(ENV_PATH) -else: - # Fallback to local .env - load_dotenv() +load_dotenv() @dataclass diff --git a/live_trading/multi_pair/README.md b/live_trading/multi_pair/README.md new file mode 100644 index 0000000..abd7673 --- /dev/null +++ b/live_trading/multi_pair/README.md @@ -0,0 +1,145 @@ +# Multi-Pair Divergence Live Trading + +This module implements live trading for the Multi-Pair Divergence Selection Strategy on OKX perpetual futures. + +## Overview + +The strategy scans 10 cryptocurrency pairs for spread divergence opportunities: + +1. **Pair Universe**: Top 10 assets by market cap (BTC, ETH, SOL, XRP, BNB, DOGE, ADA, AVAX, LINK, DOT) +2. **Spread Z-Score**: Identifies when pairs are divergent from their historical mean +3. **Universal ML Model**: Predicts probability of successful mean reversion +4. **Dynamic Selection**: Trades the pair with highest divergence score + +## Prerequisites + +Before running live trading, you must train the model via backtesting: + +```bash +uv run python scripts/run_multi_pair_backtest.py +``` + +This creates `data/multi_pair_model.pkl` which the live trading bot requires. + +## Setup + +### 1. API Keys + +Same as single-pair trading. Set in `.env`: + +```env +OKX_API_KEY=your_api_key +OKX_SECRET=your_secret +OKX_PASSWORD=your_passphrase +OKX_DEMO_MODE=true # Use demo for testing +``` + +### 2. Dependencies + +All dependencies are in `pyproject.toml`. No additional installation needed. + +## Usage + +### Run with Demo Account (Recommended First) + +```bash +uv run python -m live_trading.multi_pair.main +``` + +### Command Line Options + +```bash +# Custom position size +uv run python -m live_trading.multi_pair.main --max-position 500 + +# Custom leverage +uv run python -m live_trading.multi_pair.main --leverage 2 + +# Custom cycle interval (in seconds) +uv run python -m live_trading.multi_pair.main --interval 1800 + +# Combine options +uv run python -m live_trading.multi_pair.main --max-position 1000 --leverage 3 --interval 3600 +``` + +### Live Trading (Use with Caution) + +```bash +uv run python -m live_trading.multi_pair.main --live +``` + +## How It Works + +### Each Trading Cycle + +1. **Fetch Data**: Gets OHLCV for all 10 assets from OKX +2. **Calculate Features**: Computes Z-Score, RSI, volatility for all 45 pair combinations +3. **Score Pairs**: Uses ML model to rank pairs by divergence score (|Z| x probability) +4. **Check Exits**: If holding, check mean reversion or SL/TP +5. **Enter Best**: If no position, enter the highest-scoring divergent pair + +### Entry Conditions + +- |Z-Score| > 1.0 (spread diverged from mean) +- ML probability > 0.5 (model predicts successful reversion) +- Funding rate filter passes (avoid crowded trades) + +### Exit Conditions + +- Mean reversion: |Z-Score| returns to ~0 +- Stop-loss: ATR-based (default ~6%) +- Take-profit: ATR-based (default ~5%) + +## Strategy Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `z_entry_threshold` | 1.0 | Enter when \|Z-Score\| > threshold | +| `z_exit_threshold` | 0.0 | Exit when Z reverts to mean | +| `z_window` | 24 | Rolling window for Z-Score (hours) | +| `prob_threshold` | 0.5 | ML probability threshold for entry | +| `funding_threshold` | 0.0005 | Funding rate filter (0.05%) | +| `sl_atr_multiplier` | 10.0 | Stop-loss as ATR multiple | +| `tp_atr_multiplier` | 8.0 | Take-profit as ATR multiple | + +## Files + +### Input + +- `data/multi_pair_model.pkl` - Pre-trained ML model (required) + +### Output + +- `logs/multi_pair_live.log` - Trading logs +- `live_trading/multi_pair_positions.json` - Position persistence +- `live_trading/multi_pair_trade_log.csv` - Trade history + +## Architecture + +``` +live_trading/multi_pair/ + __init__.py # Module exports + config.py # Configuration classes + data_feed.py # Multi-asset OHLCV fetcher + strategy.py # ML scoring and signal generation + main.py # Bot orchestrator + README.md # This file +``` + +## Differences from Single-Pair + +| Aspect | Single-Pair | Multi-Pair | +|--------|-------------|------------| +| Assets | ETH only (BTC context) | 10 assets, 45 pairs | +| Model | ETH-specific | Universal across pairs | +| Selection | Fixed pair | Dynamic best pair | +| Stops | Fixed 6%/5% | ATR-based dynamic | + +## Risk Warning + +This is experimental trading software. Use at your own risk: + +- Always start with demo trading +- Never risk more than you can afford to lose +- Monitor the bot regularly +- The model was trained on historical data and may not predict future performance diff --git a/live_trading/multi_pair/__init__.py b/live_trading/multi_pair/__init__.py new file mode 100644 index 0000000..c9ae01a --- /dev/null +++ b/live_trading/multi_pair/__init__.py @@ -0,0 +1,11 @@ +"""Multi-Pair Divergence Live Trading Module.""" +from .config import MultiPairLiveConfig, get_multi_pair_config +from .data_feed import MultiPairDataFeed +from .strategy import LiveMultiPairStrategy + +__all__ = [ + "MultiPairLiveConfig", + "get_multi_pair_config", + "MultiPairDataFeed", + "LiveMultiPairStrategy", +] diff --git a/live_trading/multi_pair/config.py b/live_trading/multi_pair/config.py new file mode 100644 index 0000000..d8b0a18 --- /dev/null +++ b/live_trading/multi_pair/config.py @@ -0,0 +1,145 @@ +""" +Configuration for Multi-Pair Live Trading. + +Extends the base live trading config with multi-pair specific settings. +""" +import os +from pathlib import Path +from dataclasses import dataclass, field +from dotenv import load_dotenv + +load_dotenv() + + +@dataclass +class OKXConfig: + """OKX API configuration.""" + api_key: str = field(default_factory=lambda: "") + secret: str = field(default_factory=lambda: "") + password: str = field(default_factory=lambda: "") + demo_mode: bool = field(default_factory=lambda: True) + + def __post_init__(self): + """Load credentials based on demo mode setting.""" + self.demo_mode = os.getenv("OKX_DEMO_MODE", "true").lower() in ("true", "1", "yes") + + if self.demo_mode: + self.api_key = os.getenv("OKX_DEMO_API_KEY", os.getenv("OKX_API_KEY", "")) + self.secret = os.getenv("OKX_DEMO_SECRET", os.getenv("OKX_SECRET", "")) + self.password = os.getenv("OKX_DEMO_PASSWORD", os.getenv("OKX_PASSWORD", "")) + else: + self.api_key = os.getenv("OKX_API_KEY", "") + self.secret = os.getenv("OKX_SECRET", "") + self.password = os.getenv("OKX_PASSWORD", "") + + def validate(self) -> None: + """Validate that required credentials are present.""" + mode = "demo" if self.demo_mode else "live" + if not self.api_key: + raise ValueError(f"OKX API key not set for {mode} mode") + if not self.secret: + raise ValueError(f"OKX secret not set for {mode} mode") + if not self.password: + raise ValueError(f"OKX password not set for {mode} mode") + + +@dataclass +class MultiPairLiveConfig: + """ + Configuration for multi-pair live trading. + + Combines trading parameters, strategy settings, and risk management. + """ + # Asset Universe (top 10 by market cap perpetuals) + assets: list[str] = field(default_factory=lambda: [ + "BTC/USDT:USDT", "ETH/USDT:USDT", "SOL/USDT:USDT", "XRP/USDT:USDT", + "BNB/USDT:USDT", "DOGE/USDT:USDT", "ADA/USDT:USDT", "AVAX/USDT:USDT", + "LINK/USDT:USDT", "DOT/USDT:USDT" + ]) + + # Timeframe + timeframe: str = "1h" + candles_to_fetch: int = 500 # Enough for feature calculation + + # Z-Score Thresholds + z_window: int = 24 + z_entry_threshold: float = 1.0 + z_exit_threshold: float = 0.0 # Exit at mean reversion + + # ML Thresholds + prob_threshold: float = 0.5 + + # Position sizing + max_position_usdt: float = -1.0 # If <= 0, use all available funds + min_position_usdt: float = 10.0 + leverage: int = 1 + margin_mode: str = "cross" + max_concurrent_positions: int = 1 # Trade one pair at a time + + # Risk Management - ATR-Based Stops + atr_period: int = 14 + sl_atr_multiplier: float = 10.0 + tp_atr_multiplier: float = 8.0 + + # Fallback fixed percentages + base_sl_pct: float = 0.06 + base_tp_pct: float = 0.05 + + # ATR bounds + min_sl_pct: float = 0.02 + max_sl_pct: float = 0.10 + min_tp_pct: float = 0.02 + max_tp_pct: float = 0.15 + + # Funding Rate Filter + funding_threshold: float = 0.0005 # 0.05% + + # Trade Management + min_hold_bars: int = 0 + cooldown_bars: int = 0 + + # Execution + sleep_seconds: int = 3600 # Run every hour + slippage_pct: float = 0.001 + + def get_asset_short_name(self, symbol: str) -> str: + """Convert symbol to short name (e.g., BTC/USDT:USDT -> btc).""" + return symbol.split("/")[0].lower() + + def get_pair_count(self) -> int: + """Calculate number of unique pairs from asset list.""" + n = len(self.assets) + return n * (n - 1) // 2 + + +@dataclass +class PathConfig: + """File paths configuration.""" + base_dir: Path = field( + default_factory=lambda: Path(__file__).parent.parent.parent + ) + data_dir: Path = field(default=None) + logs_dir: Path = field(default=None) + model_path: Path = field(default=None) + positions_file: Path = field(default=None) + trade_log_file: Path = field(default=None) + + def __post_init__(self): + self.data_dir = self.base_dir / "data" + self.logs_dir = self.base_dir / "logs" + # Use the same model as backtesting + self.model_path = self.base_dir / "data" / "multi_pair_model.pkl" + self.positions_file = self.base_dir / "live_trading" / "multi_pair_positions.json" + self.trade_log_file = self.base_dir / "live_trading" / "multi_pair_trade_log.csv" + + # Ensure directories exist + self.data_dir.mkdir(parents=True, exist_ok=True) + self.logs_dir.mkdir(parents=True, exist_ok=True) + + +def get_multi_pair_config() -> tuple[OKXConfig, MultiPairLiveConfig, PathConfig]: + """Get all configuration objects for multi-pair trading.""" + okx = OKXConfig() + trading = MultiPairLiveConfig() + paths = PathConfig() + return okx, trading, paths diff --git a/live_trading/multi_pair/data_feed.py b/live_trading/multi_pair/data_feed.py new file mode 100644 index 0000000..8579d6d --- /dev/null +++ b/live_trading/multi_pair/data_feed.py @@ -0,0 +1,336 @@ +""" +Multi-Pair Data Feed for Live Trading. + +Fetches real-time OHLCV and funding data for all assets in the universe. +""" +import logging +from itertools import combinations + +import pandas as pd +import numpy as np +import ta + +from live_trading.okx_client import OKXClient +from .config import MultiPairLiveConfig, PathConfig + +logger = logging.getLogger(__name__) + + +class TradingPair: + """ + Represents a tradeable pair for spread analysis. + + Attributes: + base_asset: First asset symbol (e.g., ETH/USDT:USDT) + quote_asset: Second asset symbol (e.g., BTC/USDT:USDT) + pair_id: Unique identifier + """ + def __init__(self, base_asset: str, quote_asset: str): + self.base_asset = base_asset + self.quote_asset = quote_asset + self.pair_id = f"{base_asset}__{quote_asset}" + + @property + def name(self) -> str: + """Human-readable pair name.""" + base = self.base_asset.split("/")[0] + quote = self.quote_asset.split("/")[0] + return f"{base}/{quote}" + + def __hash__(self): + return hash(self.pair_id) + + def __eq__(self, other): + if not isinstance(other, TradingPair): + return False + return self.pair_id == other.pair_id + + +class MultiPairDataFeed: + """ + Real-time data feed for multi-pair strategy. + + Fetches OHLCV data for all assets and calculates spread features + for all pair combinations. + """ + + def __init__( + self, + okx_client: OKXClient, + config: MultiPairLiveConfig, + path_config: PathConfig + ): + self.client = okx_client + self.config = config + self.paths = path_config + + # Cache for asset data + self._asset_data: dict[str, pd.DataFrame] = {} + self._funding_rates: dict[str, float] = {} + self._pairs: list[TradingPair] = [] + + # Generate pairs + self._generate_pairs() + + def _generate_pairs(self) -> None: + """Generate all unique pairs from asset universe.""" + self._pairs = [] + for base, quote in combinations(self.config.assets, 2): + pair = TradingPair(base_asset=base, quote_asset=quote) + self._pairs.append(pair) + + logger.info("Generated %d pairs from %d assets", + len(self._pairs), len(self.config.assets)) + + @property + def pairs(self) -> list[TradingPair]: + """Get list of trading pairs.""" + return self._pairs + + def fetch_all_ohlcv(self) -> dict[str, pd.DataFrame]: + """ + Fetch OHLCV data for all assets. + + Returns: + Dictionary mapping symbol to OHLCV DataFrame + """ + self._asset_data = {} + + for symbol in self.config.assets: + try: + ohlcv = self.client.fetch_ohlcv( + symbol, + self.config.timeframe, + self.config.candles_to_fetch + ) + df = self._ohlcv_to_dataframe(ohlcv) + + if len(df) >= 200: + self._asset_data[symbol] = df + logger.debug("Fetched %s: %d candles", symbol, len(df)) + else: + logger.warning("Skipping %s: insufficient data (%d)", + symbol, len(df)) + except Exception as e: + logger.error("Error fetching %s: %s", symbol, e) + + logger.info("Fetched data for %d/%d assets", + len(self._asset_data), len(self.config.assets)) + return self._asset_data + + def _ohlcv_to_dataframe(self, ohlcv: list) -> pd.DataFrame: + """Convert OHLCV list to DataFrame.""" + df = pd.DataFrame( + ohlcv, + columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'] + ) + df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True) + df.set_index('timestamp', inplace=True) + return df + + def fetch_all_funding_rates(self) -> dict[str, float]: + """ + Fetch current funding rates for all assets. + + Returns: + Dictionary mapping symbol to funding rate + """ + self._funding_rates = {} + + for symbol in self.config.assets: + try: + rate = self.client.get_funding_rate(symbol) + self._funding_rates[symbol] = rate + except Exception as e: + logger.warning("Could not get funding for %s: %s", symbol, e) + self._funding_rates[symbol] = 0.0 + + return self._funding_rates + + def calculate_pair_features( + self, + pair: TradingPair + ) -> pd.DataFrame | None: + """ + Calculate features for a single pair. + + Args: + pair: Trading pair + + Returns: + DataFrame with features, or None if insufficient data + """ + base = pair.base_asset + quote = pair.quote_asset + + if base not in self._asset_data or quote not in self._asset_data: + return None + + df_base = self._asset_data[base] + df_quote = self._asset_data[quote] + + # Align indices + common_idx = df_base.index.intersection(df_quote.index) + if len(common_idx) < 200: + return None + + df_a = df_base.loc[common_idx] + df_b = df_quote.loc[common_idx] + + # Calculate spread (base / quote) + spread = df_a['close'] / df_b['close'] + + # Z-Score + z_window = self.config.z_window + rolling_mean = spread.rolling(window=z_window).mean() + rolling_std = spread.rolling(window=z_window).std() + z_score = (spread - rolling_mean) / rolling_std + + # Spread Technicals + spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi() + spread_roc = spread.pct_change(periods=5) * 100 + spread_change_1h = spread.pct_change(periods=1) + + # Volume Analysis + vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10) + vol_ratio_ma = vol_ratio.rolling(window=12).mean() + vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10) + + # Volatility + ret_a = df_a['close'].pct_change() + ret_b = df_b['close'].pct_change() + vol_a = ret_a.rolling(window=z_window).std() + vol_b = ret_b.rolling(window=z_window).std() + vol_spread_ratio = vol_a / (vol_b + 1e-10) + + # Realized Volatility + realized_vol_a = ret_a.rolling(window=24).std() + realized_vol_b = ret_b.rolling(window=24).std() + + # ATR (Average True Range) + high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close'] + + tr_a = pd.concat([ + high_a - low_a, + (high_a - close_a.shift(1)).abs(), + (low_a - close_a.shift(1)).abs() + ], axis=1).max(axis=1) + atr_a = tr_a.rolling(window=self.config.atr_period).mean() + atr_pct_a = atr_a / close_a + + # Build feature DataFrame + features = pd.DataFrame(index=common_idx) + features['pair_id'] = pair.pair_id + features['base_asset'] = base + features['quote_asset'] = quote + + # Price data + features['spread'] = spread + features['base_close'] = df_a['close'] + features['quote_close'] = df_b['close'] + features['base_volume'] = df_a['volume'] + + # Core Features + features['z_score'] = z_score + features['spread_rsi'] = spread_rsi + features['spread_roc'] = spread_roc + features['spread_change_1h'] = spread_change_1h + features['vol_ratio'] = vol_ratio + features['vol_ratio_rel'] = vol_ratio_rel + features['vol_diff_ratio'] = vol_spread_ratio + + # Volatility + features['realized_vol_base'] = realized_vol_a + features['realized_vol_quote'] = realized_vol_b + features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2 + + # ATR + features['atr_base'] = atr_a + features['atr_pct_base'] = atr_pct_a + + # Pair encoding + assets = self.config.assets + features['base_idx'] = assets.index(base) if base in assets else -1 + features['quote_idx'] = assets.index(quote) if quote in assets else -1 + + # Funding rates + base_funding = self._funding_rates.get(base, 0.0) + quote_funding = self._funding_rates.get(quote, 0.0) + features['base_funding'] = base_funding + features['quote_funding'] = quote_funding + features['funding_diff'] = base_funding - quote_funding + features['funding_avg'] = (base_funding + quote_funding) / 2 + + # Drop NaN rows in core features + core_cols = [ + 'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h', + 'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio', + 'realized_vol_base', 'atr_base', 'atr_pct_base' + ] + features = features.dropna(subset=core_cols) + + return features + + def calculate_all_pair_features(self) -> dict[str, pd.DataFrame]: + """ + Calculate features for all pairs. + + Returns: + Dictionary mapping pair_id to feature DataFrame + """ + all_features = {} + + for pair in self._pairs: + features = self.calculate_pair_features(pair) + if features is not None and len(features) > 0: + all_features[pair.pair_id] = features + + logger.info("Calculated features for %d/%d pairs", + len(all_features), len(self._pairs)) + + return all_features + + def get_latest_data(self) -> dict[str, pd.DataFrame] | None: + """ + Fetch and process latest market data for all pairs. + + Returns: + Dictionary of pair features or None on error + """ + try: + # Fetch OHLCV for all assets + self.fetch_all_ohlcv() + + if len(self._asset_data) < 2: + logger.warning("Insufficient assets fetched") + return None + + # Fetch funding rates + self.fetch_all_funding_rates() + + # Calculate features for all pairs + pair_features = self.calculate_all_pair_features() + + if not pair_features: + logger.warning("No pair features calculated") + return None + + logger.info("Processed %d pairs with valid features", len(pair_features)) + return pair_features + + except Exception as e: + logger.error("Error fetching market data: %s", e, exc_info=True) + return None + + def get_pair_by_id(self, pair_id: str) -> TradingPair | None: + """Get pair object by ID.""" + for pair in self._pairs: + if pair.pair_id == pair_id: + return pair + return None + + def get_current_price(self, symbol: str) -> float | None: + """Get current price for a symbol.""" + if symbol in self._asset_data: + return self._asset_data[symbol]['close'].iloc[-1] + return None diff --git a/live_trading/multi_pair/main.py b/live_trading/multi_pair/main.py new file mode 100644 index 0000000..882888a --- /dev/null +++ b/live_trading/multi_pair/main.py @@ -0,0 +1,609 @@ +#!/usr/bin/env python3 +""" +Multi-Pair Divergence Live Trading Bot. + +Trades the top 10 cryptocurrency pairs based on spread divergence +using a universal ML model for signal generation. + +Usage: + # Run with demo account (default) + uv run python -m live_trading.multi_pair.main + + # Run with specific settings + uv run python -m live_trading.multi_pair.main --max-position 500 --leverage 2 +""" +import argparse +import logging +import signal +import sys +import time +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from live_trading.okx_client import OKXClient +from live_trading.position_manager import PositionManager +from live_trading.multi_pair.config import ( + OKXConfig, MultiPairLiveConfig, PathConfig, get_multi_pair_config +) +from live_trading.multi_pair.data_feed import MultiPairDataFeed, TradingPair +from live_trading.multi_pair.strategy import LiveMultiPairStrategy + + +def setup_logging(log_dir: Path) -> logging.Logger: + """Configure logging for the trading bot.""" + log_file = log_dir / "multi_pair_live.log" + + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', + handlers=[ + logging.FileHandler(log_file), + logging.StreamHandler(sys.stdout), + ], + force=True + ) + + return logging.getLogger(__name__) + + +@dataclass +class PositionState: + """Track current position state for multi-pair.""" + pair: TradingPair | None = None + pair_id: str | None = None + direction: str | None = None + entry_price: float = 0.0 + size: float = 0.0 + stop_loss: float = 0.0 + take_profit: float = 0.0 + entry_time: datetime | None = None + + +class MultiPairLiveTradingBot: + """ + Main trading bot for multi-pair divergence strategy. + + Coordinates data fetching, pair scoring, and order execution. + """ + + def __init__( + self, + okx_config: OKXConfig, + trading_config: MultiPairLiveConfig, + path_config: PathConfig + ): + self.okx_config = okx_config + self.trading_config = trading_config + self.path_config = path_config + + self.logger = logging.getLogger(__name__) + self.running = True + + # Initialize components + self.logger.info("Initializing multi-pair trading bot...") + + # Create OKX client with adapted config + self._adapted_trading_config = self._adapt_config_for_okx_client() + self.okx_client = OKXClient(okx_config, self._adapted_trading_config) + + # Initialize data feed + self.data_feed = MultiPairDataFeed( + self.okx_client, trading_config, path_config + ) + + # Initialize position manager (reuse from single-pair) + self.position_manager = PositionManager( + self.okx_client, self._adapted_trading_config, path_config + ) + + # Initialize strategy + self.strategy = LiveMultiPairStrategy(trading_config, path_config) + + # Current position state + self.position = PositionState() + + # Register signal handlers + signal.signal(signal.SIGINT, self._handle_shutdown) + signal.signal(signal.SIGTERM, self._handle_shutdown) + + self._print_startup_banner() + + # Sync with exchange positions on startup + self._sync_position_from_exchange() + + def _adapt_config_for_okx_client(self): + """Create config compatible with OKXClient.""" + # OKXClient expects specific attributes + @dataclass + class AdaptedConfig: + eth_symbol: str = "ETH/USDT:USDT" + btc_symbol: str = "BTC/USDT:USDT" + timeframe: str = "1h" + candles_to_fetch: int = 500 + max_position_usdt: float = -1.0 + min_position_usdt: float = 10.0 + leverage: int = 1 + margin_mode: str = "cross" + stop_loss_pct: float = 0.06 + take_profit_pct: float = 0.05 + max_concurrent_positions: int = 1 + z_entry_threshold: float = 1.0 + z_window: int = 24 + model_prob_threshold: float = 0.5 + funding_threshold: float = 0.0005 + sleep_seconds: int = 3600 + slippage_pct: float = 0.001 + + adapted = AdaptedConfig() + adapted.timeframe = self.trading_config.timeframe + adapted.candles_to_fetch = self.trading_config.candles_to_fetch + adapted.max_position_usdt = self.trading_config.max_position_usdt + adapted.min_position_usdt = self.trading_config.min_position_usdt + adapted.leverage = self.trading_config.leverage + adapted.margin_mode = self.trading_config.margin_mode + adapted.max_concurrent_positions = self.trading_config.max_concurrent_positions + adapted.sleep_seconds = self.trading_config.sleep_seconds + adapted.slippage_pct = self.trading_config.slippage_pct + + return adapted + + def _print_startup_banner(self) -> None: + """Print startup information.""" + mode = "DEMO/SANDBOX" if self.okx_config.demo_mode else "LIVE" + + print("=" * 60) + print(" Multi-Pair Divergence Strategy - Live Trading Bot") + print("=" * 60) + print(f" Mode: {mode}") + print(f" Assets: {len(self.trading_config.assets)} assets") + print(f" Pairs: {self.trading_config.get_pair_count()} pairs") + print(f" Timeframe: {self.trading_config.timeframe}") + print(f" Max Position: ${self.trading_config.max_position_usdt if self.trading_config.max_position_usdt > 0 else 'All available'}") + print(f" Leverage: {self.trading_config.leverage}x") + print(f" Z-Entry: > {self.trading_config.z_entry_threshold}") + print(f" Prob Threshold: > {self.trading_config.prob_threshold}") + print(f" Cycle Interval: {self.trading_config.sleep_seconds // 60} minutes") + print("=" * 60) + print(f" Assets: {', '.join([a.split('/')[0] for a in self.trading_config.assets])}") + print("=" * 60) + + if not self.okx_config.demo_mode: + print("\n *** WARNING: LIVE TRADING MODE - REAL FUNDS AT RISK ***\n") + + def _handle_shutdown(self, signum, frame) -> None: + """Handle shutdown signals gracefully.""" + self.logger.info("Shutdown signal received, stopping...") + self.running = False + + def _sync_position_from_exchange(self) -> bool: + """ + Sync internal position state with exchange positions. + + Checks for existing open positions on the exchange and updates + internal state to match. This prevents stacking positions when + the bot is restarted. + + Returns: + True if a position was synced, False otherwise + """ + try: + positions = self.okx_client.get_positions() + + if not positions: + if self.position.pair is not None: + # Position was closed externally (e.g., SL/TP hit) + self.logger.info( + "Position %s was closed externally, resetting state", + self.position.pair.name if self.position.pair else "unknown" + ) + self.position = PositionState() + return False + + # Check each position against our tradeable assets + our_assets = set(self.trading_config.assets) + + for pos in positions: + pos_symbol = pos.get('symbol', '') + contracts = abs(float(pos.get('contracts', 0))) + + if contracts == 0: + continue + + # Check if this position is for one of our assets + if pos_symbol not in our_assets: + continue + + # Found a position for one of our assets + side = pos.get('side', 'long') + entry_price = float(pos.get('entryPrice', 0)) + unrealized_pnl = float(pos.get('unrealizedPnl', 0)) + + # If we already track this position, just update + if (self.position.pair is not None and + self.position.pair.base_asset == pos_symbol): + self.logger.debug( + "Position already tracked: %s %s %.2f contracts", + side, pos_symbol, contracts + ) + return True + + # New position found - sync it + # Find or create a TradingPair for this position + matched_pair = None + for pair in self.data_feed.pairs: + if pair.base_asset == pos_symbol: + matched_pair = pair + break + + if matched_pair is None: + # Create a placeholder pair (we don't know the quote asset) + matched_pair = TradingPair( + base_asset=pos_symbol, + quote_asset="UNKNOWN" + ) + + # Calculate approximate SL/TP based on config defaults + sl_pct = self.trading_config.base_sl_pct + tp_pct = self.trading_config.base_tp_pct + + if side == 'long': + stop_loss = entry_price * (1 - sl_pct) + take_profit = entry_price * (1 + tp_pct) + else: + stop_loss = entry_price * (1 + sl_pct) + take_profit = entry_price * (1 - tp_pct) + + self.position = PositionState( + pair=matched_pair, + pair_id=matched_pair.pair_id, + direction=side, + entry_price=entry_price, + size=contracts, + stop_loss=stop_loss, + take_profit=take_profit, + entry_time=None # Unknown for synced positions + ) + + self.logger.info( + "Synced existing position from exchange: %s %s %.4f @ %.4f (PnL: %.2f)", + side.upper(), + pos_symbol, + contracts, + entry_price, + unrealized_pnl + ) + return True + + # No matching positions found + if self.position.pair is not None: + self.logger.info( + "Position %s no longer exists on exchange, resetting state", + self.position.pair.name + ) + self.position = PositionState() + + return False + + except Exception as e: + self.logger.error("Failed to sync position from exchange: %s", e) + return False + + def run_trading_cycle(self) -> None: + """ + Execute one trading cycle. + + 1. Sync position state with exchange + 2. Fetch latest market data for all assets + 3. Calculate features for all pairs + 4. Score pairs and find best opportunity + 5. Check exit conditions for current position + 6. Execute trades if needed + """ + cycle_start = datetime.now(timezone.utc) + self.logger.info("--- Trading Cycle Start: %s ---", cycle_start.isoformat()) + + try: + # 1. Sync position state with exchange (detect SL/TP closures) + self._sync_position_from_exchange() + + # 2. Fetch all market data + pair_features = self.data_feed.get_latest_data() + if pair_features is None: + self.logger.warning("No market data available, skipping cycle") + return + + # 2. Check exit conditions for current position + if self.position.pair is not None: + exit_signal = self.strategy.check_exit_signal( + pair_features, + self.position.pair_id + ) + + if exit_signal['action'] == 'exit': + self._execute_exit(exit_signal) + else: + # Check SL/TP + current_price = self.data_feed.get_current_price( + self.position.pair.base_asset + ) + if current_price: + sl_tp_exit = self._check_sl_tp(current_price) + if sl_tp_exit: + self._execute_exit({'reason': sl_tp_exit}) + + # 3. Generate entry signal if no position + if self.position.pair is None: + entry_signal = self.strategy.generate_signal( + pair_features, + self.data_feed.pairs + ) + + if entry_signal['action'] == 'entry': + self._execute_entry(entry_signal) + + # 4. Log status + if self.position.pair: + self.logger.info( + "Position: %s %s, entry=%.4f, current PnL check pending", + self.position.direction, + self.position.pair.name, + self.position.entry_price + ) + else: + self.logger.info("No open position") + + except Exception as e: + self.logger.error("Trading cycle error: %s", e, exc_info=True) + + cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds() + self.logger.info("--- Cycle completed in %.1fs ---", cycle_duration) + + def _check_sl_tp(self, current_price: float) -> str | None: + """Check stop-loss and take-profit levels.""" + if self.position.direction == 'long': + if current_price <= self.position.stop_loss: + return f"stop_loss ({current_price:.4f} <= {self.position.stop_loss:.4f})" + if current_price >= self.position.take_profit: + return f"take_profit ({current_price:.4f} >= {self.position.take_profit:.4f})" + else: # short + if current_price >= self.position.stop_loss: + return f"stop_loss ({current_price:.4f} >= {self.position.stop_loss:.4f})" + if current_price <= self.position.take_profit: + return f"take_profit ({current_price:.4f} <= {self.position.take_profit:.4f})" + return None + + def _execute_entry(self, signal: dict) -> None: + """Execute entry trade.""" + pair = signal['pair'] + symbol = pair.base_asset # Trade the base asset + direction = signal['direction'] + + self.logger.info( + "Entry signal: %s %s (z=%.2f, p=%.2f, score=%.3f)", + direction.upper(), + pair.name, + signal['z_score'], + signal['probability'], + signal['divergence_score'] + ) + + # Get account balance + try: + balance = self.okx_client.get_balance() + available_usdt = balance['free'] + except Exception as e: + self.logger.error("Could not get balance: %s", e) + return + + # Calculate position size + size_usdt = self.strategy.calculate_position_size( + signal['divergence_score'], + available_usdt + ) + + if size_usdt <= 0: + self.logger.info("Position size too small, skipping entry") + return + + current_price = signal['base_price'] + size_asset = size_usdt / current_price + + # Calculate SL/TP + stop_loss, take_profit = self.strategy.calculate_sl_tp( + current_price, + direction, + signal['atr'], + signal['atr_pct'] + ) + + self.logger.info( + "Executing %s entry: %.6f %s @ %.4f ($%.2f), SL=%.4f, TP=%.4f", + direction.upper(), + size_asset, + symbol.split('/')[0], + current_price, + size_usdt, + stop_loss, + take_profit + ) + + try: + # Place market order + order_side = "buy" if direction == "long" else "sell" + order = self.okx_client.place_market_order(symbol, order_side, size_asset) + + filled_price = order.get('average') or order.get('price') or current_price + filled_amount = order.get('filled') or order.get('amount') or size_asset + + if filled_price is None or filled_price == 0: + filled_price = current_price + if filled_amount is None or filled_amount == 0: + filled_amount = size_asset + + # Recalculate SL/TP with filled price + stop_loss, take_profit = self.strategy.calculate_sl_tp( + filled_price, direction, signal['atr'], signal['atr_pct'] + ) + + # Update position state + self.position = PositionState( + pair=pair, + pair_id=pair.pair_id, + direction=direction, + entry_price=filled_price, + size=filled_amount, + stop_loss=stop_loss, + take_profit=take_profit, + entry_time=datetime.now(timezone.utc) + ) + + self.logger.info( + "Position opened: %s %s %.6f @ %.4f", + direction.upper(), + pair.name, + filled_amount, + filled_price + ) + + # Try to set SL/TP on exchange + try: + self.okx_client.set_stop_loss_take_profit( + symbol, direction, filled_amount, stop_loss, take_profit + ) + except Exception as e: + self.logger.warning("Could not set SL/TP on exchange: %s", e) + + except Exception as e: + self.logger.error("Order execution failed: %s", e, exc_info=True) + + def _execute_exit(self, signal: dict) -> None: + """Execute exit trade.""" + if self.position.pair is None: + return + + symbol = self.position.pair.base_asset + reason = signal.get('reason', 'unknown') + + self.logger.info( + "Exit signal: %s %s, reason: %s", + self.position.direction, + self.position.pair.name, + reason + ) + + try: + # Close position on exchange + self.okx_client.close_position(symbol) + + self.logger.info( + "Position closed: %s %s", + self.position.direction, + self.position.pair.name + ) + + # Reset position state + self.position = PositionState() + + except Exception as e: + self.logger.error("Exit execution failed: %s", e, exc_info=True) + + def run(self) -> None: + """Main trading loop.""" + self.logger.info("Starting multi-pair trading loop...") + + while self.running: + try: + self.run_trading_cycle() + + if self.running: + sleep_seconds = self.trading_config.sleep_seconds + minutes = sleep_seconds // 60 + self.logger.info("Sleeping for %d minutes...", minutes) + + for _ in range(sleep_seconds): + if not self.running: + break + time.sleep(1) + + except KeyboardInterrupt: + self.logger.info("Keyboard interrupt received") + break + except Exception as e: + self.logger.error("Unexpected error in main loop: %s", e, exc_info=True) + time.sleep(60) + + self.logger.info("Shutting down...") + self.logger.info("Shutdown complete") + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Multi-Pair Divergence Live Trading Bot" + ) + parser.add_argument( + "--max-position", + type=float, + default=None, + help="Maximum position size in USDT" + ) + parser.add_argument( + "--leverage", + type=int, + default=None, + help="Trading leverage (1-125)" + ) + parser.add_argument( + "--interval", + type=int, + default=None, + help="Trading cycle interval in seconds" + ) + parser.add_argument( + "--live", + action="store_true", + help="Use live trading mode (requires OKX_DEMO_MODE=false)" + ) + return parser.parse_args() + + +def main(): + """Main entry point.""" + args = parse_args() + + # Load configuration + okx_config, trading_config, path_config = get_multi_pair_config() + + # Apply command line overrides + if args.max_position is not None: + trading_config.max_position_usdt = args.max_position + if args.leverage is not None: + trading_config.leverage = args.leverage + if args.interval is not None: + trading_config.sleep_seconds = args.interval + if args.live: + okx_config.demo_mode = False + + # Setup logging + logger = setup_logging(path_config.logs_dir) + + try: + # Validate config + okx_config.validate() + + # Create and run bot + bot = MultiPairLiveTradingBot(okx_config, trading_config, path_config) + bot.run() + except ValueError as e: + logger.error("Configuration error: %s", e) + sys.exit(1) + except Exception as e: + logger.error("Fatal error: %s", e, exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/live_trading/multi_pair/strategy.py b/live_trading/multi_pair/strategy.py new file mode 100644 index 0000000..b0e9215 --- /dev/null +++ b/live_trading/multi_pair/strategy.py @@ -0,0 +1,396 @@ +""" +Live Multi-Pair Divergence Strategy. + +Scores all pairs and selects the best divergence opportunity for trading. +Uses the pre-trained universal ML model from backtesting. +""" +import logging +import pickle +from dataclasses import dataclass +from pathlib import Path + +import pandas as pd +import numpy as np +from sklearn.ensemble import RandomForestClassifier + +# Opt-in to future pandas behavior to silence FutureWarning on fillna +pd.set_option('future.no_silent_downcasting', True) + +from .config import MultiPairLiveConfig, PathConfig +from .data_feed import TradingPair + +logger = logging.getLogger(__name__) + + +@dataclass +class DivergenceSignal: + """ + Signal for a divergent pair. + + Attributes: + pair: Trading pair + z_score: Current Z-Score of the spread + probability: ML model probability of profitable reversion + divergence_score: Combined score (|z_score| * probability) + direction: 'long' or 'short' (relative to base asset) + base_price: Current price of base asset + quote_price: Current price of quote asset + atr: Average True Range in price units + atr_pct: ATR as percentage of price + """ + pair: TradingPair + z_score: float + probability: float + divergence_score: float + direction: str + base_price: float + quote_price: float + atr: float + atr_pct: float + base_funding: float = 0.0 + + +class LiveMultiPairStrategy: + """ + Live trading implementation of multi-pair divergence strategy. + + Scores all pairs using the universal ML model and selects + the best opportunity for mean-reversion trading. + """ + + def __init__( + self, + config: MultiPairLiveConfig, + path_config: PathConfig + ): + self.config = config + self.paths = path_config + self.model: RandomForestClassifier | None = None + self.feature_cols: list[str] | None = None + self._load_model() + + def _load_model(self) -> None: + """Load pre-trained model from backtesting.""" + if self.paths.model_path.exists(): + try: + with open(self.paths.model_path, 'rb') as f: + saved = pickle.load(f) + self.model = saved['model'] + self.feature_cols = saved['feature_cols'] + logger.info("Loaded model from %s", self.paths.model_path) + except Exception as e: + logger.error("Could not load model: %s", e) + raise ValueError( + f"Multi-pair model not found at {self.paths.model_path}. " + "Run the backtest first to train the model." + ) + else: + raise ValueError( + f"Multi-pair model not found at {self.paths.model_path}. " + "Run the backtest first to train the model." + ) + + def score_pairs( + self, + pair_features: dict[str, pd.DataFrame], + pairs: list[TradingPair] + ) -> list[DivergenceSignal]: + """ + Score all pairs and return ranked signals. + + Args: + pair_features: Feature DataFrames by pair_id + pairs: List of TradingPair objects + + Returns: + List of DivergenceSignal sorted by score (descending) + """ + if self.model is None: + logger.warning("Model not loaded") + return [] + + signals = [] + pair_map = {p.pair_id: p for p in pairs} + + for pair_id, features in pair_features.items(): + if pair_id not in pair_map: + continue + + pair = pair_map[pair_id] + + # Get latest features + if len(features) == 0: + continue + + latest = features.iloc[-1] + z_score = latest['z_score'] + + # Skip if Z-score below threshold + if abs(z_score) < self.config.z_entry_threshold: + continue + + # Prepare features for prediction + # Handle missing feature columns gracefully + available_cols = [c for c in self.feature_cols if c in latest.index] + missing_cols = [c for c in self.feature_cols if c not in latest.index] + + if missing_cols: + logger.debug("Missing feature columns: %s", missing_cols) + + feature_row = latest[available_cols].fillna(0) + feature_row = feature_row.replace([np.inf, -np.inf], 0) + + # Create full feature vector with zeros for missing + X_dict = {c: 0 for c in self.feature_cols} + for col in available_cols: + X_dict[col] = feature_row[col] + + X = pd.DataFrame([X_dict]) + + # Predict probability + prob = self.model.predict_proba(X)[0, 1] + + # Skip if probability below threshold + if prob < self.config.prob_threshold: + continue + + # Apply funding rate filter + base_funding = latest.get('base_funding', 0) or 0 + funding_thresh = self.config.funding_threshold + + if z_score > 0: # Short signal + if base_funding < -funding_thresh: + logger.debug( + "Skipping %s short: funding too negative (%.4f)", + pair.name, base_funding + ) + continue + else: # Long signal + if base_funding > funding_thresh: + logger.debug( + "Skipping %s long: funding too positive (%.4f)", + pair.name, base_funding + ) + continue + + # Calculate divergence score + divergence_score = abs(z_score) * prob + + # Determine direction + direction = 'short' if z_score > 0 else 'long' + + signal = DivergenceSignal( + pair=pair, + z_score=z_score, + probability=prob, + divergence_score=divergence_score, + direction=direction, + base_price=latest['base_close'], + quote_price=latest['quote_close'], + atr=latest.get('atr_base', 0), + atr_pct=latest.get('atr_pct_base', 0.02), + base_funding=base_funding + ) + signals.append(signal) + + # Sort by divergence score (highest first) + signals.sort(key=lambda s: s.divergence_score, reverse=True) + + if signals: + logger.info( + "Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f, dir=%s)", + len(signals), + signals[0].pair.name, + signals[0].divergence_score, + signals[0].z_score, + signals[0].probability, + signals[0].direction + ) + else: + logger.info("No pairs meet entry criteria") + + return signals + + def select_best_pair( + self, + signals: list[DivergenceSignal] + ) -> DivergenceSignal | None: + """ + Select the best pair from scored signals. + + Args: + signals: List of DivergenceSignal (pre-sorted by score) + + Returns: + Best signal or None if no valid candidates + """ + if not signals: + return None + return signals[0] + + def generate_signal( + self, + pair_features: dict[str, pd.DataFrame], + pairs: list[TradingPair] + ) -> dict: + """ + Generate trading signal from latest features. + + Args: + pair_features: Feature DataFrames by pair_id + pairs: List of TradingPair objects + + Returns: + Signal dictionary with action, pair, direction, etc. + """ + # Score all pairs + signals = self.score_pairs(pair_features, pairs) + + # Select best + best = self.select_best_pair(signals) + + if best is None: + return { + 'action': 'hold', + 'reason': 'no_valid_signals' + } + + return { + 'action': 'entry', + 'pair': best.pair, + 'pair_id': best.pair.pair_id, + 'direction': best.direction, + 'z_score': best.z_score, + 'probability': best.probability, + 'divergence_score': best.divergence_score, + 'base_price': best.base_price, + 'quote_price': best.quote_price, + 'atr': best.atr, + 'atr_pct': best.atr_pct, + 'base_funding': best.base_funding, + 'reason': f'{best.pair.name} z={best.z_score:.2f} p={best.probability:.2f}' + } + + def check_exit_signal( + self, + pair_features: dict[str, pd.DataFrame], + current_pair_id: str + ) -> dict: + """ + Check if current position should be exited. + + Exit conditions: + 1. Z-Score reverted to mean (|Z| < threshold) + + Args: + pair_features: Feature DataFrames by pair_id + current_pair_id: Current position's pair ID + + Returns: + Signal dictionary with action and reason + """ + if current_pair_id not in pair_features: + return { + 'action': 'exit', + 'reason': 'pair_data_missing' + } + + features = pair_features[current_pair_id] + if len(features) == 0: + return { + 'action': 'exit', + 'reason': 'no_data' + } + + latest = features.iloc[-1] + z_score = latest['z_score'] + + # Check mean reversion + if abs(z_score) < self.config.z_exit_threshold: + return { + 'action': 'exit', + 'reason': f'mean_reversion (z={z_score:.2f})' + } + + return { + 'action': 'hold', + 'z_score': z_score, + 'reason': f'holding (z={z_score:.2f})' + } + + def calculate_sl_tp( + self, + entry_price: float, + direction: str, + atr: float, + atr_pct: float + ) -> tuple[float, float]: + """ + Calculate ATR-based dynamic stop-loss and take-profit prices. + + Args: + entry_price: Entry price + direction: 'long' or 'short' + atr: ATR in price units + atr_pct: ATR as percentage of price + + Returns: + Tuple of (stop_loss_price, take_profit_price) + """ + if atr > 0 and atr_pct > 0: + sl_distance = atr * self.config.sl_atr_multiplier + tp_distance = atr * self.config.tp_atr_multiplier + + sl_pct = sl_distance / entry_price + tp_pct = tp_distance / entry_price + else: + sl_pct = self.config.base_sl_pct + tp_pct = self.config.base_tp_pct + + # Apply bounds + sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct)) + tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct)) + + if direction == 'long': + stop_loss = entry_price * (1 - sl_pct) + take_profit = entry_price * (1 + tp_pct) + else: + stop_loss = entry_price * (1 + sl_pct) + take_profit = entry_price * (1 - tp_pct) + + return stop_loss, take_profit + + def calculate_position_size( + self, + divergence_score: float, + available_usdt: float + ) -> float: + """ + Calculate position size based on divergence score. + + Args: + divergence_score: Combined score (|z| * prob) + available_usdt: Available USDT balance + + Returns: + Position size in USDT + """ + if self.config.max_position_usdt <= 0: + base_size = available_usdt + else: + base_size = min(available_usdt, self.config.max_position_usdt) + + # Scale by divergence (1.0 at 0.5 score, up to 2.0 at 1.0+ score) + base_threshold = 0.5 + if divergence_score <= base_threshold: + scale = 1.0 + else: + scale = 1.0 + (divergence_score - base_threshold) / base_threshold + scale = min(scale, 2.0) + + size = base_size * scale + + if size < self.config.min_position_usdt: + return 0.0 + + return min(size, available_usdt * 0.95)