feat: Add Multi-Pair Divergence Live Trading Module
- Introduced a new module for live trading based on the Multi-Pair Divergence Strategy. - Implemented configuration classes for OKX API and multi-pair settings. - Developed data feed functionality to fetch real-time OHLCV and funding data for multiple assets. - Created a trading bot orchestrator to manage trading cycles, including entry and exit signals based on ML model predictions. - Added comprehensive logging and error handling for robust operation. - Included a README with setup instructions and usage guidelines for the new module.
This commit is contained in:
98
check_demo_account.py
Normal file
98
check_demo_account.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Check OKX demo account positions and recent orders.
|
||||
|
||||
Usage:
|
||||
uv run python check_demo_account.py
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from live_trading.config import OKXConfig
|
||||
import ccxt
|
||||
|
||||
|
||||
def main():
|
||||
"""Check demo account status."""
|
||||
config = OKXConfig()
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" OKX Demo Account Check")
|
||||
print(f"{'='*60}")
|
||||
print(f" Demo Mode: {config.demo_mode}")
|
||||
print(f" API Key: {config.api_key[:8]}..." if config.api_key else " API Key: NOT SET")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
exchange = ccxt.okx({
|
||||
'apiKey': config.api_key,
|
||||
'secret': config.secret,
|
||||
'password': config.password,
|
||||
'sandbox': config.demo_mode,
|
||||
'options': {'defaultType': 'swap'},
|
||||
'enableRateLimit': True,
|
||||
})
|
||||
|
||||
# Check balance
|
||||
print("--- BALANCE ---")
|
||||
balance = exchange.fetch_balance()
|
||||
usdt = balance.get('USDT', {})
|
||||
print(f"USDT Total: {usdt.get('total', 0):.2f}")
|
||||
print(f"USDT Free: {usdt.get('free', 0):.2f}")
|
||||
print(f"USDT Used: {usdt.get('used', 0):.2f}")
|
||||
|
||||
# Check all balances
|
||||
print("\n--- ALL NON-ZERO BALANCES ---")
|
||||
for currency, data in balance.items():
|
||||
if isinstance(data, dict) and data.get('total', 0) > 0:
|
||||
print(f"{currency}: total={data.get('total', 0):.6f}, free={data.get('free', 0):.6f}")
|
||||
|
||||
# Check open positions
|
||||
print("\n--- OPEN POSITIONS ---")
|
||||
positions = exchange.fetch_positions()
|
||||
open_positions = [p for p in positions if abs(float(p.get('contracts', 0))) > 0]
|
||||
|
||||
if open_positions:
|
||||
for pos in open_positions:
|
||||
print(f" {pos['symbol']}: {pos['side']} {pos['contracts']} contracts @ {pos.get('entryPrice', 'N/A')}")
|
||||
print(f" Unrealized PnL: {pos.get('unrealizedPnl', 'N/A')}")
|
||||
else:
|
||||
print(" No open positions")
|
||||
|
||||
# Check recent orders (last 50)
|
||||
print("\n--- RECENT ORDERS (last 24h) ---")
|
||||
try:
|
||||
# Fetch closed orders for AVAX
|
||||
orders = exchange.fetch_orders('AVAX/USDT:USDT', limit=20)
|
||||
if orders:
|
||||
for order in orders[-10:]: # Last 10
|
||||
ts = datetime.fromtimestamp(order['timestamp']/1000, tz=timezone.utc)
|
||||
print(f" [{ts.strftime('%H:%M:%S')}] {order['side'].upper()} {order['amount']} AVAX @ {order.get('average', order.get('price', 'market'))}")
|
||||
print(f" Status: {order['status']}, Filled: {order.get('filled', 0)}, ID: {order['id']}")
|
||||
else:
|
||||
print(" No recent AVAX orders")
|
||||
except Exception as e:
|
||||
print(f" Could not fetch orders: {e}")
|
||||
|
||||
# Check order history more broadly
|
||||
print("\n--- ORDER HISTORY (AVAX) ---")
|
||||
try:
|
||||
# Try fetching my trades
|
||||
trades = exchange.fetch_my_trades('AVAX/USDT:USDT', limit=10)
|
||||
if trades:
|
||||
for trade in trades[-5:]:
|
||||
ts = datetime.fromtimestamp(trade['timestamp']/1000, tz=timezone.utc)
|
||||
print(f" [{ts.strftime('%Y-%m-%d %H:%M:%S')}] {trade['side'].upper()} {trade['amount']} @ {trade['price']}")
|
||||
print(f" Fee: {trade.get('fee', {}).get('cost', 'N/A')} {trade.get('fee', {}).get('currency', '')}")
|
||||
else:
|
||||
print(" No recent AVAX trades")
|
||||
except Exception as e:
|
||||
print(f" Could not fetch trades: {e}")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -9,13 +9,7 @@ from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load .env from sibling project (BTC_spot_MVRV)
|
||||
ENV_PATH = Path(__file__).parent.parent.parent / "BTC_spot_MVRV" / ".env"
|
||||
if ENV_PATH.exists():
|
||||
load_dotenv(ENV_PATH)
|
||||
else:
|
||||
# Fallback to local .env
|
||||
load_dotenv()
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
145
live_trading/multi_pair/README.md
Normal file
145
live_trading/multi_pair/README.md
Normal file
@@ -0,0 +1,145 @@
|
||||
# Multi-Pair Divergence Live Trading
|
||||
|
||||
This module implements live trading for the Multi-Pair Divergence Selection Strategy on OKX perpetual futures.
|
||||
|
||||
## Overview
|
||||
|
||||
The strategy scans 10 cryptocurrency pairs for spread divergence opportunities:
|
||||
|
||||
1. **Pair Universe**: Top 10 assets by market cap (BTC, ETH, SOL, XRP, BNB, DOGE, ADA, AVAX, LINK, DOT)
|
||||
2. **Spread Z-Score**: Identifies when pairs are divergent from their historical mean
|
||||
3. **Universal ML Model**: Predicts probability of successful mean reversion
|
||||
4. **Dynamic Selection**: Trades the pair with highest divergence score
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before running live trading, you must train the model via backtesting:
|
||||
|
||||
```bash
|
||||
uv run python scripts/run_multi_pair_backtest.py
|
||||
```
|
||||
|
||||
This creates `data/multi_pair_model.pkl` which the live trading bot requires.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. API Keys
|
||||
|
||||
Same as single-pair trading. Set in `.env`:
|
||||
|
||||
```env
|
||||
OKX_API_KEY=your_api_key
|
||||
OKX_SECRET=your_secret
|
||||
OKX_PASSWORD=your_passphrase
|
||||
OKX_DEMO_MODE=true # Use demo for testing
|
||||
```
|
||||
|
||||
### 2. Dependencies
|
||||
|
||||
All dependencies are in `pyproject.toml`. No additional installation needed.
|
||||
|
||||
## Usage
|
||||
|
||||
### Run with Demo Account (Recommended First)
|
||||
|
||||
```bash
|
||||
uv run python -m live_trading.multi_pair.main
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
```bash
|
||||
# Custom position size
|
||||
uv run python -m live_trading.multi_pair.main --max-position 500
|
||||
|
||||
# Custom leverage
|
||||
uv run python -m live_trading.multi_pair.main --leverage 2
|
||||
|
||||
# Custom cycle interval (in seconds)
|
||||
uv run python -m live_trading.multi_pair.main --interval 1800
|
||||
|
||||
# Combine options
|
||||
uv run python -m live_trading.multi_pair.main --max-position 1000 --leverage 3 --interval 3600
|
||||
```
|
||||
|
||||
### Live Trading (Use with Caution)
|
||||
|
||||
```bash
|
||||
uv run python -m live_trading.multi_pair.main --live
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Each Trading Cycle
|
||||
|
||||
1. **Fetch Data**: Gets OHLCV for all 10 assets from OKX
|
||||
2. **Calculate Features**: Computes Z-Score, RSI, volatility for all 45 pair combinations
|
||||
3. **Score Pairs**: Uses ML model to rank pairs by divergence score (|Z| x probability)
|
||||
4. **Check Exits**: If holding, check mean reversion or SL/TP
|
||||
5. **Enter Best**: If no position, enter the highest-scoring divergent pair
|
||||
|
||||
### Entry Conditions
|
||||
|
||||
- |Z-Score| > 1.0 (spread diverged from mean)
|
||||
- ML probability > 0.5 (model predicts successful reversion)
|
||||
- Funding rate filter passes (avoid crowded trades)
|
||||
|
||||
### Exit Conditions
|
||||
|
||||
- Mean reversion: |Z-Score| returns to ~0
|
||||
- Stop-loss: ATR-based (default ~6%)
|
||||
- Take-profit: ATR-based (default ~5%)
|
||||
|
||||
## Strategy Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `z_entry_threshold` | 1.0 | Enter when \|Z-Score\| > threshold |
|
||||
| `z_exit_threshold` | 0.0 | Exit when Z reverts to mean |
|
||||
| `z_window` | 24 | Rolling window for Z-Score (hours) |
|
||||
| `prob_threshold` | 0.5 | ML probability threshold for entry |
|
||||
| `funding_threshold` | 0.0005 | Funding rate filter (0.05%) |
|
||||
| `sl_atr_multiplier` | 10.0 | Stop-loss as ATR multiple |
|
||||
| `tp_atr_multiplier` | 8.0 | Take-profit as ATR multiple |
|
||||
|
||||
## Files
|
||||
|
||||
### Input
|
||||
|
||||
- `data/multi_pair_model.pkl` - Pre-trained ML model (required)
|
||||
|
||||
### Output
|
||||
|
||||
- `logs/multi_pair_live.log` - Trading logs
|
||||
- `live_trading/multi_pair_positions.json` - Position persistence
|
||||
- `live_trading/multi_pair_trade_log.csv` - Trade history
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
live_trading/multi_pair/
|
||||
__init__.py # Module exports
|
||||
config.py # Configuration classes
|
||||
data_feed.py # Multi-asset OHLCV fetcher
|
||||
strategy.py # ML scoring and signal generation
|
||||
main.py # Bot orchestrator
|
||||
README.md # This file
|
||||
```
|
||||
|
||||
## Differences from Single-Pair
|
||||
|
||||
| Aspect | Single-Pair | Multi-Pair |
|
||||
|--------|-------------|------------|
|
||||
| Assets | ETH only (BTC context) | 10 assets, 45 pairs |
|
||||
| Model | ETH-specific | Universal across pairs |
|
||||
| Selection | Fixed pair | Dynamic best pair |
|
||||
| Stops | Fixed 6%/5% | ATR-based dynamic |
|
||||
|
||||
## Risk Warning
|
||||
|
||||
This is experimental trading software. Use at your own risk:
|
||||
|
||||
- Always start with demo trading
|
||||
- Never risk more than you can afford to lose
|
||||
- Monitor the bot regularly
|
||||
- The model was trained on historical data and may not predict future performance
|
||||
11
live_trading/multi_pair/__init__.py
Normal file
11
live_trading/multi_pair/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Multi-Pair Divergence Live Trading Module."""
|
||||
from .config import MultiPairLiveConfig, get_multi_pair_config
|
||||
from .data_feed import MultiPairDataFeed
|
||||
from .strategy import LiveMultiPairStrategy
|
||||
|
||||
__all__ = [
|
||||
"MultiPairLiveConfig",
|
||||
"get_multi_pair_config",
|
||||
"MultiPairDataFeed",
|
||||
"LiveMultiPairStrategy",
|
||||
]
|
||||
145
live_trading/multi_pair/config.py
Normal file
145
live_trading/multi_pair/config.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Configuration for Multi-Pair Live Trading.
|
||||
|
||||
Extends the base live trading config with multi-pair specific settings.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OKXConfig:
|
||||
"""OKX API configuration."""
|
||||
api_key: str = field(default_factory=lambda: "")
|
||||
secret: str = field(default_factory=lambda: "")
|
||||
password: str = field(default_factory=lambda: "")
|
||||
demo_mode: bool = field(default_factory=lambda: True)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Load credentials based on demo mode setting."""
|
||||
self.demo_mode = os.getenv("OKX_DEMO_MODE", "true").lower() in ("true", "1", "yes")
|
||||
|
||||
if self.demo_mode:
|
||||
self.api_key = os.getenv("OKX_DEMO_API_KEY", os.getenv("OKX_API_KEY", ""))
|
||||
self.secret = os.getenv("OKX_DEMO_SECRET", os.getenv("OKX_SECRET", ""))
|
||||
self.password = os.getenv("OKX_DEMO_PASSWORD", os.getenv("OKX_PASSWORD", ""))
|
||||
else:
|
||||
self.api_key = os.getenv("OKX_API_KEY", "")
|
||||
self.secret = os.getenv("OKX_SECRET", "")
|
||||
self.password = os.getenv("OKX_PASSWORD", "")
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate that required credentials are present."""
|
||||
mode = "demo" if self.demo_mode else "live"
|
||||
if not self.api_key:
|
||||
raise ValueError(f"OKX API key not set for {mode} mode")
|
||||
if not self.secret:
|
||||
raise ValueError(f"OKX secret not set for {mode} mode")
|
||||
if not self.password:
|
||||
raise ValueError(f"OKX password not set for {mode} mode")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiPairLiveConfig:
|
||||
"""
|
||||
Configuration for multi-pair live trading.
|
||||
|
||||
Combines trading parameters, strategy settings, and risk management.
|
||||
"""
|
||||
# Asset Universe (top 10 by market cap perpetuals)
|
||||
assets: list[str] = field(default_factory=lambda: [
|
||||
"BTC/USDT:USDT", "ETH/USDT:USDT", "SOL/USDT:USDT", "XRP/USDT:USDT",
|
||||
"BNB/USDT:USDT", "DOGE/USDT:USDT", "ADA/USDT:USDT", "AVAX/USDT:USDT",
|
||||
"LINK/USDT:USDT", "DOT/USDT:USDT"
|
||||
])
|
||||
|
||||
# Timeframe
|
||||
timeframe: str = "1h"
|
||||
candles_to_fetch: int = 500 # Enough for feature calculation
|
||||
|
||||
# Z-Score Thresholds
|
||||
z_window: int = 24
|
||||
z_entry_threshold: float = 1.0
|
||||
z_exit_threshold: float = 0.0 # Exit at mean reversion
|
||||
|
||||
# ML Thresholds
|
||||
prob_threshold: float = 0.5
|
||||
|
||||
# Position sizing
|
||||
max_position_usdt: float = -1.0 # If <= 0, use all available funds
|
||||
min_position_usdt: float = 10.0
|
||||
leverage: int = 1
|
||||
margin_mode: str = "cross"
|
||||
max_concurrent_positions: int = 1 # Trade one pair at a time
|
||||
|
||||
# Risk Management - ATR-Based Stops
|
||||
atr_period: int = 14
|
||||
sl_atr_multiplier: float = 10.0
|
||||
tp_atr_multiplier: float = 8.0
|
||||
|
||||
# Fallback fixed percentages
|
||||
base_sl_pct: float = 0.06
|
||||
base_tp_pct: float = 0.05
|
||||
|
||||
# ATR bounds
|
||||
min_sl_pct: float = 0.02
|
||||
max_sl_pct: float = 0.10
|
||||
min_tp_pct: float = 0.02
|
||||
max_tp_pct: float = 0.15
|
||||
|
||||
# Funding Rate Filter
|
||||
funding_threshold: float = 0.0005 # 0.05%
|
||||
|
||||
# Trade Management
|
||||
min_hold_bars: int = 0
|
||||
cooldown_bars: int = 0
|
||||
|
||||
# Execution
|
||||
sleep_seconds: int = 3600 # Run every hour
|
||||
slippage_pct: float = 0.001
|
||||
|
||||
def get_asset_short_name(self, symbol: str) -> str:
|
||||
"""Convert symbol to short name (e.g., BTC/USDT:USDT -> btc)."""
|
||||
return symbol.split("/")[0].lower()
|
||||
|
||||
def get_pair_count(self) -> int:
|
||||
"""Calculate number of unique pairs from asset list."""
|
||||
n = len(self.assets)
|
||||
return n * (n - 1) // 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathConfig:
|
||||
"""File paths configuration."""
|
||||
base_dir: Path = field(
|
||||
default_factory=lambda: Path(__file__).parent.parent.parent
|
||||
)
|
||||
data_dir: Path = field(default=None)
|
||||
logs_dir: Path = field(default=None)
|
||||
model_path: Path = field(default=None)
|
||||
positions_file: Path = field(default=None)
|
||||
trade_log_file: Path = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
self.data_dir = self.base_dir / "data"
|
||||
self.logs_dir = self.base_dir / "logs"
|
||||
# Use the same model as backtesting
|
||||
self.model_path = self.base_dir / "data" / "multi_pair_model.pkl"
|
||||
self.positions_file = self.base_dir / "live_trading" / "multi_pair_positions.json"
|
||||
self.trade_log_file = self.base_dir / "live_trading" / "multi_pair_trade_log.csv"
|
||||
|
||||
# Ensure directories exist
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_multi_pair_config() -> tuple[OKXConfig, MultiPairLiveConfig, PathConfig]:
|
||||
"""Get all configuration objects for multi-pair trading."""
|
||||
okx = OKXConfig()
|
||||
trading = MultiPairLiveConfig()
|
||||
paths = PathConfig()
|
||||
return okx, trading, paths
|
||||
336
live_trading/multi_pair/data_feed.py
Normal file
336
live_trading/multi_pair/data_feed.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
Multi-Pair Data Feed for Live Trading.
|
||||
|
||||
Fetches real-time OHLCV and funding data for all assets in the universe.
|
||||
"""
|
||||
import logging
|
||||
from itertools import combinations
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import ta
|
||||
|
||||
from live_trading.okx_client import OKXClient
|
||||
from .config import MultiPairLiveConfig, PathConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TradingPair:
|
||||
"""
|
||||
Represents a tradeable pair for spread analysis.
|
||||
|
||||
Attributes:
|
||||
base_asset: First asset symbol (e.g., ETH/USDT:USDT)
|
||||
quote_asset: Second asset symbol (e.g., BTC/USDT:USDT)
|
||||
pair_id: Unique identifier
|
||||
"""
|
||||
def __init__(self, base_asset: str, quote_asset: str):
|
||||
self.base_asset = base_asset
|
||||
self.quote_asset = quote_asset
|
||||
self.pair_id = f"{base_asset}__{quote_asset}"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Human-readable pair name."""
|
||||
base = self.base_asset.split("/")[0]
|
||||
quote = self.quote_asset.split("/")[0]
|
||||
return f"{base}/{quote}"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.pair_id)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, TradingPair):
|
||||
return False
|
||||
return self.pair_id == other.pair_id
|
||||
|
||||
|
||||
class MultiPairDataFeed:
|
||||
"""
|
||||
Real-time data feed for multi-pair strategy.
|
||||
|
||||
Fetches OHLCV data for all assets and calculates spread features
|
||||
for all pair combinations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
okx_client: OKXClient,
|
||||
config: MultiPairLiveConfig,
|
||||
path_config: PathConfig
|
||||
):
|
||||
self.client = okx_client
|
||||
self.config = config
|
||||
self.paths = path_config
|
||||
|
||||
# Cache for asset data
|
||||
self._asset_data: dict[str, pd.DataFrame] = {}
|
||||
self._funding_rates: dict[str, float] = {}
|
||||
self._pairs: list[TradingPair] = []
|
||||
|
||||
# Generate pairs
|
||||
self._generate_pairs()
|
||||
|
||||
def _generate_pairs(self) -> None:
|
||||
"""Generate all unique pairs from asset universe."""
|
||||
self._pairs = []
|
||||
for base, quote in combinations(self.config.assets, 2):
|
||||
pair = TradingPair(base_asset=base, quote_asset=quote)
|
||||
self._pairs.append(pair)
|
||||
|
||||
logger.info("Generated %d pairs from %d assets",
|
||||
len(self._pairs), len(self.config.assets))
|
||||
|
||||
@property
|
||||
def pairs(self) -> list[TradingPair]:
|
||||
"""Get list of trading pairs."""
|
||||
return self._pairs
|
||||
|
||||
def fetch_all_ohlcv(self) -> dict[str, pd.DataFrame]:
|
||||
"""
|
||||
Fetch OHLCV data for all assets.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol to OHLCV DataFrame
|
||||
"""
|
||||
self._asset_data = {}
|
||||
|
||||
for symbol in self.config.assets:
|
||||
try:
|
||||
ohlcv = self.client.fetch_ohlcv(
|
||||
symbol,
|
||||
self.config.timeframe,
|
||||
self.config.candles_to_fetch
|
||||
)
|
||||
df = self._ohlcv_to_dataframe(ohlcv)
|
||||
|
||||
if len(df) >= 200:
|
||||
self._asset_data[symbol] = df
|
||||
logger.debug("Fetched %s: %d candles", symbol, len(df))
|
||||
else:
|
||||
logger.warning("Skipping %s: insufficient data (%d)",
|
||||
symbol, len(df))
|
||||
except Exception as e:
|
||||
logger.error("Error fetching %s: %s", symbol, e)
|
||||
|
||||
logger.info("Fetched data for %d/%d assets",
|
||||
len(self._asset_data), len(self.config.assets))
|
||||
return self._asset_data
|
||||
|
||||
def _ohlcv_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
|
||||
"""Convert OHLCV list to DataFrame."""
|
||||
df = pd.DataFrame(
|
||||
ohlcv,
|
||||
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
df.set_index('timestamp', inplace=True)
|
||||
return df
|
||||
|
||||
def fetch_all_funding_rates(self) -> dict[str, float]:
|
||||
"""
|
||||
Fetch current funding rates for all assets.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol to funding rate
|
||||
"""
|
||||
self._funding_rates = {}
|
||||
|
||||
for symbol in self.config.assets:
|
||||
try:
|
||||
rate = self.client.get_funding_rate(symbol)
|
||||
self._funding_rates[symbol] = rate
|
||||
except Exception as e:
|
||||
logger.warning("Could not get funding for %s: %s", symbol, e)
|
||||
self._funding_rates[symbol] = 0.0
|
||||
|
||||
return self._funding_rates
|
||||
|
||||
def calculate_pair_features(
|
||||
self,
|
||||
pair: TradingPair
|
||||
) -> pd.DataFrame | None:
|
||||
"""
|
||||
Calculate features for a single pair.
|
||||
|
||||
Args:
|
||||
pair: Trading pair
|
||||
|
||||
Returns:
|
||||
DataFrame with features, or None if insufficient data
|
||||
"""
|
||||
base = pair.base_asset
|
||||
quote = pair.quote_asset
|
||||
|
||||
if base not in self._asset_data or quote not in self._asset_data:
|
||||
return None
|
||||
|
||||
df_base = self._asset_data[base]
|
||||
df_quote = self._asset_data[quote]
|
||||
|
||||
# Align indices
|
||||
common_idx = df_base.index.intersection(df_quote.index)
|
||||
if len(common_idx) < 200:
|
||||
return None
|
||||
|
||||
df_a = df_base.loc[common_idx]
|
||||
df_b = df_quote.loc[common_idx]
|
||||
|
||||
# Calculate spread (base / quote)
|
||||
spread = df_a['close'] / df_b['close']
|
||||
|
||||
# Z-Score
|
||||
z_window = self.config.z_window
|
||||
rolling_mean = spread.rolling(window=z_window).mean()
|
||||
rolling_std = spread.rolling(window=z_window).std()
|
||||
z_score = (spread - rolling_mean) / rolling_std
|
||||
|
||||
# Spread Technicals
|
||||
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||
spread_roc = spread.pct_change(periods=5) * 100
|
||||
spread_change_1h = spread.pct_change(periods=1)
|
||||
|
||||
# Volume Analysis
|
||||
vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10)
|
||||
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||
vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10)
|
||||
|
||||
# Volatility
|
||||
ret_a = df_a['close'].pct_change()
|
||||
ret_b = df_b['close'].pct_change()
|
||||
vol_a = ret_a.rolling(window=z_window).std()
|
||||
vol_b = ret_b.rolling(window=z_window).std()
|
||||
vol_spread_ratio = vol_a / (vol_b + 1e-10)
|
||||
|
||||
# Realized Volatility
|
||||
realized_vol_a = ret_a.rolling(window=24).std()
|
||||
realized_vol_b = ret_b.rolling(window=24).std()
|
||||
|
||||
# ATR (Average True Range)
|
||||
high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close']
|
||||
|
||||
tr_a = pd.concat([
|
||||
high_a - low_a,
|
||||
(high_a - close_a.shift(1)).abs(),
|
||||
(low_a - close_a.shift(1)).abs()
|
||||
], axis=1).max(axis=1)
|
||||
atr_a = tr_a.rolling(window=self.config.atr_period).mean()
|
||||
atr_pct_a = atr_a / close_a
|
||||
|
||||
# Build feature DataFrame
|
||||
features = pd.DataFrame(index=common_idx)
|
||||
features['pair_id'] = pair.pair_id
|
||||
features['base_asset'] = base
|
||||
features['quote_asset'] = quote
|
||||
|
||||
# Price data
|
||||
features['spread'] = spread
|
||||
features['base_close'] = df_a['close']
|
||||
features['quote_close'] = df_b['close']
|
||||
features['base_volume'] = df_a['volume']
|
||||
|
||||
# Core Features
|
||||
features['z_score'] = z_score
|
||||
features['spread_rsi'] = spread_rsi
|
||||
features['spread_roc'] = spread_roc
|
||||
features['spread_change_1h'] = spread_change_1h
|
||||
features['vol_ratio'] = vol_ratio
|
||||
features['vol_ratio_rel'] = vol_ratio_rel
|
||||
features['vol_diff_ratio'] = vol_spread_ratio
|
||||
|
||||
# Volatility
|
||||
features['realized_vol_base'] = realized_vol_a
|
||||
features['realized_vol_quote'] = realized_vol_b
|
||||
features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2
|
||||
|
||||
# ATR
|
||||
features['atr_base'] = atr_a
|
||||
features['atr_pct_base'] = atr_pct_a
|
||||
|
||||
# Pair encoding
|
||||
assets = self.config.assets
|
||||
features['base_idx'] = assets.index(base) if base in assets else -1
|
||||
features['quote_idx'] = assets.index(quote) if quote in assets else -1
|
||||
|
||||
# Funding rates
|
||||
base_funding = self._funding_rates.get(base, 0.0)
|
||||
quote_funding = self._funding_rates.get(quote, 0.0)
|
||||
features['base_funding'] = base_funding
|
||||
features['quote_funding'] = quote_funding
|
||||
features['funding_diff'] = base_funding - quote_funding
|
||||
features['funding_avg'] = (base_funding + quote_funding) / 2
|
||||
|
||||
# Drop NaN rows in core features
|
||||
core_cols = [
|
||||
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
|
||||
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
|
||||
'realized_vol_base', 'atr_base', 'atr_pct_base'
|
||||
]
|
||||
features = features.dropna(subset=core_cols)
|
||||
|
||||
return features
|
||||
|
||||
def calculate_all_pair_features(self) -> dict[str, pd.DataFrame]:
|
||||
"""
|
||||
Calculate features for all pairs.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping pair_id to feature DataFrame
|
||||
"""
|
||||
all_features = {}
|
||||
|
||||
for pair in self._pairs:
|
||||
features = self.calculate_pair_features(pair)
|
||||
if features is not None and len(features) > 0:
|
||||
all_features[pair.pair_id] = features
|
||||
|
||||
logger.info("Calculated features for %d/%d pairs",
|
||||
len(all_features), len(self._pairs))
|
||||
|
||||
return all_features
|
||||
|
||||
def get_latest_data(self) -> dict[str, pd.DataFrame] | None:
|
||||
"""
|
||||
Fetch and process latest market data for all pairs.
|
||||
|
||||
Returns:
|
||||
Dictionary of pair features or None on error
|
||||
"""
|
||||
try:
|
||||
# Fetch OHLCV for all assets
|
||||
self.fetch_all_ohlcv()
|
||||
|
||||
if len(self._asset_data) < 2:
|
||||
logger.warning("Insufficient assets fetched")
|
||||
return None
|
||||
|
||||
# Fetch funding rates
|
||||
self.fetch_all_funding_rates()
|
||||
|
||||
# Calculate features for all pairs
|
||||
pair_features = self.calculate_all_pair_features()
|
||||
|
||||
if not pair_features:
|
||||
logger.warning("No pair features calculated")
|
||||
return None
|
||||
|
||||
logger.info("Processed %d pairs with valid features", len(pair_features))
|
||||
return pair_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error fetching market data: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
def get_pair_by_id(self, pair_id: str) -> TradingPair | None:
|
||||
"""Get pair object by ID."""
|
||||
for pair in self._pairs:
|
||||
if pair.pair_id == pair_id:
|
||||
return pair
|
||||
return None
|
||||
|
||||
def get_current_price(self, symbol: str) -> float | None:
|
||||
"""Get current price for a symbol."""
|
||||
if symbol in self._asset_data:
|
||||
return self._asset_data[symbol]['close'].iloc[-1]
|
||||
return None
|
||||
609
live_trading/multi_pair/main.py
Normal file
609
live_trading/multi_pair/main.py
Normal file
@@ -0,0 +1,609 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Multi-Pair Divergence Live Trading Bot.
|
||||
|
||||
Trades the top 10 cryptocurrency pairs based on spread divergence
|
||||
using a universal ML model for signal generation.
|
||||
|
||||
Usage:
|
||||
# Run with demo account (default)
|
||||
uv run python -m live_trading.multi_pair.main
|
||||
|
||||
# Run with specific settings
|
||||
uv run python -m live_trading.multi_pair.main --max-position 500 --leverage 2
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from live_trading.okx_client import OKXClient
|
||||
from live_trading.position_manager import PositionManager
|
||||
from live_trading.multi_pair.config import (
|
||||
OKXConfig, MultiPairLiveConfig, PathConfig, get_multi_pair_config
|
||||
)
|
||||
from live_trading.multi_pair.data_feed import MultiPairDataFeed, TradingPair
|
||||
from live_trading.multi_pair.strategy import LiveMultiPairStrategy
|
||||
|
||||
|
||||
def setup_logging(log_dir: Path) -> logging.Logger:
|
||||
"""Configure logging for the trading bot."""
|
||||
log_file = log_dir / "multi_pair_live.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_file),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
force=True
|
||||
)
|
||||
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PositionState:
|
||||
"""Track current position state for multi-pair."""
|
||||
pair: TradingPair | None = None
|
||||
pair_id: str | None = None
|
||||
direction: str | None = None
|
||||
entry_price: float = 0.0
|
||||
size: float = 0.0
|
||||
stop_loss: float = 0.0
|
||||
take_profit: float = 0.0
|
||||
entry_time: datetime | None = None
|
||||
|
||||
|
||||
class MultiPairLiveTradingBot:
|
||||
"""
|
||||
Main trading bot for multi-pair divergence strategy.
|
||||
|
||||
Coordinates data fetching, pair scoring, and order execution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
okx_config: OKXConfig,
|
||||
trading_config: MultiPairLiveConfig,
|
||||
path_config: PathConfig
|
||||
):
|
||||
self.okx_config = okx_config
|
||||
self.trading_config = trading_config
|
||||
self.path_config = path_config
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.running = True
|
||||
|
||||
# Initialize components
|
||||
self.logger.info("Initializing multi-pair trading bot...")
|
||||
|
||||
# Create OKX client with adapted config
|
||||
self._adapted_trading_config = self._adapt_config_for_okx_client()
|
||||
self.okx_client = OKXClient(okx_config, self._adapted_trading_config)
|
||||
|
||||
# Initialize data feed
|
||||
self.data_feed = MultiPairDataFeed(
|
||||
self.okx_client, trading_config, path_config
|
||||
)
|
||||
|
||||
# Initialize position manager (reuse from single-pair)
|
||||
self.position_manager = PositionManager(
|
||||
self.okx_client, self._adapted_trading_config, path_config
|
||||
)
|
||||
|
||||
# Initialize strategy
|
||||
self.strategy = LiveMultiPairStrategy(trading_config, path_config)
|
||||
|
||||
# Current position state
|
||||
self.position = PositionState()
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, self._handle_shutdown)
|
||||
signal.signal(signal.SIGTERM, self._handle_shutdown)
|
||||
|
||||
self._print_startup_banner()
|
||||
|
||||
# Sync with exchange positions on startup
|
||||
self._sync_position_from_exchange()
|
||||
|
||||
def _adapt_config_for_okx_client(self):
|
||||
"""Create config compatible with OKXClient."""
|
||||
# OKXClient expects specific attributes
|
||||
@dataclass
|
||||
class AdaptedConfig:
|
||||
eth_symbol: str = "ETH/USDT:USDT"
|
||||
btc_symbol: str = "BTC/USDT:USDT"
|
||||
timeframe: str = "1h"
|
||||
candles_to_fetch: int = 500
|
||||
max_position_usdt: float = -1.0
|
||||
min_position_usdt: float = 10.0
|
||||
leverage: int = 1
|
||||
margin_mode: str = "cross"
|
||||
stop_loss_pct: float = 0.06
|
||||
take_profit_pct: float = 0.05
|
||||
max_concurrent_positions: int = 1
|
||||
z_entry_threshold: float = 1.0
|
||||
z_window: int = 24
|
||||
model_prob_threshold: float = 0.5
|
||||
funding_threshold: float = 0.0005
|
||||
sleep_seconds: int = 3600
|
||||
slippage_pct: float = 0.001
|
||||
|
||||
adapted = AdaptedConfig()
|
||||
adapted.timeframe = self.trading_config.timeframe
|
||||
adapted.candles_to_fetch = self.trading_config.candles_to_fetch
|
||||
adapted.max_position_usdt = self.trading_config.max_position_usdt
|
||||
adapted.min_position_usdt = self.trading_config.min_position_usdt
|
||||
adapted.leverage = self.trading_config.leverage
|
||||
adapted.margin_mode = self.trading_config.margin_mode
|
||||
adapted.max_concurrent_positions = self.trading_config.max_concurrent_positions
|
||||
adapted.sleep_seconds = self.trading_config.sleep_seconds
|
||||
adapted.slippage_pct = self.trading_config.slippage_pct
|
||||
|
||||
return adapted
|
||||
|
||||
def _print_startup_banner(self) -> None:
|
||||
"""Print startup information."""
|
||||
mode = "DEMO/SANDBOX" if self.okx_config.demo_mode else "LIVE"
|
||||
|
||||
print("=" * 60)
|
||||
print(" Multi-Pair Divergence Strategy - Live Trading Bot")
|
||||
print("=" * 60)
|
||||
print(f" Mode: {mode}")
|
||||
print(f" Assets: {len(self.trading_config.assets)} assets")
|
||||
print(f" Pairs: {self.trading_config.get_pair_count()} pairs")
|
||||
print(f" Timeframe: {self.trading_config.timeframe}")
|
||||
print(f" Max Position: ${self.trading_config.max_position_usdt if self.trading_config.max_position_usdt > 0 else 'All available'}")
|
||||
print(f" Leverage: {self.trading_config.leverage}x")
|
||||
print(f" Z-Entry: > {self.trading_config.z_entry_threshold}")
|
||||
print(f" Prob Threshold: > {self.trading_config.prob_threshold}")
|
||||
print(f" Cycle Interval: {self.trading_config.sleep_seconds // 60} minutes")
|
||||
print("=" * 60)
|
||||
print(f" Assets: {', '.join([a.split('/')[0] for a in self.trading_config.assets])}")
|
||||
print("=" * 60)
|
||||
|
||||
if not self.okx_config.demo_mode:
|
||||
print("\n *** WARNING: LIVE TRADING MODE - REAL FUNDS AT RISK ***\n")
|
||||
|
||||
def _handle_shutdown(self, signum, frame) -> None:
|
||||
"""Handle shutdown signals gracefully."""
|
||||
self.logger.info("Shutdown signal received, stopping...")
|
||||
self.running = False
|
||||
|
||||
def _sync_position_from_exchange(self) -> bool:
|
||||
"""
|
||||
Sync internal position state with exchange positions.
|
||||
|
||||
Checks for existing open positions on the exchange and updates
|
||||
internal state to match. This prevents stacking positions when
|
||||
the bot is restarted.
|
||||
|
||||
Returns:
|
||||
True if a position was synced, False otherwise
|
||||
"""
|
||||
try:
|
||||
positions = self.okx_client.get_positions()
|
||||
|
||||
if not positions:
|
||||
if self.position.pair is not None:
|
||||
# Position was closed externally (e.g., SL/TP hit)
|
||||
self.logger.info(
|
||||
"Position %s was closed externally, resetting state",
|
||||
self.position.pair.name if self.position.pair else "unknown"
|
||||
)
|
||||
self.position = PositionState()
|
||||
return False
|
||||
|
||||
# Check each position against our tradeable assets
|
||||
our_assets = set(self.trading_config.assets)
|
||||
|
||||
for pos in positions:
|
||||
pos_symbol = pos.get('symbol', '')
|
||||
contracts = abs(float(pos.get('contracts', 0)))
|
||||
|
||||
if contracts == 0:
|
||||
continue
|
||||
|
||||
# Check if this position is for one of our assets
|
||||
if pos_symbol not in our_assets:
|
||||
continue
|
||||
|
||||
# Found a position for one of our assets
|
||||
side = pos.get('side', 'long')
|
||||
entry_price = float(pos.get('entryPrice', 0))
|
||||
unrealized_pnl = float(pos.get('unrealizedPnl', 0))
|
||||
|
||||
# If we already track this position, just update
|
||||
if (self.position.pair is not None and
|
||||
self.position.pair.base_asset == pos_symbol):
|
||||
self.logger.debug(
|
||||
"Position already tracked: %s %s %.2f contracts",
|
||||
side, pos_symbol, contracts
|
||||
)
|
||||
return True
|
||||
|
||||
# New position found - sync it
|
||||
# Find or create a TradingPair for this position
|
||||
matched_pair = None
|
||||
for pair in self.data_feed.pairs:
|
||||
if pair.base_asset == pos_symbol:
|
||||
matched_pair = pair
|
||||
break
|
||||
|
||||
if matched_pair is None:
|
||||
# Create a placeholder pair (we don't know the quote asset)
|
||||
matched_pair = TradingPair(
|
||||
base_asset=pos_symbol,
|
||||
quote_asset="UNKNOWN"
|
||||
)
|
||||
|
||||
# Calculate approximate SL/TP based on config defaults
|
||||
sl_pct = self.trading_config.base_sl_pct
|
||||
tp_pct = self.trading_config.base_tp_pct
|
||||
|
||||
if side == 'long':
|
||||
stop_loss = entry_price * (1 - sl_pct)
|
||||
take_profit = entry_price * (1 + tp_pct)
|
||||
else:
|
||||
stop_loss = entry_price * (1 + sl_pct)
|
||||
take_profit = entry_price * (1 - tp_pct)
|
||||
|
||||
self.position = PositionState(
|
||||
pair=matched_pair,
|
||||
pair_id=matched_pair.pair_id,
|
||||
direction=side,
|
||||
entry_price=entry_price,
|
||||
size=contracts,
|
||||
stop_loss=stop_loss,
|
||||
take_profit=take_profit,
|
||||
entry_time=None # Unknown for synced positions
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"Synced existing position from exchange: %s %s %.4f @ %.4f (PnL: %.2f)",
|
||||
side.upper(),
|
||||
pos_symbol,
|
||||
contracts,
|
||||
entry_price,
|
||||
unrealized_pnl
|
||||
)
|
||||
return True
|
||||
|
||||
# No matching positions found
|
||||
if self.position.pair is not None:
|
||||
self.logger.info(
|
||||
"Position %s no longer exists on exchange, resetting state",
|
||||
self.position.pair.name
|
||||
)
|
||||
self.position = PositionState()
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to sync position from exchange: %s", e)
|
||||
return False
|
||||
|
||||
def run_trading_cycle(self) -> None:
|
||||
"""
|
||||
Execute one trading cycle.
|
||||
|
||||
1. Sync position state with exchange
|
||||
2. Fetch latest market data for all assets
|
||||
3. Calculate features for all pairs
|
||||
4. Score pairs and find best opportunity
|
||||
5. Check exit conditions for current position
|
||||
6. Execute trades if needed
|
||||
"""
|
||||
cycle_start = datetime.now(timezone.utc)
|
||||
self.logger.info("--- Trading Cycle Start: %s ---", cycle_start.isoformat())
|
||||
|
||||
try:
|
||||
# 1. Sync position state with exchange (detect SL/TP closures)
|
||||
self._sync_position_from_exchange()
|
||||
|
||||
# 2. Fetch all market data
|
||||
pair_features = self.data_feed.get_latest_data()
|
||||
if pair_features is None:
|
||||
self.logger.warning("No market data available, skipping cycle")
|
||||
return
|
||||
|
||||
# 2. Check exit conditions for current position
|
||||
if self.position.pair is not None:
|
||||
exit_signal = self.strategy.check_exit_signal(
|
||||
pair_features,
|
||||
self.position.pair_id
|
||||
)
|
||||
|
||||
if exit_signal['action'] == 'exit':
|
||||
self._execute_exit(exit_signal)
|
||||
else:
|
||||
# Check SL/TP
|
||||
current_price = self.data_feed.get_current_price(
|
||||
self.position.pair.base_asset
|
||||
)
|
||||
if current_price:
|
||||
sl_tp_exit = self._check_sl_tp(current_price)
|
||||
if sl_tp_exit:
|
||||
self._execute_exit({'reason': sl_tp_exit})
|
||||
|
||||
# 3. Generate entry signal if no position
|
||||
if self.position.pair is None:
|
||||
entry_signal = self.strategy.generate_signal(
|
||||
pair_features,
|
||||
self.data_feed.pairs
|
||||
)
|
||||
|
||||
if entry_signal['action'] == 'entry':
|
||||
self._execute_entry(entry_signal)
|
||||
|
||||
# 4. Log status
|
||||
if self.position.pair:
|
||||
self.logger.info(
|
||||
"Position: %s %s, entry=%.4f, current PnL check pending",
|
||||
self.position.direction,
|
||||
self.position.pair.name,
|
||||
self.position.entry_price
|
||||
)
|
||||
else:
|
||||
self.logger.info("No open position")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Trading cycle error: %s", e, exc_info=True)
|
||||
|
||||
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
|
||||
self.logger.info("--- Cycle completed in %.1fs ---", cycle_duration)
|
||||
|
||||
def _check_sl_tp(self, current_price: float) -> str | None:
|
||||
"""Check stop-loss and take-profit levels."""
|
||||
if self.position.direction == 'long':
|
||||
if current_price <= self.position.stop_loss:
|
||||
return f"stop_loss ({current_price:.4f} <= {self.position.stop_loss:.4f})"
|
||||
if current_price >= self.position.take_profit:
|
||||
return f"take_profit ({current_price:.4f} >= {self.position.take_profit:.4f})"
|
||||
else: # short
|
||||
if current_price >= self.position.stop_loss:
|
||||
return f"stop_loss ({current_price:.4f} >= {self.position.stop_loss:.4f})"
|
||||
if current_price <= self.position.take_profit:
|
||||
return f"take_profit ({current_price:.4f} <= {self.position.take_profit:.4f})"
|
||||
return None
|
||||
|
||||
def _execute_entry(self, signal: dict) -> None:
|
||||
"""Execute entry trade."""
|
||||
pair = signal['pair']
|
||||
symbol = pair.base_asset # Trade the base asset
|
||||
direction = signal['direction']
|
||||
|
||||
self.logger.info(
|
||||
"Entry signal: %s %s (z=%.2f, p=%.2f, score=%.3f)",
|
||||
direction.upper(),
|
||||
pair.name,
|
||||
signal['z_score'],
|
||||
signal['probability'],
|
||||
signal['divergence_score']
|
||||
)
|
||||
|
||||
# Get account balance
|
||||
try:
|
||||
balance = self.okx_client.get_balance()
|
||||
available_usdt = balance['free']
|
||||
except Exception as e:
|
||||
self.logger.error("Could not get balance: %s", e)
|
||||
return
|
||||
|
||||
# Calculate position size
|
||||
size_usdt = self.strategy.calculate_position_size(
|
||||
signal['divergence_score'],
|
||||
available_usdt
|
||||
)
|
||||
|
||||
if size_usdt <= 0:
|
||||
self.logger.info("Position size too small, skipping entry")
|
||||
return
|
||||
|
||||
current_price = signal['base_price']
|
||||
size_asset = size_usdt / current_price
|
||||
|
||||
# Calculate SL/TP
|
||||
stop_loss, take_profit = self.strategy.calculate_sl_tp(
|
||||
current_price,
|
||||
direction,
|
||||
signal['atr'],
|
||||
signal['atr_pct']
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"Executing %s entry: %.6f %s @ %.4f ($%.2f), SL=%.4f, TP=%.4f",
|
||||
direction.upper(),
|
||||
size_asset,
|
||||
symbol.split('/')[0],
|
||||
current_price,
|
||||
size_usdt,
|
||||
stop_loss,
|
||||
take_profit
|
||||
)
|
||||
|
||||
try:
|
||||
# Place market order
|
||||
order_side = "buy" if direction == "long" else "sell"
|
||||
order = self.okx_client.place_market_order(symbol, order_side, size_asset)
|
||||
|
||||
filled_price = order.get('average') or order.get('price') or current_price
|
||||
filled_amount = order.get('filled') or order.get('amount') or size_asset
|
||||
|
||||
if filled_price is None or filled_price == 0:
|
||||
filled_price = current_price
|
||||
if filled_amount is None or filled_amount == 0:
|
||||
filled_amount = size_asset
|
||||
|
||||
# Recalculate SL/TP with filled price
|
||||
stop_loss, take_profit = self.strategy.calculate_sl_tp(
|
||||
filled_price, direction, signal['atr'], signal['atr_pct']
|
||||
)
|
||||
|
||||
# Update position state
|
||||
self.position = PositionState(
|
||||
pair=pair,
|
||||
pair_id=pair.pair_id,
|
||||
direction=direction,
|
||||
entry_price=filled_price,
|
||||
size=filled_amount,
|
||||
stop_loss=stop_loss,
|
||||
take_profit=take_profit,
|
||||
entry_time=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"Position opened: %s %s %.6f @ %.4f",
|
||||
direction.upper(),
|
||||
pair.name,
|
||||
filled_amount,
|
||||
filled_price
|
||||
)
|
||||
|
||||
# Try to set SL/TP on exchange
|
||||
try:
|
||||
self.okx_client.set_stop_loss_take_profit(
|
||||
symbol, direction, filled_amount, stop_loss, take_profit
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning("Could not set SL/TP on exchange: %s", e)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Order execution failed: %s", e, exc_info=True)
|
||||
|
||||
def _execute_exit(self, signal: dict) -> None:
|
||||
"""Execute exit trade."""
|
||||
if self.position.pair is None:
|
||||
return
|
||||
|
||||
symbol = self.position.pair.base_asset
|
||||
reason = signal.get('reason', 'unknown')
|
||||
|
||||
self.logger.info(
|
||||
"Exit signal: %s %s, reason: %s",
|
||||
self.position.direction,
|
||||
self.position.pair.name,
|
||||
reason
|
||||
)
|
||||
|
||||
try:
|
||||
# Close position on exchange
|
||||
self.okx_client.close_position(symbol)
|
||||
|
||||
self.logger.info(
|
||||
"Position closed: %s %s",
|
||||
self.position.direction,
|
||||
self.position.pair.name
|
||||
)
|
||||
|
||||
# Reset position state
|
||||
self.position = PositionState()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Exit execution failed: %s", e, exc_info=True)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Main trading loop."""
|
||||
self.logger.info("Starting multi-pair trading loop...")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self.run_trading_cycle()
|
||||
|
||||
if self.running:
|
||||
sleep_seconds = self.trading_config.sleep_seconds
|
||||
minutes = sleep_seconds // 60
|
||||
self.logger.info("Sleeping for %d minutes...", minutes)
|
||||
|
||||
for _ in range(sleep_seconds):
|
||||
if not self.running:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
self.logger.info("Keyboard interrupt received")
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error("Unexpected error in main loop: %s", e, exc_info=True)
|
||||
time.sleep(60)
|
||||
|
||||
self.logger.info("Shutting down...")
|
||||
self.logger.info("Shutdown complete")
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Multi-Pair Divergence Live Trading Bot"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-position",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Maximum position size in USDT"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leverage",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Trading leverage (1-125)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interval",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Trading cycle interval in seconds"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--live",
|
||||
action="store_true",
|
||||
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
args = parse_args()
|
||||
|
||||
# Load configuration
|
||||
okx_config, trading_config, path_config = get_multi_pair_config()
|
||||
|
||||
# Apply command line overrides
|
||||
if args.max_position is not None:
|
||||
trading_config.max_position_usdt = args.max_position
|
||||
if args.leverage is not None:
|
||||
trading_config.leverage = args.leverage
|
||||
if args.interval is not None:
|
||||
trading_config.sleep_seconds = args.interval
|
||||
if args.live:
|
||||
okx_config.demo_mode = False
|
||||
|
||||
# Setup logging
|
||||
logger = setup_logging(path_config.logs_dir)
|
||||
|
||||
try:
|
||||
# Validate config
|
||||
okx_config.validate()
|
||||
|
||||
# Create and run bot
|
||||
bot = MultiPairLiveTradingBot(okx_config, trading_config, path_config)
|
||||
bot.run()
|
||||
except ValueError as e:
|
||||
logger.error("Configuration error: %s", e)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error("Fatal error: %s", e, exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
396
live_trading/multi_pair/strategy.py
Normal file
396
live_trading/multi_pair/strategy.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""
|
||||
Live Multi-Pair Divergence Strategy.
|
||||
|
||||
Scores all pairs and selects the best divergence opportunity for trading.
|
||||
Uses the pre-trained universal ML model from backtesting.
|
||||
"""
|
||||
import logging
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
# Opt-in to future pandas behavior to silence FutureWarning on fillna
|
||||
pd.set_option('future.no_silent_downcasting', True)
|
||||
|
||||
from .config import MultiPairLiveConfig, PathConfig
|
||||
from .data_feed import TradingPair
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DivergenceSignal:
|
||||
"""
|
||||
Signal for a divergent pair.
|
||||
|
||||
Attributes:
|
||||
pair: Trading pair
|
||||
z_score: Current Z-Score of the spread
|
||||
probability: ML model probability of profitable reversion
|
||||
divergence_score: Combined score (|z_score| * probability)
|
||||
direction: 'long' or 'short' (relative to base asset)
|
||||
base_price: Current price of base asset
|
||||
quote_price: Current price of quote asset
|
||||
atr: Average True Range in price units
|
||||
atr_pct: ATR as percentage of price
|
||||
"""
|
||||
pair: TradingPair
|
||||
z_score: float
|
||||
probability: float
|
||||
divergence_score: float
|
||||
direction: str
|
||||
base_price: float
|
||||
quote_price: float
|
||||
atr: float
|
||||
atr_pct: float
|
||||
base_funding: float = 0.0
|
||||
|
||||
|
||||
class LiveMultiPairStrategy:
|
||||
"""
|
||||
Live trading implementation of multi-pair divergence strategy.
|
||||
|
||||
Scores all pairs using the universal ML model and selects
|
||||
the best opportunity for mean-reversion trading.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MultiPairLiveConfig,
|
||||
path_config: PathConfig
|
||||
):
|
||||
self.config = config
|
||||
self.paths = path_config
|
||||
self.model: RandomForestClassifier | None = None
|
||||
self.feature_cols: list[str] | None = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Load pre-trained model from backtesting."""
|
||||
if self.paths.model_path.exists():
|
||||
try:
|
||||
with open(self.paths.model_path, 'rb') as f:
|
||||
saved = pickle.load(f)
|
||||
self.model = saved['model']
|
||||
self.feature_cols = saved['feature_cols']
|
||||
logger.info("Loaded model from %s", self.paths.model_path)
|
||||
except Exception as e:
|
||||
logger.error("Could not load model: %s", e)
|
||||
raise ValueError(
|
||||
f"Multi-pair model not found at {self.paths.model_path}. "
|
||||
"Run the backtest first to train the model."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Multi-pair model not found at {self.paths.model_path}. "
|
||||
"Run the backtest first to train the model."
|
||||
)
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pair_features: dict[str, pd.DataFrame],
|
||||
pairs: list[TradingPair]
|
||||
) -> list[DivergenceSignal]:
|
||||
"""
|
||||
Score all pairs and return ranked signals.
|
||||
|
||||
Args:
|
||||
pair_features: Feature DataFrames by pair_id
|
||||
pairs: List of TradingPair objects
|
||||
|
||||
Returns:
|
||||
List of DivergenceSignal sorted by score (descending)
|
||||
"""
|
||||
if self.model is None:
|
||||
logger.warning("Model not loaded")
|
||||
return []
|
||||
|
||||
signals = []
|
||||
pair_map = {p.pair_id: p for p in pairs}
|
||||
|
||||
for pair_id, features in pair_features.items():
|
||||
if pair_id not in pair_map:
|
||||
continue
|
||||
|
||||
pair = pair_map[pair_id]
|
||||
|
||||
# Get latest features
|
||||
if len(features) == 0:
|
||||
continue
|
||||
|
||||
latest = features.iloc[-1]
|
||||
z_score = latest['z_score']
|
||||
|
||||
# Skip if Z-score below threshold
|
||||
if abs(z_score) < self.config.z_entry_threshold:
|
||||
continue
|
||||
|
||||
# Prepare features for prediction
|
||||
# Handle missing feature columns gracefully
|
||||
available_cols = [c for c in self.feature_cols if c in latest.index]
|
||||
missing_cols = [c for c in self.feature_cols if c not in latest.index]
|
||||
|
||||
if missing_cols:
|
||||
logger.debug("Missing feature columns: %s", missing_cols)
|
||||
|
||||
feature_row = latest[available_cols].fillna(0)
|
||||
feature_row = feature_row.replace([np.inf, -np.inf], 0)
|
||||
|
||||
# Create full feature vector with zeros for missing
|
||||
X_dict = {c: 0 for c in self.feature_cols}
|
||||
for col in available_cols:
|
||||
X_dict[col] = feature_row[col]
|
||||
|
||||
X = pd.DataFrame([X_dict])
|
||||
|
||||
# Predict probability
|
||||
prob = self.model.predict_proba(X)[0, 1]
|
||||
|
||||
# Skip if probability below threshold
|
||||
if prob < self.config.prob_threshold:
|
||||
continue
|
||||
|
||||
# Apply funding rate filter
|
||||
base_funding = latest.get('base_funding', 0) or 0
|
||||
funding_thresh = self.config.funding_threshold
|
||||
|
||||
if z_score > 0: # Short signal
|
||||
if base_funding < -funding_thresh:
|
||||
logger.debug(
|
||||
"Skipping %s short: funding too negative (%.4f)",
|
||||
pair.name, base_funding
|
||||
)
|
||||
continue
|
||||
else: # Long signal
|
||||
if base_funding > funding_thresh:
|
||||
logger.debug(
|
||||
"Skipping %s long: funding too positive (%.4f)",
|
||||
pair.name, base_funding
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate divergence score
|
||||
divergence_score = abs(z_score) * prob
|
||||
|
||||
# Determine direction
|
||||
direction = 'short' if z_score > 0 else 'long'
|
||||
|
||||
signal = DivergenceSignal(
|
||||
pair=pair,
|
||||
z_score=z_score,
|
||||
probability=prob,
|
||||
divergence_score=divergence_score,
|
||||
direction=direction,
|
||||
base_price=latest['base_close'],
|
||||
quote_price=latest['quote_close'],
|
||||
atr=latest.get('atr_base', 0),
|
||||
atr_pct=latest.get('atr_pct_base', 0.02),
|
||||
base_funding=base_funding
|
||||
)
|
||||
signals.append(signal)
|
||||
|
||||
# Sort by divergence score (highest first)
|
||||
signals.sort(key=lambda s: s.divergence_score, reverse=True)
|
||||
|
||||
if signals:
|
||||
logger.info(
|
||||
"Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f, dir=%s)",
|
||||
len(signals),
|
||||
signals[0].pair.name,
|
||||
signals[0].divergence_score,
|
||||
signals[0].z_score,
|
||||
signals[0].probability,
|
||||
signals[0].direction
|
||||
)
|
||||
else:
|
||||
logger.info("No pairs meet entry criteria")
|
||||
|
||||
return signals
|
||||
|
||||
def select_best_pair(
|
||||
self,
|
||||
signals: list[DivergenceSignal]
|
||||
) -> DivergenceSignal | None:
|
||||
"""
|
||||
Select the best pair from scored signals.
|
||||
|
||||
Args:
|
||||
signals: List of DivergenceSignal (pre-sorted by score)
|
||||
|
||||
Returns:
|
||||
Best signal or None if no valid candidates
|
||||
"""
|
||||
if not signals:
|
||||
return None
|
||||
return signals[0]
|
||||
|
||||
def generate_signal(
|
||||
self,
|
||||
pair_features: dict[str, pd.DataFrame],
|
||||
pairs: list[TradingPair]
|
||||
) -> dict:
|
||||
"""
|
||||
Generate trading signal from latest features.
|
||||
|
||||
Args:
|
||||
pair_features: Feature DataFrames by pair_id
|
||||
pairs: List of TradingPair objects
|
||||
|
||||
Returns:
|
||||
Signal dictionary with action, pair, direction, etc.
|
||||
"""
|
||||
# Score all pairs
|
||||
signals = self.score_pairs(pair_features, pairs)
|
||||
|
||||
# Select best
|
||||
best = self.select_best_pair(signals)
|
||||
|
||||
if best is None:
|
||||
return {
|
||||
'action': 'hold',
|
||||
'reason': 'no_valid_signals'
|
||||
}
|
||||
|
||||
return {
|
||||
'action': 'entry',
|
||||
'pair': best.pair,
|
||||
'pair_id': best.pair.pair_id,
|
||||
'direction': best.direction,
|
||||
'z_score': best.z_score,
|
||||
'probability': best.probability,
|
||||
'divergence_score': best.divergence_score,
|
||||
'base_price': best.base_price,
|
||||
'quote_price': best.quote_price,
|
||||
'atr': best.atr,
|
||||
'atr_pct': best.atr_pct,
|
||||
'base_funding': best.base_funding,
|
||||
'reason': f'{best.pair.name} z={best.z_score:.2f} p={best.probability:.2f}'
|
||||
}
|
||||
|
||||
def check_exit_signal(
|
||||
self,
|
||||
pair_features: dict[str, pd.DataFrame],
|
||||
current_pair_id: str
|
||||
) -> dict:
|
||||
"""
|
||||
Check if current position should be exited.
|
||||
|
||||
Exit conditions:
|
||||
1. Z-Score reverted to mean (|Z| < threshold)
|
||||
|
||||
Args:
|
||||
pair_features: Feature DataFrames by pair_id
|
||||
current_pair_id: Current position's pair ID
|
||||
|
||||
Returns:
|
||||
Signal dictionary with action and reason
|
||||
"""
|
||||
if current_pair_id not in pair_features:
|
||||
return {
|
||||
'action': 'exit',
|
||||
'reason': 'pair_data_missing'
|
||||
}
|
||||
|
||||
features = pair_features[current_pair_id]
|
||||
if len(features) == 0:
|
||||
return {
|
||||
'action': 'exit',
|
||||
'reason': 'no_data'
|
||||
}
|
||||
|
||||
latest = features.iloc[-1]
|
||||
z_score = latest['z_score']
|
||||
|
||||
# Check mean reversion
|
||||
if abs(z_score) < self.config.z_exit_threshold:
|
||||
return {
|
||||
'action': 'exit',
|
||||
'reason': f'mean_reversion (z={z_score:.2f})'
|
||||
}
|
||||
|
||||
return {
|
||||
'action': 'hold',
|
||||
'z_score': z_score,
|
||||
'reason': f'holding (z={z_score:.2f})'
|
||||
}
|
||||
|
||||
def calculate_sl_tp(
|
||||
self,
|
||||
entry_price: float,
|
||||
direction: str,
|
||||
atr: float,
|
||||
atr_pct: float
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate ATR-based dynamic stop-loss and take-profit prices.
|
||||
|
||||
Args:
|
||||
entry_price: Entry price
|
||||
direction: 'long' or 'short'
|
||||
atr: ATR in price units
|
||||
atr_pct: ATR as percentage of price
|
||||
|
||||
Returns:
|
||||
Tuple of (stop_loss_price, take_profit_price)
|
||||
"""
|
||||
if atr > 0 and atr_pct > 0:
|
||||
sl_distance = atr * self.config.sl_atr_multiplier
|
||||
tp_distance = atr * self.config.tp_atr_multiplier
|
||||
|
||||
sl_pct = sl_distance / entry_price
|
||||
tp_pct = tp_distance / entry_price
|
||||
else:
|
||||
sl_pct = self.config.base_sl_pct
|
||||
tp_pct = self.config.base_tp_pct
|
||||
|
||||
# Apply bounds
|
||||
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
|
||||
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
|
||||
|
||||
if direction == 'long':
|
||||
stop_loss = entry_price * (1 - sl_pct)
|
||||
take_profit = entry_price * (1 + tp_pct)
|
||||
else:
|
||||
stop_loss = entry_price * (1 + sl_pct)
|
||||
take_profit = entry_price * (1 - tp_pct)
|
||||
|
||||
return stop_loss, take_profit
|
||||
|
||||
def calculate_position_size(
|
||||
self,
|
||||
divergence_score: float,
|
||||
available_usdt: float
|
||||
) -> float:
|
||||
"""
|
||||
Calculate position size based on divergence score.
|
||||
|
||||
Args:
|
||||
divergence_score: Combined score (|z| * prob)
|
||||
available_usdt: Available USDT balance
|
||||
|
||||
Returns:
|
||||
Position size in USDT
|
||||
"""
|
||||
if self.config.max_position_usdt <= 0:
|
||||
base_size = available_usdt
|
||||
else:
|
||||
base_size = min(available_usdt, self.config.max_position_usdt)
|
||||
|
||||
# Scale by divergence (1.0 at 0.5 score, up to 2.0 at 1.0+ score)
|
||||
base_threshold = 0.5
|
||||
if divergence_score <= base_threshold:
|
||||
scale = 1.0
|
||||
else:
|
||||
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
|
||||
scale = min(scale, 2.0)
|
||||
|
||||
size = base_size * scale
|
||||
|
||||
if size < self.config.min_position_usdt:
|
||||
return 0.0
|
||||
|
||||
return min(size, available_usdt * 0.95)
|
||||
Reference in New Issue
Block a user