2 Commits

Author SHA1 Message Date
1af0aab5fa feat: Add Multi-Pair Divergence Live Trading Module
- Introduced a new module for live trading based on the Multi-Pair Divergence Strategy.
- Implemented configuration classes for OKX API and multi-pair settings.
- Developed data feed functionality to fetch real-time OHLCV and funding data for multiple assets.
- Created a trading bot orchestrator to manage trading cycles, including entry and exit signals based on ML model predictions.
- Added comprehensive logging and error handling for robust operation.
- Included a README with setup instructions and usage guidelines for the new module.
2026-01-15 22:17:13 +08:00
df37366603 feat: Multi-Pair Divergence Selection Strategy
- Extend regime detection to top 10 cryptocurrencies (45 pairs)
- Dynamic pair selection based on divergence score (|z_score| * probability)
- Universal ML model trained on all pairs
- Correlation-based filtering to avoid redundant positions
- Funding rate integration from OKX for all 10 assets
- ATR-based dynamic stop-loss and take-profit
- Walk-forward training with 70/30 split

Performance: +35.69% return (vs +28.66% baseline), 63.6% win rate
2026-01-15 20:47:23 +08:00
43 changed files with 4466 additions and 3364 deletions

4
.gitignore vendored
View File

@@ -175,7 +175,3 @@ data/backtest_runs.db
.gitignore
live_trading/regime_model.pkl
live_trading/positions.json
*.pkl
*.db

304
README.md
View File

@@ -1,262 +1,82 @@
# Lowkey Backtest
### lowkey_backtest — Supertrend Backtester
A backtesting framework supporting multiple market types (spot, perpetual) with realistic trading simulation including leverage, funding, and shorts.
## Requirements
### Overview
Backtest a simple, long-only strategy driven by a meta Supertrend signal on aggregated OHLCV data. The script:
- Loads 1-minute BTC/USD data from `../data/btcusd_1-min_data.csv`
- Aggregates to multiple timeframes (e.g., `5min`, `15min`, `30min`, `1h`, `4h`, `1d`)
- Computes three Supertrend variants and creates a meta signal when all agree
- Executes entries/exits at the aggregated bar open price
- Applies OKX spot fee assumptions (taker by default)
- Evaluates stop-loss using intra-bar 1-minute data
- Writes detailed trade logs and a summary CSV
### Requirements
- Python 3.12+
- Package manager: `uv`
- Dependencies: `pandas`, `numpy`, `ta`
- Package management: `uv`
## Installation
Install dependencies with uv:
```bash
uv sync
# If a dependency is missing, add it explicitly and sync
uv add pandas numpy ta
uv sync
```
## Quick Reference
### Data
- Expected CSV location: `../data/btcusd_1-min_data.csv` (relative to the repo root)
- Required columns: `Timestamp`, `Open`, `High`, `Low`, `Close`, `Volume`
- `Timestamp` should be UNIX seconds; zero-volume rows are ignored
| Command | Description |
|---------|-------------|
| `uv run python main.py download -p BTC-USDT` | Download data |
| `uv run python main.py backtest -s meta_st -p BTC-USDT` | Run backtest |
| `uv run python main.py wfa -s regime -p BTC-USDT` | Walk-forward analysis |
| `uv run python train_model.py --download` | Train/retrain ML model |
| `uv run python research/regime_detection.py` | Run research script |
---
## Backtest CLI
The main entry point is `main.py` which provides three commands: `download`, `backtest`, and `wfa`.
### Download Data
Download historical OHLCV data from exchanges.
### Quickstart
Run the backtest with defaults:
```bash
uv run python main.py download -p BTC-USDT -t 1h
uv run python main.py
```
**Options:**
- `-p, --pair` (required): Trading pair (e.g., `BTC-USDT`, `ETH-USDT`)
- `-t, --timeframe`: Timeframe (default: `1m`)
- `-e, --exchange`: Exchange (default: `okx`)
- `-m, --market`: Market type: `spot` or `perpetual` (default: `spot`)
- `--start`: Start date in `YYYY-MM-DD` format
Outputs:
- Per-run trade logs in `backtest_logs/` named like `trade_log_<TIMEFRAME>_sl<STOPLOSS>.csv`
- Run-level summary in `backtest_summary.csv`
**Examples:**
```bash
# Download 1-hour spot data
uv run python main.py download -p ETH-USDT -t 1h
### Configuring a Run
Adjust parameters directly in `main.py`:
- Date range (in `load_data`): `load_data('2021-11-01', '2024-10-16')`
- Timeframes to test (any subset of `"5min", "15min", "30min", "1h", "4h", "1d"`):
- `timeframes = ["5min", "15min", "30min", "1h", "4h", "1d"]`
- Stop-loss percentages:
- `stoplosses = [0.03, 0.05, 0.1]`
- Supertrend settings (in `add_supertrend_indicators`): `(period, multiplier)` pairs `(12, 3.0)`, `(10, 1.0)`, `(11, 2.0)`
- Fee model (in `calculate_okx_taker_maker_fee`): taker `0.0010`, maker `0.0008`
# Download perpetual data from a specific date
uv run python main.py download -p BTC-USDT -m perpetual --start 2024-01-01
### What the Backtester Does
- Aggregation: Resamples 1-minute data to your selected timeframe using OHLCV rules
- Supertrend signals: Computes three Supertrends and derives a meta trend of `+1` (bullish) or `-1` (bearish) when all agree; otherwise `0`
- Trade logic (long-only):
- Entry when the meta trend changes to bullish; uses aggregated bar open price
- Exit when the meta trend changes to bearish; uses aggregated bar open price
- Stop-loss: For each aggregated bar, scans corresponding 1-minute closes to detect stop-loss and exits using a realistic fill (threshold or next 1-minute open if gapped)
- Performance metrics: total return, max drawdown, Sharpe (daily, factor 252), win rate, number of trades, final/initial equity, and total fees
### Important: Lookahead Bias Note
The current implementation uses the meta Supertrend signal of the same bar for entries/exits, which introduces lookahead bias. To avoid this, lag the signal by one bar inside `backtest()` in `main.py`:
```python
# Replace the current line
meta_trend_signal = meta_trend
# With a one-bar lag to remove lookahead
# meta_trend_signal = np.roll(meta_trend, 1)
# meta_trend_signal[0] = 0
```
### Run Backtest
### Outputs
- `backtest_logs/trade_log_<TIMEFRAME>_sl<STOPLOSS>.csv`: trade-by-trade records including type (`buy`, `sell`, `stop_loss`, `forced_close`), timestamps, prices, balances, PnL, and fees
- `backtest_summary.csv`: one row per (timeframe, stop-loss) combination with `timeframe`, `stop_loss`, `total_return`, `max_drawdown`, `sharpe_ratio`, `win_rate`, `num_trades`, `final_equity`, `initial_equity`, `num_stop_losses`, `total_fees`
Run a backtest with a specific strategy.
### Troubleshooting
- CSV not found: Ensure the dataset is located at `../data/btcusd_1-min_data.csv`
- Missing packages: Run `uv add pandas numpy ta` then `uv sync`
- Memory/performance: Large date ranges on 1-minute data can be heavy; narrow the date span or test fewer timeframes
```bash
uv run python main.py backtest -s <strategy> -p <pair> [options]
```
**Available Strategies:**
- `meta_st` - Meta Supertrend (triple supertrend consensus)
- `regime` - Regime Reversion (ML-based spread trading)
- `rsi` - RSI overbought/oversold
- `macross` - Moving Average Crossover
**Options:**
- `-s, --strategy` (required): Strategy name
- `-p, --pair` (required): Trading pair
- `-t, --timeframe`: Timeframe (default: `1m`)
- `--start`: Start date
- `--end`: End date
- `-g, --grid`: Run grid search optimization
- `--plot`: Show equity curve plot
- `--sl`: Stop loss percentage
- `--tp`: Take profit percentage
- `--trail`: Enable trailing stop
- `--fees`: Override fee rate
- `--slippage`: Slippage (default: `0.001`)
- `-l, --leverage`: Leverage multiplier
**Examples:**
```bash
# Basic backtest with Meta Supertrend
uv run python main.py backtest -s meta_st -p BTC-USDT -t 1h
# Backtest with date range and plot
uv run python main.py backtest -s meta_st -p BTC-USDT --start 2024-01-01 --end 2024-12-31 --plot
# Grid search optimization
uv run python main.py backtest -s meta_st -p BTC-USDT -t 4h -g
# Backtest with risk parameters
uv run python main.py backtest -s meta_st -p BTC-USDT --sl 0.05 --tp 0.10 --trail
# Regime strategy on ETH/BTC spread
uv run python main.py backtest -s regime -p ETH-USDT -t 1h
```
### Walk-Forward Analysis (WFA)
Run walk-forward optimization to avoid overfitting.
```bash
uv run python main.py wfa -s <strategy> -p <pair> [options]
```
**Options:**
- `-s, --strategy` (required): Strategy name
- `-p, --pair` (required): Trading pair
- `-t, --timeframe`: Timeframe (default: `1d`)
- `-w, --windows`: Number of walk-forward windows (default: `10`)
- `--plot`: Show WFA results plot
**Examples:**
```bash
# Walk-forward analysis with 10 windows
uv run python main.py wfa -s meta_st -p BTC-USDT -t 1d -w 10
# WFA with plot output
uv run python main.py wfa -s regime -p ETH-USDT --plot
```
---
## Research Scripts
Research scripts are located in the `research/` directory for experimental analysis.
### Regime Detection Research
Tests multiple holding horizons for the regime reversion strategy using walk-forward training.
```bash
uv run python research/regime_detection.py
```
**Options:**
- `--days DAYS`: Number of days of historical data (default: 90)
- `--start DATE`: Start date (YYYY-MM-DD), overrides `--days`
- `--end DATE`: End date (YYYY-MM-DD), defaults to now
- `--output PATH`: Output CSV path
**Examples:**
```bash
# Use last 90 days (default)
uv run python research/regime_detection.py
# Use last 180 days
uv run python research/regime_detection.py --days 180
# Specific date range
uv run python research/regime_detection.py --start 2025-07-01 --end 2025-12-31
```
**What it does:**
- Loads BTC and ETH hourly data
- Calculates spread features (Z-score, RSI, volume ratios)
- Trains RandomForest classifier with walk-forward methodology
- Tests horizons from 6h to 150h
- Outputs best parameters by F1 score, Net PnL, and MAE
**Output:**
- Console: Summary of results for each horizon
- File: `research/horizon_optimization_results.csv`
---
## ML Model Training
The `regime` strategy uses a RandomForest classifier that can be trained with new data.
### Train Model
Train or retrain the ML model with latest data:
```bash
uv run python train_model.py [options]
```
**Options:**
- `--days DAYS`: Days of historical data (default: 90)
- `--pair PAIR`: Base pair for context (default: BTC-USDT)
- `--spread-pair PAIR`: Trading pair (default: ETH-USDT)
- `--timeframe TF`: Timeframe (default: 1h)
- `--market TYPE`: Market type: `spot` or `perpetual` (default: perpetual)
- `--output PATH`: Model output path (default: `data/regime_model.pkl`)
- `--train-ratio R`: Train/test split ratio (default: 0.7)
- `--horizon H`: Prediction horizon in bars (default: 102)
- `--download`: Download latest data before training
- `--dry-run`: Run without saving model
**Examples:**
```bash
# Train with last 90 days of data
uv run python train_model.py
# Download fresh data and train
uv run python train_model.py --download
# Train with 180 days of data
uv run python train_model.py --days 180
# Train on spot market data
uv run python train_model.py --market spot
# Dry run to see metrics without saving
uv run python train_model.py --dry-run
```
### Daily Retraining (Cron)
To automate daily model retraining, add a cron job:
```bash
# Edit crontab
crontab -e
# Add entry to retrain daily at 00:30 UTC
30 0 * * * cd /path/to/lowkey_backtest_live && uv run python train_model.py --download >> logs/training.log 2>&1
```
### Model Files
| File | Description |
|------|-------------|
| `data/regime_model.pkl` | Current production model |
| `data/regime_model_YYYYMMDD_HHMMSS.pkl` | Versioned model snapshots |
The model file contains:
- Trained RandomForest classifier
- Feature column names
- Training metrics (F1 score, sample counts)
- Training timestamp
---
## Output Files
| Location | Description |
|----------|-------------|
| `backtest_logs/` | Trade logs and WFA results |
| `research/` | Research output files |
| `data/` | Downloaded OHLCV data and ML models |
| `data/regime_model.pkl` | Trained ML model for regime strategy |
---
## Running Tests
```bash
uv run pytest tests/
```
Run a specific test file:
```bash
uv run pytest tests/test_data_manager.py
```

98
check_demo_account.py Normal file
View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python3
"""
Check OKX demo account positions and recent orders.
Usage:
uv run python check_demo_account.py
"""
import sys
from pathlib import Path
from datetime import datetime, timezone
sys.path.insert(0, str(Path(__file__).parent))
from live_trading.config import OKXConfig
import ccxt
def main():
"""Check demo account status."""
config = OKXConfig()
print(f"\n{'='*60}")
print(f" OKX Demo Account Check")
print(f"{'='*60}")
print(f" Demo Mode: {config.demo_mode}")
print(f" API Key: {config.api_key[:8]}..." if config.api_key else " API Key: NOT SET")
print(f"{'='*60}\n")
exchange = ccxt.okx({
'apiKey': config.api_key,
'secret': config.secret,
'password': config.password,
'sandbox': config.demo_mode,
'options': {'defaultType': 'swap'},
'enableRateLimit': True,
})
# Check balance
print("--- BALANCE ---")
balance = exchange.fetch_balance()
usdt = balance.get('USDT', {})
print(f"USDT Total: {usdt.get('total', 0):.2f}")
print(f"USDT Free: {usdt.get('free', 0):.2f}")
print(f"USDT Used: {usdt.get('used', 0):.2f}")
# Check all balances
print("\n--- ALL NON-ZERO BALANCES ---")
for currency, data in balance.items():
if isinstance(data, dict) and data.get('total', 0) > 0:
print(f"{currency}: total={data.get('total', 0):.6f}, free={data.get('free', 0):.6f}")
# Check open positions
print("\n--- OPEN POSITIONS ---")
positions = exchange.fetch_positions()
open_positions = [p for p in positions if abs(float(p.get('contracts', 0))) > 0]
if open_positions:
for pos in open_positions:
print(f" {pos['symbol']}: {pos['side']} {pos['contracts']} contracts @ {pos.get('entryPrice', 'N/A')}")
print(f" Unrealized PnL: {pos.get('unrealizedPnl', 'N/A')}")
else:
print(" No open positions")
# Check recent orders (last 50)
print("\n--- RECENT ORDERS (last 24h) ---")
try:
# Fetch closed orders for AVAX
orders = exchange.fetch_orders('AVAX/USDT:USDT', limit=20)
if orders:
for order in orders[-10:]: # Last 10
ts = datetime.fromtimestamp(order['timestamp']/1000, tz=timezone.utc)
print(f" [{ts.strftime('%H:%M:%S')}] {order['side'].upper()} {order['amount']} AVAX @ {order.get('average', order.get('price', 'market'))}")
print(f" Status: {order['status']}, Filled: {order.get('filled', 0)}, ID: {order['id']}")
else:
print(" No recent AVAX orders")
except Exception as e:
print(f" Could not fetch orders: {e}")
# Check order history more broadly
print("\n--- ORDER HISTORY (AVAX) ---")
try:
# Try fetching my trades
trades = exchange.fetch_my_trades('AVAX/USDT:USDT', limit=10)
if trades:
for trade in trades[-5:]:
ts = datetime.fromtimestamp(trade['timestamp']/1000, tz=timezone.utc)
print(f" [{ts.strftime('%Y-%m-%d %H:%M:%S')}] {trade['side'].upper()} {trade['amount']} @ {trade['price']}")
print(f" Fee: {trade.get('fee', {}).get('cost', 'N/A')} {trade.get('fee', {}).get('currency', '')}")
else:
print(" No recent AVAX trades")
except Exception as e:
print(f" Could not fetch trades: {e}")
print(f"\n{'='*60}\n")
if __name__ == "__main__":
main()

BIN
data/multi_pair_model.pkl Normal file

Binary file not shown.

View File

@@ -60,7 +60,7 @@ class TradingConfig:
# Position sizing
max_position_usdt: float = -1.0 # Max position size in USDT. If <= 0, use all available funds
min_position_usdt: float = 1.0 # Min position size in USDT
min_position_usdt: float = 10.0 # Min position size in USDT
leverage: int = 1 # Leverage (1x = no leverage)
margin_mode: str = "cross" # "cross" or "isolated"

View File

@@ -1,13 +0,0 @@
"""Database module for live trading persistence."""
from .database import get_db, init_db
from .models import Trade, DailySummary, Session
from .metrics import MetricsCalculator
__all__ = [
"get_db",
"init_db",
"Trade",
"DailySummary",
"Session",
"MetricsCalculator",
]

View File

@@ -1,325 +0,0 @@
"""SQLite database connection and operations."""
import sqlite3
import logging
from pathlib import Path
from typing import Optional
from contextlib import contextmanager
from .models import Trade, DailySummary, Session
logger = logging.getLogger(__name__)
# Database schema
SCHEMA = """
-- Trade history table
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trade_id TEXT UNIQUE NOT NULL,
symbol TEXT NOT NULL,
side TEXT NOT NULL,
entry_price REAL NOT NULL,
exit_price REAL,
size REAL NOT NULL,
size_usdt REAL NOT NULL,
pnl_usd REAL,
pnl_pct REAL,
entry_time TEXT NOT NULL,
exit_time TEXT,
hold_duration_hours REAL,
reason TEXT,
order_id_entry TEXT,
order_id_exit TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Daily summary table
CREATE TABLE IF NOT EXISTS daily_summary (
id INTEGER PRIMARY KEY AUTOINCREMENT,
date TEXT UNIQUE NOT NULL,
total_trades INTEGER DEFAULT 0,
winning_trades INTEGER DEFAULT 0,
total_pnl_usd REAL DEFAULT 0,
max_drawdown_usd REAL DEFAULT 0,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Session metadata
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
start_time TEXT NOT NULL,
end_time TEXT,
starting_balance REAL,
ending_balance REAL,
total_pnl REAL,
total_trades INTEGER DEFAULT 0
);
-- Indexes for common queries
CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades(entry_time);
CREATE INDEX IF NOT EXISTS idx_trades_exit_time ON trades(exit_time);
CREATE INDEX IF NOT EXISTS idx_daily_summary_date ON daily_summary(date);
"""
_db_instance: Optional["TradingDatabase"] = None
class TradingDatabase:
"""SQLite database for trade persistence."""
def __init__(self, db_path: Path):
self.db_path = db_path
self._connection: Optional[sqlite3.Connection] = None
@property
def connection(self) -> sqlite3.Connection:
"""Get or create database connection."""
if self._connection is None:
self._connection = sqlite3.connect(
str(self.db_path),
check_same_thread=False,
)
self._connection.row_factory = sqlite3.Row
return self._connection
def init_schema(self) -> None:
"""Initialize database schema."""
with self.connection:
self.connection.executescript(SCHEMA)
logger.info(f"Database initialized at {self.db_path}")
def close(self) -> None:
"""Close database connection."""
if self._connection:
self._connection.close()
self._connection = None
@contextmanager
def transaction(self):
"""Context manager for database transactions."""
try:
yield self.connection
self.connection.commit()
except Exception:
self.connection.rollback()
raise
def insert_trade(self, trade: Trade) -> int:
"""
Insert a new trade record.
Args:
trade: Trade object to insert
Returns:
Row ID of inserted trade
"""
sql = """
INSERT INTO trades (
trade_id, symbol, side, entry_price, exit_price,
size, size_usdt, pnl_usd, pnl_pct, entry_time,
exit_time, hold_duration_hours, reason,
order_id_entry, order_id_exit
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
with self.transaction():
cursor = self.connection.execute(
sql,
(
trade.trade_id,
trade.symbol,
trade.side,
trade.entry_price,
trade.exit_price,
trade.size,
trade.size_usdt,
trade.pnl_usd,
trade.pnl_pct,
trade.entry_time,
trade.exit_time,
trade.hold_duration_hours,
trade.reason,
trade.order_id_entry,
trade.order_id_exit,
),
)
return cursor.lastrowid
def update_trade(self, trade_id: str, **kwargs) -> bool:
"""
Update an existing trade record.
Args:
trade_id: Trade ID to update
**kwargs: Fields to update
Returns:
True if trade was updated
"""
if not kwargs:
return False
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
sql = f"UPDATE trades SET {set_clause} WHERE trade_id = ?"
with self.transaction():
cursor = self.connection.execute(
sql, (*kwargs.values(), trade_id)
)
return cursor.rowcount > 0
def get_trade(self, trade_id: str) -> Optional[Trade]:
"""Get a trade by ID."""
sql = "SELECT * FROM trades WHERE trade_id = ?"
row = self.connection.execute(sql, (trade_id,)).fetchone()
if row:
return Trade(**dict(row))
return None
def get_trades(
self,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
limit: Optional[int] = None,
) -> list[Trade]:
"""
Get trades within a time range.
Args:
start_time: ISO format start time filter
end_time: ISO format end time filter
limit: Maximum number of trades to return
Returns:
List of Trade objects
"""
conditions = []
params = []
if start_time:
conditions.append("entry_time >= ?")
params.append(start_time)
if end_time:
conditions.append("entry_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions) if conditions else "1=1"
limit_clause = f"LIMIT {limit}" if limit else ""
sql = f"""
SELECT * FROM trades
WHERE {where_clause}
ORDER BY entry_time DESC
{limit_clause}
"""
rows = self.connection.execute(sql, params).fetchall()
return [Trade(**dict(row)) for row in rows]
def get_all_trades(self) -> list[Trade]:
"""Get all trades."""
sql = "SELECT * FROM trades ORDER BY entry_time DESC"
rows = self.connection.execute(sql).fetchall()
return [Trade(**dict(row)) for row in rows]
def count_trades(self) -> int:
"""Get total number of trades."""
sql = "SELECT COUNT(*) FROM trades WHERE exit_time IS NOT NULL"
return self.connection.execute(sql).fetchone()[0]
def upsert_daily_summary(self, summary: DailySummary) -> None:
"""Insert or update daily summary."""
sql = """
INSERT INTO daily_summary (
date, total_trades, winning_trades, total_pnl_usd, max_drawdown_usd
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(date) DO UPDATE SET
total_trades = excluded.total_trades,
winning_trades = excluded.winning_trades,
total_pnl_usd = excluded.total_pnl_usd,
max_drawdown_usd = excluded.max_drawdown_usd,
updated_at = CURRENT_TIMESTAMP
"""
with self.transaction():
self.connection.execute(
sql,
(
summary.date,
summary.total_trades,
summary.winning_trades,
summary.total_pnl_usd,
summary.max_drawdown_usd,
),
)
def get_daily_summary(self, date: str) -> Optional[DailySummary]:
"""Get daily summary for a specific date."""
sql = "SELECT * FROM daily_summary WHERE date = ?"
row = self.connection.execute(sql, (date,)).fetchone()
if row:
return DailySummary(**dict(row))
return None
def insert_session(self, session: Session) -> int:
"""Insert a new session record."""
sql = """
INSERT INTO sessions (
start_time, end_time, starting_balance,
ending_balance, total_pnl, total_trades
) VALUES (?, ?, ?, ?, ?, ?)
"""
with self.transaction():
cursor = self.connection.execute(
sql,
(
session.start_time,
session.end_time,
session.starting_balance,
session.ending_balance,
session.total_pnl,
session.total_trades,
),
)
return cursor.lastrowid
def update_session(self, session_id: int, **kwargs) -> bool:
"""Update an existing session."""
if not kwargs:
return False
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
sql = f"UPDATE sessions SET {set_clause} WHERE id = ?"
with self.transaction():
cursor = self.connection.execute(
sql, (*kwargs.values(), session_id)
)
return cursor.rowcount > 0
def get_latest_session(self) -> Optional[Session]:
"""Get the most recent session."""
sql = "SELECT * FROM sessions ORDER BY id DESC LIMIT 1"
row = self.connection.execute(sql).fetchone()
if row:
return Session(**dict(row))
return None
def init_db(db_path: Path) -> TradingDatabase:
"""
Initialize the database.
Args:
db_path: Path to the SQLite database file
Returns:
TradingDatabase instance
"""
global _db_instance
_db_instance = TradingDatabase(db_path)
_db_instance.init_schema()
return _db_instance
def get_db() -> Optional[TradingDatabase]:
"""Get the global database instance."""
return _db_instance

View File

@@ -1,235 +0,0 @@
"""Metrics calculation from trade database."""
import logging
from dataclasses import dataclass
from datetime import datetime, timezone, timedelta
from typing import Optional
from .database import TradingDatabase
logger = logging.getLogger(__name__)
@dataclass
class PeriodMetrics:
"""Trading metrics for a time period."""
period_name: str
start_time: Optional[str]
end_time: Optional[str]
total_pnl: float = 0.0
total_trades: int = 0
winning_trades: int = 0
losing_trades: int = 0
win_rate: float = 0.0
avg_trade_duration_hours: float = 0.0
max_drawdown: float = 0.0
max_drawdown_pct: float = 0.0
best_trade: float = 0.0
worst_trade: float = 0.0
avg_win: float = 0.0
avg_loss: float = 0.0
class MetricsCalculator:
"""Calculate trading metrics from database."""
def __init__(self, db: TradingDatabase):
self.db = db
def get_all_time_metrics(self) -> PeriodMetrics:
"""Get metrics for all trades ever."""
return self._calculate_metrics("All Time", None, None)
def get_monthly_metrics(self) -> PeriodMetrics:
"""Get metrics for current month."""
now = datetime.now(timezone.utc)
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return self._calculate_metrics(
"Monthly",
start.isoformat(),
now.isoformat(),
)
def get_weekly_metrics(self) -> PeriodMetrics:
"""Get metrics for current week (Monday to now)."""
now = datetime.now(timezone.utc)
days_since_monday = now.weekday()
start = now - timedelta(days=days_since_monday)
start = start.replace(hour=0, minute=0, second=0, microsecond=0)
return self._calculate_metrics(
"Weekly",
start.isoformat(),
now.isoformat(),
)
def get_daily_metrics(self) -> PeriodMetrics:
"""Get metrics for today (UTC)."""
now = datetime.now(timezone.utc)
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
return self._calculate_metrics(
"Daily",
start.isoformat(),
now.isoformat(),
)
def _calculate_metrics(
self,
period_name: str,
start_time: Optional[str],
end_time: Optional[str],
) -> PeriodMetrics:
"""
Calculate metrics for a time period.
Args:
period_name: Name of the period
start_time: ISO format start time (None for all time)
end_time: ISO format end time (None for all time)
Returns:
PeriodMetrics object
"""
metrics = PeriodMetrics(
period_name=period_name,
start_time=start_time,
end_time=end_time,
)
# Build query conditions
conditions = ["exit_time IS NOT NULL"]
params = []
if start_time:
conditions.append("exit_time >= ?")
params.append(start_time)
if end_time:
conditions.append("exit_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions)
# Get aggregate metrics
sql = f"""
SELECT
COUNT(*) as total_trades,
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
SUM(CASE WHEN pnl_usd < 0 THEN 1 ELSE 0 END) as losing_trades,
COALESCE(SUM(pnl_usd), 0) as total_pnl,
COALESCE(AVG(hold_duration_hours), 0) as avg_duration,
COALESCE(MAX(pnl_usd), 0) as best_trade,
COALESCE(MIN(pnl_usd), 0) as worst_trade,
COALESCE(AVG(CASE WHEN pnl_usd > 0 THEN pnl_usd END), 0) as avg_win,
COALESCE(AVG(CASE WHEN pnl_usd < 0 THEN pnl_usd END), 0) as avg_loss
FROM trades
WHERE {where_clause}
"""
row = self.db.connection.execute(sql, params).fetchone()
if row and row["total_trades"] > 0:
metrics.total_trades = row["total_trades"]
metrics.winning_trades = row["winning_trades"] or 0
metrics.losing_trades = row["losing_trades"] or 0
metrics.total_pnl = row["total_pnl"]
metrics.avg_trade_duration_hours = row["avg_duration"]
metrics.best_trade = row["best_trade"]
metrics.worst_trade = row["worst_trade"]
metrics.avg_win = row["avg_win"]
metrics.avg_loss = row["avg_loss"]
if metrics.total_trades > 0:
metrics.win_rate = (
metrics.winning_trades / metrics.total_trades * 100
)
# Calculate max drawdown
metrics.max_drawdown = self._calculate_max_drawdown(
start_time, end_time
)
return metrics
def _calculate_max_drawdown(
self,
start_time: Optional[str],
end_time: Optional[str],
) -> float:
"""Calculate maximum drawdown for a period."""
conditions = ["exit_time IS NOT NULL"]
params = []
if start_time:
conditions.append("exit_time >= ?")
params.append(start_time)
if end_time:
conditions.append("exit_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions)
sql = f"""
SELECT pnl_usd
FROM trades
WHERE {where_clause}
ORDER BY exit_time
"""
rows = self.db.connection.execute(sql, params).fetchall()
if not rows:
return 0.0
cumulative = 0.0
peak = 0.0
max_drawdown = 0.0
for row in rows:
pnl = row["pnl_usd"] or 0.0
cumulative += pnl
peak = max(peak, cumulative)
drawdown = peak - cumulative
max_drawdown = max(max_drawdown, drawdown)
return max_drawdown
def has_monthly_data(self) -> bool:
"""Check if we have data spanning more than current month."""
sql = """
SELECT MIN(exit_time) as first_trade
FROM trades
WHERE exit_time IS NOT NULL
"""
row = self.db.connection.execute(sql).fetchone()
if not row or not row["first_trade"]:
return False
first_trade = datetime.fromisoformat(row["first_trade"])
now = datetime.now(timezone.utc)
month_start = now.replace(day=1, hour=0, minute=0, second=0)
return first_trade < month_start
def has_weekly_data(self) -> bool:
"""Check if we have data spanning more than current week."""
sql = """
SELECT MIN(exit_time) as first_trade
FROM trades
WHERE exit_time IS NOT NULL
"""
row = self.db.connection.execute(sql).fetchone()
if not row or not row["first_trade"]:
return False
first_trade = datetime.fromisoformat(row["first_trade"])
now = datetime.now(timezone.utc)
days_since_monday = now.weekday()
week_start = now - timedelta(days=days_since_monday)
week_start = week_start.replace(hour=0, minute=0, second=0)
return first_trade < week_start
def get_session_start_balance(self) -> Optional[float]:
"""Get starting balance from latest session."""
sql = "SELECT starting_balance FROM sessions ORDER BY id DESC LIMIT 1"
row = self.db.connection.execute(sql).fetchone()
return row["starting_balance"] if row else None

View File

@@ -1,191 +0,0 @@
"""Database migrations and CSV import."""
import csv
import logging
from pathlib import Path
from datetime import datetime
from .database import TradingDatabase
from .models import Trade, DailySummary
logger = logging.getLogger(__name__)
def migrate_csv_to_db(db: TradingDatabase, csv_path: Path) -> int:
"""
Migrate trades from CSV file to SQLite database.
Args:
db: TradingDatabase instance
csv_path: Path to trade_log.csv
Returns:
Number of trades migrated
"""
if not csv_path.exists():
logger.info("No CSV file to migrate")
return 0
# Check if database already has trades
existing_count = db.count_trades()
if existing_count > 0:
logger.info(
f"Database already has {existing_count} trades, skipping migration"
)
return 0
migrated = 0
try:
with open(csv_path, "r", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
trade = _csv_row_to_trade(row)
if trade:
try:
db.insert_trade(trade)
migrated += 1
except Exception as e:
logger.warning(
f"Failed to migrate trade {row.get('trade_id')}: {e}"
)
logger.info(f"Migrated {migrated} trades from CSV to database")
except Exception as e:
logger.error(f"CSV migration failed: {e}")
return migrated
def _csv_row_to_trade(row: dict) -> Trade | None:
"""Convert a CSV row to a Trade object."""
try:
return Trade(
trade_id=row["trade_id"],
symbol=row["symbol"],
side=row["side"],
entry_price=float(row["entry_price"]),
exit_price=_safe_float(row.get("exit_price")),
size=float(row["size"]),
size_usdt=float(row["size_usdt"]),
pnl_usd=_safe_float(row.get("pnl_usd")),
pnl_pct=_safe_float(row.get("pnl_pct")),
entry_time=row["entry_time"],
exit_time=row.get("exit_time") or None,
hold_duration_hours=_safe_float(row.get("hold_duration_hours")),
reason=row.get("reason") or None,
order_id_entry=row.get("order_id_entry") or None,
order_id_exit=row.get("order_id_exit") or None,
)
except (KeyError, ValueError) as e:
logger.warning(f"Invalid CSV row: {e}")
return None
def _safe_float(value: str | None) -> float | None:
"""Safely convert string to float."""
if value is None or value == "":
return None
try:
return float(value)
except ValueError:
return None
def rebuild_daily_summaries(db: TradingDatabase) -> int:
"""
Rebuild daily summary table from trades.
Args:
db: TradingDatabase instance
Returns:
Number of daily summaries created
"""
sql = """
SELECT
DATE(exit_time) as date,
COUNT(*) as total_trades,
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
SUM(pnl_usd) as total_pnl_usd
FROM trades
WHERE exit_time IS NOT NULL
GROUP BY DATE(exit_time)
ORDER BY date
"""
rows = db.connection.execute(sql).fetchall()
count = 0
for row in rows:
summary = DailySummary(
date=row["date"],
total_trades=row["total_trades"],
winning_trades=row["winning_trades"],
total_pnl_usd=row["total_pnl_usd"] or 0.0,
max_drawdown_usd=0.0, # Calculated separately
)
db.upsert_daily_summary(summary)
count += 1
# Calculate max drawdowns
_calculate_daily_drawdowns(db)
logger.info(f"Rebuilt {count} daily summaries")
return count
def _calculate_daily_drawdowns(db: TradingDatabase) -> None:
"""Calculate and update max drawdown for each day."""
sql = """
SELECT trade_id, DATE(exit_time) as date, pnl_usd
FROM trades
WHERE exit_time IS NOT NULL
ORDER BY exit_time
"""
rows = db.connection.execute(sql).fetchall()
# Track cumulative PnL and drawdown per day
daily_drawdowns: dict[str, float] = {}
cumulative_pnl = 0.0
peak_pnl = 0.0
for row in rows:
date = row["date"]
pnl = row["pnl_usd"] or 0.0
cumulative_pnl += pnl
peak_pnl = max(peak_pnl, cumulative_pnl)
drawdown = peak_pnl - cumulative_pnl
if date not in daily_drawdowns:
daily_drawdowns[date] = 0.0
daily_drawdowns[date] = max(daily_drawdowns[date], drawdown)
# Update daily summaries with drawdown
for date, drawdown in daily_drawdowns.items():
db.connection.execute(
"UPDATE daily_summary SET max_drawdown_usd = ? WHERE date = ?",
(drawdown, date),
)
db.connection.commit()
def run_migrations(db: TradingDatabase, csv_path: Path) -> None:
"""
Run all migrations.
Args:
db: TradingDatabase instance
csv_path: Path to trade_log.csv for migration
"""
logger.info("Running database migrations...")
# Migrate CSV data if exists
migrate_csv_to_db(db, csv_path)
# Rebuild daily summaries
rebuild_daily_summaries(db)
logger.info("Migrations complete")

View File

@@ -1,69 +0,0 @@
"""Data models for trade persistence."""
from dataclasses import dataclass, asdict
from typing import Optional
from datetime import datetime
@dataclass
class Trade:
"""Represents a completed trade."""
trade_id: str
symbol: str
side: str
entry_price: float
size: float
size_usdt: float
entry_time: str
exit_price: Optional[float] = None
pnl_usd: Optional[float] = None
pnl_pct: Optional[float] = None
exit_time: Optional[str] = None
hold_duration_hours: Optional[float] = None
reason: Optional[str] = None
order_id_entry: Optional[str] = None
order_id_exit: Optional[str] = None
id: Optional[int] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
@classmethod
def from_row(cls, row: tuple, columns: list[str]) -> "Trade":
"""Create Trade from database row."""
data = dict(zip(columns, row))
return cls(**data)
@dataclass
class DailySummary:
"""Daily trading summary."""
date: str
total_trades: int = 0
winning_trades: int = 0
total_pnl_usd: float = 0.0
max_drawdown_usd: float = 0.0
id: Optional[int] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
@dataclass
class Session:
"""Trading session metadata."""
start_time: str
end_time: Optional[str] = None
starting_balance: Optional[float] = None
ending_balance: Optional[float] = None
total_pnl: Optional[float] = None
total_trades: int = 0
id: Optional[int] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)

View File

@@ -6,7 +6,6 @@ Uses a pre-trained ML model or trains on historical data.
"""
import logging
import pickle
import time
from pathlib import Path
from typing import Optional
@@ -40,48 +39,18 @@ class LiveRegimeStrategy:
self.paths = path_config
self.model: Optional[RandomForestClassifier] = None
self.feature_cols: Optional[list] = None
self.horizon: int = 54 # Default horizon
self._last_model_load_time: float = 0.0
self._last_train_time: float = 0.0
self._load_or_train_model()
def reload_model_if_changed(self) -> None:
"""Check if model file has changed and reload if necessary."""
if not self.paths.model_path.exists():
return
try:
mtime = self.paths.model_path.stat().st_mtime
if mtime > self._last_model_load_time:
logger.info(f"Model file changed, reloading... (last: {self._last_model_load_time}, new: {mtime})")
self._load_or_train_model()
except Exception as e:
logger.warning(f"Error checking model file: {e}")
def _load_or_train_model(self) -> None:
"""Load pre-trained model or train a new one."""
if self.paths.model_path.exists():
try:
self._last_model_load_time = self.paths.model_path.stat().st_mtime
with open(self.paths.model_path, 'rb') as f:
saved = pickle.load(f)
self.model = saved['model']
self.feature_cols = saved['feature_cols']
# Load horizon from metrics if available
if 'metrics' in saved and 'horizon' in saved['metrics']:
self.horizon = saved['metrics']['horizon']
logger.info(f"Loaded model from {self.paths.model_path} (horizon={self.horizon})")
else:
logger.info(f"Loaded model from {self.paths.model_path} (default horizon={self.horizon})")
# Load timestamp if available
if 'timestamp' in saved:
self._last_train_time = saved['timestamp']
else:
self._last_train_time = self._last_model_load_time
return
logger.info(f"Loaded model from {self.paths.model_path}")
return
except Exception as e:
logger.warning(f"Could not load model: {e}")
@@ -97,20 +66,11 @@ class LiveRegimeStrategy:
pickle.dump({
'model': self.model,
'feature_cols': self.feature_cols,
'metrics': {'horizon': self.horizon}, # Save horizon
'timestamp': time.time()
}, f)
logger.info(f"Saved model to {self.paths.model_path}")
except Exception as e:
logger.error(f"Could not save model: {e}")
def check_retrain(self, features: pd.DataFrame) -> None:
"""Check if model needs retraining (older than 24h)."""
if time.time() - self._last_train_time > 24 * 3600:
logger.info("Model is older than 24h. Retraining...")
self.train_model(features)
self._last_train_time = time.time()
def train_model(self, features: pd.DataFrame) -> None:
"""
Train the Random Forest model on historical data.
@@ -121,63 +81,20 @@ class LiveRegimeStrategy:
logger.info(f"Training model on {len(features)} samples...")
z_thresh = self.config.z_entry_threshold
horizon = self.horizon
horizon = 102 # Optimal horizon from research
profit_target = 0.005 # 0.5% profit threshold
stop_loss_pct = self.config.stop_loss_pct
# Calculate targets path-dependently
spread = features['spread'].values
z_score = features['z_score'].values
n = len(spread)
# Define targets
future_min = features['spread'].rolling(window=horizon).min().shift(-horizon)
future_max = features['spread'].rolling(window=horizon).max().shift(-horizon)
targets = np.zeros(n, dtype=int)
target_short = features['spread'] * (1 - profit_target)
target_long = features['spread'] * (1 + profit_target)
candidates = np.where((z_score > z_thresh) | (z_score < -z_thresh))[0]
success_short = (features['z_score'] > z_thresh) & (future_min < target_short)
success_long = (features['z_score'] < -z_thresh) & (future_max > target_long)
for i in candidates:
if i + horizon >= n:
continue
entry_price = spread[i]
future_prices = spread[i+1 : i+1+horizon]
if z_score[i] > z_thresh: # Short
target_price = entry_price * (1 - profit_target)
stop_price = entry_price * (1 + stop_loss_pct)
hit_tp = future_prices <= target_price
hit_sl = future_prices >= stop_price
if not np.any(hit_tp):
targets[i] = 0
elif not np.any(hit_sl):
targets[i] = 1
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
else: # Long
target_price = entry_price * (1 + profit_target)
stop_price = entry_price * (1 - stop_loss_pct)
hit_tp = future_prices >= target_price
hit_sl = future_prices <= stop_price
if not np.any(hit_tp):
targets[i] = 0
elif not np.any(hit_sl):
targets[i] = 1
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
targets = np.select([success_short, success_long], [1, 1], default=0)
# Exclude non-feature columns
exclude = ['spread', 'btc_close', 'eth_close', 'eth_volume']
@@ -187,10 +104,8 @@ class LiveRegimeStrategy:
X = features[self.feature_cols].fillna(0)
X = X.replace([np.inf, -np.inf], 0)
# Use rows where we had enough data to look ahead
valid_mask = np.zeros(n, dtype=bool)
valid_mask[:n-horizon] = True
# Remove rows with invalid targets
valid_mask = ~np.isnan(targets) & future_min.notna().values & future_max.notna().values
X_clean = X[valid_mask]
y_clean = targets[valid_mask]
@@ -214,8 +129,7 @@ class LiveRegimeStrategy:
def generate_signal(
self,
features: pd.DataFrame,
current_funding: dict,
position_side: Optional[str] = None
current_funding: dict
) -> dict:
"""
Generate trading signal from latest features.
@@ -223,14 +137,10 @@ class LiveRegimeStrategy:
Args:
features: DataFrame with calculated features
current_funding: Dictionary with funding rate data
position_side: Current position side ('long', 'short', or None)
Returns:
Signal dictionary with action, side, confidence, etc.
"""
# Check if retraining is needed
self.check_retrain(features)
if self.model is None:
# Train model if not available
if len(features) >= 200:
@@ -300,17 +210,12 @@ class LiveRegimeStrategy:
signal['action'] = 'hold'
signal['reason'] = f'funding_filter_blocked_short (funding={btc_funding:.4f})'
# Check for exit conditions (Overshoot Logic)
if signal['action'] == 'hold' and position_side:
# Overshoot Logic
# If Long, exit if Z > 0.5 (Reverted past 0 to +0.5)
if position_side == 'long' and z_score > 0.5:
signal['action'] = 'check_exit'
signal['reason'] = f'overshoot_exit_long (z={z_score:.2f} > 0.5)'
# If Short, exit if Z < -0.5 (Reverted past 0 to -0.5)
elif position_side == 'short' and z_score < -0.5:
signal['action'] = 'check_exit'
signal['reason'] = f'overshoot_exit_short (z={z_score:.2f} < -0.5)'
# Check for exit conditions (mean reversion complete)
if signal['action'] == 'hold':
# Z-score crossed back through 0
if abs(z_score) < 0.3:
signal['action'] = 'check_exit'
signal['reason'] = f'z_score_reverted_to_mean ({z_score:.2f})'
logger.info(
f"Signal: {signal['action']} {signal['side'] or ''} "
@@ -356,9 +261,9 @@ class LiveRegimeStrategy:
def calculate_sl_tp(
self,
entry_price: Optional[float],
entry_price: float,
side: str
) -> tuple[Optional[float], Optional[float]]:
) -> tuple[float, float]:
"""
Calculate stop-loss and take-profit prices.
@@ -367,21 +272,8 @@ class LiveRegimeStrategy:
side: "long" or "short"
Returns:
Tuple of (stop_loss_price, take_profit_price), or (None, None) if
entry_price is invalid
Raises:
ValueError: If side is not "long" or "short"
Tuple of (stop_loss_price, take_profit_price)
"""
if entry_price is None or entry_price <= 0:
logger.error(
f"Invalid entry_price for SL/TP calculation: {entry_price}"
)
return None, None
if side not in ("long", "short"):
raise ValueError(f"Invalid side: {side}. Must be 'long' or 'short'")
sl_pct = self.config.stop_loss_pct
tp_pct = self.config.take_profit_pct

View File

@@ -11,19 +11,14 @@ Usage:
# Run with specific settings
uv run python -m live_trading.main --max-position 500 --leverage 2
# Run without UI (headless mode)
uv run python -m live_trading.main --no-ui
"""
import argparse
import logging
import queue
import signal
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
@@ -33,47 +28,22 @@ from live_trading.okx_client import OKXClient
from live_trading.data_feed import DataFeed
from live_trading.position_manager import PositionManager
from live_trading.live_regime_strategy import LiveRegimeStrategy
from live_trading.db.database import init_db, TradingDatabase
from live_trading.db.migrations import run_migrations
from live_trading.ui.state import SharedState, PositionState
from live_trading.ui.dashboard import TradingDashboard, setup_ui_logging
def setup_logging(
log_dir: Path,
log_queue: Optional[queue.Queue] = None,
) -> logging.Logger:
"""
Configure logging for the trading bot.
Args:
log_dir: Directory for log files
log_queue: Optional queue for UI log handler
Returns:
Logger instance
"""
def setup_logging(log_dir: Path) -> logging.Logger:
"""Configure logging for the trading bot."""
log_file = log_dir / "live_trading.log"
handlers = [
logging.FileHandler(log_file),
]
# Only add StreamHandler if no UI (log_queue is None)
if log_queue is None:
handlers.append(logging.StreamHandler(sys.stdout))
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
handlers=handlers,
force=True,
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout),
],
force=True
)
# Add UI log handler if queue provided
if log_queue is not None:
setup_ui_logging(log_queue)
return logging.getLogger(__name__)
@@ -89,15 +59,11 @@ class LiveTradingBot:
self,
okx_config: OKXConfig,
trading_config: TradingConfig,
path_config: PathConfig,
database: Optional[TradingDatabase] = None,
shared_state: Optional[SharedState] = None,
path_config: PathConfig
):
self.okx_config = okx_config
self.trading_config = trading_config
self.path_config = path_config
self.db = database
self.state = shared_state
self.logger = logging.getLogger(__name__)
self.running = True
@@ -108,7 +74,7 @@ class LiveTradingBot:
self.okx_client = OKXClient(okx_config, trading_config)
self.data_feed = DataFeed(self.okx_client, trading_config, path_config)
self.position_manager = PositionManager(
self.okx_client, trading_config, path_config, database
self.okx_client, trading_config, path_config
)
self.strategy = LiveRegimeStrategy(trading_config, path_config)
@@ -116,16 +82,6 @@ class LiveTradingBot:
signal.signal(signal.SIGINT, self._handle_shutdown)
signal.signal(signal.SIGTERM, self._handle_shutdown)
# Initialize shared state if provided
if self.state:
mode = "DEMO" if okx_config.demo_mode else "LIVE"
self.state.set_mode(mode)
self.state.set_symbols(
trading_config.eth_symbol,
trading_config.btc_symbol,
)
self.state.update_account(0.0, 0.0, trading_config.leverage)
self._print_startup_banner()
def _print_startup_banner(self) -> None:
@@ -153,8 +109,6 @@ class LiveTradingBot:
"""Handle shutdown signals gracefully."""
self.logger.info("Shutdown signal received, stopping...")
self.running = False
if self.state:
self.state.stop()
def run_trading_cycle(self) -> None:
"""
@@ -164,20 +118,10 @@ class LiveTradingBot:
2. Update open positions
3. Generate trading signal
4. Execute trades if signal triggers
5. Update shared state for UI
"""
# Reload model if it has changed (e.g. daily training)
try:
self.strategy.reload_model_if_changed()
except Exception as e:
self.logger.warning(f"Failed to reload model: {e}")
cycle_start = datetime.now(timezone.utc)
self.logger.info(f"--- Trading Cycle Start: {cycle_start.isoformat()} ---")
if self.state:
self.state.set_last_cycle_time(cycle_start.isoformat())
try:
# 1. Fetch market data
features = self.data_feed.get_latest_data()
@@ -206,31 +150,19 @@ class LiveTradingBot:
# 3. Sync with exchange positions
self.position_manager.sync_with_exchange()
# Get current position side for signal generation
symbol = self.trading_config.eth_symbol
position = self.position_manager.get_position_for_symbol(symbol)
position_side = position.side if position else None
# 4. Get current funding rates
funding = self.data_feed.get_current_funding_rates()
# 5. Generate trading signal
sig = self.strategy.generate_signal(features, funding, position_side=position_side)
signal = self.strategy.generate_signal(features, funding)
# 6. Update shared state with strategy info
self._update_strategy_state(sig, funding)
# 6. Execute trades based on signal
if signal['action'] == 'entry':
self._execute_entry(signal, eth_price)
elif signal['action'] == 'check_exit':
self._execute_exit(signal)
# 7. Execute trades based on signal
if sig['action'] == 'entry':
self._execute_entry(sig, eth_price)
elif sig['action'] == 'check_exit':
self._execute_exit(sig)
# 8. Update shared state with position and account
self._update_position_state(eth_price)
self._update_account_state()
# 9. Log portfolio summary
# 7. Log portfolio summary
summary = self.position_manager.get_portfolio_summary()
self.logger.info(
f"Portfolio: {summary['open_positions']} positions, "
@@ -246,61 +178,6 @@ class LiveTradingBot:
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
self.logger.info(f"--- Cycle completed in {cycle_duration:.1f}s ---")
def _update_strategy_state(self, sig: dict, funding: dict) -> None:
"""Update shared state with strategy information."""
if not self.state:
return
self.state.update_strategy(
z_score=sig.get('z_score', 0.0),
probability=sig.get('probability', 0.0),
funding_rate=funding.get('btc_funding', 0.0),
action=sig.get('action', 'hold'),
reason=sig.get('reason', ''),
)
def _update_position_state(self, current_price: float) -> None:
"""Update shared state with current position."""
if not self.state:
return
symbol = self.trading_config.eth_symbol
position = self.position_manager.get_position_for_symbol(symbol)
if position is None:
self.state.clear_position()
return
pos_state = PositionState(
trade_id=position.trade_id,
symbol=position.symbol,
side=position.side,
entry_price=position.entry_price,
current_price=position.current_price,
size=position.size,
size_usdt=position.size_usdt,
unrealized_pnl=position.unrealized_pnl,
unrealized_pnl_pct=position.unrealized_pnl_pct,
stop_loss_price=position.stop_loss_price,
take_profit_price=position.take_profit_price,
)
self.state.set_position(pos_state)
def _update_account_state(self) -> None:
"""Update shared state with account information."""
if not self.state:
return
try:
balance = self.okx_client.get_balance()
self.state.update_account(
balance=balance.get('total', 0.0),
available=balance.get('free', 0.0),
leverage=self.trading_config.leverage,
)
except Exception as e:
self.logger.warning(f"Failed to update account state: {e}")
def _execute_entry(self, signal: dict, current_price: float) -> None:
"""Execute entry trade."""
symbol = self.trading_config.eth_symbol
@@ -314,46 +191,43 @@ class LiveTradingBot:
# Get account balance
balance = self.okx_client.get_balance()
available_usdt = balance['free']
self.logger.info(f"Account balance: ${available_usdt:.2f} USDT available")
# Calculate position size
size_usdt = self.strategy.calculate_position_size(signal, available_usdt)
if size_usdt <= 0:
self.logger.info(
f"Position size too small (${size_usdt:.2f}), skipping entry. "
f"Min required: ${self.strategy.config.min_position_usdt:.2f}"
)
self.logger.info("Position size too small, skipping entry")
return
size_eth = size_usdt / current_price
# Calculate SL/TP for logging
# Calculate SL/TP
stop_loss, take_profit = self.strategy.calculate_sl_tp(current_price, side)
sl_str = f"{stop_loss:.2f}" if stop_loss else "N/A"
tp_str = f"{take_profit:.2f}" if take_profit else "N/A"
self.logger.info(
f"Executing {side.upper()} entry: {size_eth:.4f} ETH @ {current_price:.2f} "
f"(${size_usdt:.2f}), SL={sl_str}, TP={tp_str}"
f"(${size_usdt:.2f}), SL={stop_loss:.2f}, TP={take_profit:.2f}"
)
try:
# Place market order (guaranteed to have fill price or raises)
# Place market order
order_side = "buy" if side == "long" else "sell"
order = self.okx_client.place_market_order(symbol, order_side, size_eth)
# Get filled price and amount (guaranteed by OKX client)
filled_price = order['average']
filled_amount = order.get('filled') or size_eth
# Get filled price (handle None values from OKX response)
filled_price = order.get('average') or order.get('price') or current_price
filled_amount = order.get('filled') or order.get('amount') or size_eth
# Calculate SL/TP with filled price
# Ensure we have valid numeric values
if filled_price is None or filled_price == 0:
self.logger.warning(f"No fill price in order response, using current price: {current_price}")
filled_price = current_price
if filled_amount is None or filled_amount == 0:
self.logger.warning(f"No fill amount in order response, using requested: {size_eth}")
filled_amount = size_eth
# Recalculate SL/TP with filled price
stop_loss, take_profit = self.strategy.calculate_sl_tp(filled_price, side)
if stop_loss is None or take_profit is None:
raise RuntimeError(
f"Failed to calculate SL/TP: filled_price={filled_price}, side={side}"
)
# Get order ID from response
order_id = order.get('id', '')
@@ -417,30 +291,22 @@ class LiveTradingBot:
except Exception as e:
self.logger.error(f"Exit execution failed: {e}", exc_info=True)
def _is_running(self) -> bool:
"""Check if bot should continue running."""
if not self.running:
return False
if self.state and not self.state.is_running():
return False
return True
def run(self) -> None:
"""Main trading loop."""
self.logger.info("Starting trading loop...")
while self._is_running():
while self.running:
try:
self.run_trading_cycle()
if self._is_running():
if self.running:
sleep_seconds = self.trading_config.sleep_seconds
minutes = sleep_seconds // 60
self.logger.info(f"Sleeping for {minutes} minutes...")
# Sleep in smaller chunks to allow faster shutdown
for _ in range(sleep_seconds):
if not self._is_running():
if not self.running:
break
time.sleep(1)
@@ -454,8 +320,6 @@ class LiveTradingBot:
# Cleanup
self.logger.info("Shutting down...")
self.position_manager.save_positions()
if self.db:
self.db.close()
self.logger.info("Shutdown complete")
@@ -487,11 +351,6 @@ def parse_args():
action="store_true",
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
)
parser.add_argument(
"--no-ui",
action="store_true",
help="Run in headless mode without terminal UI"
)
return parser.parse_args()
@@ -512,64 +371,19 @@ def main():
if args.live:
okx_config.demo_mode = False
# Determine if UI should be enabled
use_ui = not args.no_ui and sys.stdin.isatty()
# Initialize database
db_path = path_config.base_dir / "live_trading" / "trading.db"
db = init_db(db_path)
# Run migrations (imports CSV if exists)
run_migrations(db, path_config.trade_log_file)
# Initialize UI components if enabled
log_queue: Optional[queue.Queue] = None
shared_state: Optional[SharedState] = None
dashboard: Optional[TradingDashboard] = None
if use_ui:
log_queue = queue.Queue(maxsize=1000)
shared_state = SharedState()
# Setup logging
logger = setup_logging(path_config.logs_dir, log_queue)
logger = setup_logging(path_config.logs_dir)
try:
# Create bot
bot = LiveTradingBot(
okx_config,
trading_config,
path_config,
database=db,
shared_state=shared_state,
)
# Start dashboard if UI enabled
if use_ui and shared_state and log_queue:
dashboard = TradingDashboard(
state=shared_state,
db=db,
log_queue=log_queue,
on_quit=lambda: setattr(bot, 'running', False),
)
dashboard.start()
logger.info("Dashboard started")
# Run bot
# Create and run bot
bot = LiveTradingBot(okx_config, trading_config, path_config)
bot.run()
except ValueError as e:
logger.error(f"Configuration error: {e}")
sys.exit(1)
except Exception as e:
logger.error(f"Fatal error: {e}", exc_info=True)
sys.exit(1)
finally:
# Cleanup
if dashboard:
dashboard.stop()
if db:
db.close()
if __name__ == "__main__":

View File

@@ -0,0 +1,145 @@
# Multi-Pair Divergence Live Trading
This module implements live trading for the Multi-Pair Divergence Selection Strategy on OKX perpetual futures.
## Overview
The strategy scans 10 cryptocurrency pairs for spread divergence opportunities:
1. **Pair Universe**: Top 10 assets by market cap (BTC, ETH, SOL, XRP, BNB, DOGE, ADA, AVAX, LINK, DOT)
2. **Spread Z-Score**: Identifies when pairs are divergent from their historical mean
3. **Universal ML Model**: Predicts probability of successful mean reversion
4. **Dynamic Selection**: Trades the pair with highest divergence score
## Prerequisites
Before running live trading, you must train the model via backtesting:
```bash
uv run python scripts/run_multi_pair_backtest.py
```
This creates `data/multi_pair_model.pkl` which the live trading bot requires.
## Setup
### 1. API Keys
Same as single-pair trading. Set in `.env`:
```env
OKX_API_KEY=your_api_key
OKX_SECRET=your_secret
OKX_PASSWORD=your_passphrase
OKX_DEMO_MODE=true # Use demo for testing
```
### 2. Dependencies
All dependencies are in `pyproject.toml`. No additional installation needed.
## Usage
### Run with Demo Account (Recommended First)
```bash
uv run python -m live_trading.multi_pair.main
```
### Command Line Options
```bash
# Custom position size
uv run python -m live_trading.multi_pair.main --max-position 500
# Custom leverage
uv run python -m live_trading.multi_pair.main --leverage 2
# Custom cycle interval (in seconds)
uv run python -m live_trading.multi_pair.main --interval 1800
# Combine options
uv run python -m live_trading.multi_pair.main --max-position 1000 --leverage 3 --interval 3600
```
### Live Trading (Use with Caution)
```bash
uv run python -m live_trading.multi_pair.main --live
```
## How It Works
### Each Trading Cycle
1. **Fetch Data**: Gets OHLCV for all 10 assets from OKX
2. **Calculate Features**: Computes Z-Score, RSI, volatility for all 45 pair combinations
3. **Score Pairs**: Uses ML model to rank pairs by divergence score (|Z| x probability)
4. **Check Exits**: If holding, check mean reversion or SL/TP
5. **Enter Best**: If no position, enter the highest-scoring divergent pair
### Entry Conditions
- |Z-Score| > 1.0 (spread diverged from mean)
- ML probability > 0.5 (model predicts successful reversion)
- Funding rate filter passes (avoid crowded trades)
### Exit Conditions
- Mean reversion: |Z-Score| returns to ~0
- Stop-loss: ATR-based (default ~6%)
- Take-profit: ATR-based (default ~5%)
## Strategy Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `z_entry_threshold` | 1.0 | Enter when \|Z-Score\| > threshold |
| `z_exit_threshold` | 0.0 | Exit when Z reverts to mean |
| `z_window` | 24 | Rolling window for Z-Score (hours) |
| `prob_threshold` | 0.5 | ML probability threshold for entry |
| `funding_threshold` | 0.0005 | Funding rate filter (0.05%) |
| `sl_atr_multiplier` | 10.0 | Stop-loss as ATR multiple |
| `tp_atr_multiplier` | 8.0 | Take-profit as ATR multiple |
## Files
### Input
- `data/multi_pair_model.pkl` - Pre-trained ML model (required)
### Output
- `logs/multi_pair_live.log` - Trading logs
- `live_trading/multi_pair_positions.json` - Position persistence
- `live_trading/multi_pair_trade_log.csv` - Trade history
## Architecture
```
live_trading/multi_pair/
__init__.py # Module exports
config.py # Configuration classes
data_feed.py # Multi-asset OHLCV fetcher
strategy.py # ML scoring and signal generation
main.py # Bot orchestrator
README.md # This file
```
## Differences from Single-Pair
| Aspect | Single-Pair | Multi-Pair |
|--------|-------------|------------|
| Assets | ETH only (BTC context) | 10 assets, 45 pairs |
| Model | ETH-specific | Universal across pairs |
| Selection | Fixed pair | Dynamic best pair |
| Stops | Fixed 6%/5% | ATR-based dynamic |
## Risk Warning
This is experimental trading software. Use at your own risk:
- Always start with demo trading
- Never risk more than you can afford to lose
- Monitor the bot regularly
- The model was trained on historical data and may not predict future performance

View File

@@ -0,0 +1,11 @@
"""Multi-Pair Divergence Live Trading Module."""
from .config import MultiPairLiveConfig, get_multi_pair_config
from .data_feed import MultiPairDataFeed
from .strategy import LiveMultiPairStrategy
__all__ = [
"MultiPairLiveConfig",
"get_multi_pair_config",
"MultiPairDataFeed",
"LiveMultiPairStrategy",
]

View File

@@ -0,0 +1,145 @@
"""
Configuration for Multi-Pair Live Trading.
Extends the base live trading config with multi-pair specific settings.
"""
import os
from pathlib import Path
from dataclasses import dataclass, field
from dotenv import load_dotenv
load_dotenv()
@dataclass
class OKXConfig:
"""OKX API configuration."""
api_key: str = field(default_factory=lambda: "")
secret: str = field(default_factory=lambda: "")
password: str = field(default_factory=lambda: "")
demo_mode: bool = field(default_factory=lambda: True)
def __post_init__(self):
"""Load credentials based on demo mode setting."""
self.demo_mode = os.getenv("OKX_DEMO_MODE", "true").lower() in ("true", "1", "yes")
if self.demo_mode:
self.api_key = os.getenv("OKX_DEMO_API_KEY", os.getenv("OKX_API_KEY", ""))
self.secret = os.getenv("OKX_DEMO_SECRET", os.getenv("OKX_SECRET", ""))
self.password = os.getenv("OKX_DEMO_PASSWORD", os.getenv("OKX_PASSWORD", ""))
else:
self.api_key = os.getenv("OKX_API_KEY", "")
self.secret = os.getenv("OKX_SECRET", "")
self.password = os.getenv("OKX_PASSWORD", "")
def validate(self) -> None:
"""Validate that required credentials are present."""
mode = "demo" if self.demo_mode else "live"
if not self.api_key:
raise ValueError(f"OKX API key not set for {mode} mode")
if not self.secret:
raise ValueError(f"OKX secret not set for {mode} mode")
if not self.password:
raise ValueError(f"OKX password not set for {mode} mode")
@dataclass
class MultiPairLiveConfig:
"""
Configuration for multi-pair live trading.
Combines trading parameters, strategy settings, and risk management.
"""
# Asset Universe (top 10 by market cap perpetuals)
assets: list[str] = field(default_factory=lambda: [
"BTC/USDT:USDT", "ETH/USDT:USDT", "SOL/USDT:USDT", "XRP/USDT:USDT",
"BNB/USDT:USDT", "DOGE/USDT:USDT", "ADA/USDT:USDT", "AVAX/USDT:USDT",
"LINK/USDT:USDT", "DOT/USDT:USDT"
])
# Timeframe
timeframe: str = "1h"
candles_to_fetch: int = 500 # Enough for feature calculation
# Z-Score Thresholds
z_window: int = 24
z_entry_threshold: float = 1.0
z_exit_threshold: float = 0.0 # Exit at mean reversion
# ML Thresholds
prob_threshold: float = 0.5
# Position sizing
max_position_usdt: float = -1.0 # If <= 0, use all available funds
min_position_usdt: float = 10.0
leverage: int = 1
margin_mode: str = "cross"
max_concurrent_positions: int = 1 # Trade one pair at a time
# Risk Management - ATR-Based Stops
atr_period: int = 14
sl_atr_multiplier: float = 10.0
tp_atr_multiplier: float = 8.0
# Fallback fixed percentages
base_sl_pct: float = 0.06
base_tp_pct: float = 0.05
# ATR bounds
min_sl_pct: float = 0.02
max_sl_pct: float = 0.10
min_tp_pct: float = 0.02
max_tp_pct: float = 0.15
# Funding Rate Filter
funding_threshold: float = 0.0005 # 0.05%
# Trade Management
min_hold_bars: int = 0
cooldown_bars: int = 0
# Execution
sleep_seconds: int = 3600 # Run every hour
slippage_pct: float = 0.001
def get_asset_short_name(self, symbol: str) -> str:
"""Convert symbol to short name (e.g., BTC/USDT:USDT -> btc)."""
return symbol.split("/")[0].lower()
def get_pair_count(self) -> int:
"""Calculate number of unique pairs from asset list."""
n = len(self.assets)
return n * (n - 1) // 2
@dataclass
class PathConfig:
"""File paths configuration."""
base_dir: Path = field(
default_factory=lambda: Path(__file__).parent.parent.parent
)
data_dir: Path = field(default=None)
logs_dir: Path = field(default=None)
model_path: Path = field(default=None)
positions_file: Path = field(default=None)
trade_log_file: Path = field(default=None)
def __post_init__(self):
self.data_dir = self.base_dir / "data"
self.logs_dir = self.base_dir / "logs"
# Use the same model as backtesting
self.model_path = self.base_dir / "data" / "multi_pair_model.pkl"
self.positions_file = self.base_dir / "live_trading" / "multi_pair_positions.json"
self.trade_log_file = self.base_dir / "live_trading" / "multi_pair_trade_log.csv"
# Ensure directories exist
self.data_dir.mkdir(parents=True, exist_ok=True)
self.logs_dir.mkdir(parents=True, exist_ok=True)
def get_multi_pair_config() -> tuple[OKXConfig, MultiPairLiveConfig, PathConfig]:
"""Get all configuration objects for multi-pair trading."""
okx = OKXConfig()
trading = MultiPairLiveConfig()
paths = PathConfig()
return okx, trading, paths

View File

@@ -0,0 +1,336 @@
"""
Multi-Pair Data Feed for Live Trading.
Fetches real-time OHLCV and funding data for all assets in the universe.
"""
import logging
from itertools import combinations
import pandas as pd
import numpy as np
import ta
from live_trading.okx_client import OKXClient
from .config import MultiPairLiveConfig, PathConfig
logger = logging.getLogger(__name__)
class TradingPair:
"""
Represents a tradeable pair for spread analysis.
Attributes:
base_asset: First asset symbol (e.g., ETH/USDT:USDT)
quote_asset: Second asset symbol (e.g., BTC/USDT:USDT)
pair_id: Unique identifier
"""
def __init__(self, base_asset: str, quote_asset: str):
self.base_asset = base_asset
self.quote_asset = quote_asset
self.pair_id = f"{base_asset}__{quote_asset}"
@property
def name(self) -> str:
"""Human-readable pair name."""
base = self.base_asset.split("/")[0]
quote = self.quote_asset.split("/")[0]
return f"{base}/{quote}"
def __hash__(self):
return hash(self.pair_id)
def __eq__(self, other):
if not isinstance(other, TradingPair):
return False
return self.pair_id == other.pair_id
class MultiPairDataFeed:
"""
Real-time data feed for multi-pair strategy.
Fetches OHLCV data for all assets and calculates spread features
for all pair combinations.
"""
def __init__(
self,
okx_client: OKXClient,
config: MultiPairLiveConfig,
path_config: PathConfig
):
self.client = okx_client
self.config = config
self.paths = path_config
# Cache for asset data
self._asset_data: dict[str, pd.DataFrame] = {}
self._funding_rates: dict[str, float] = {}
self._pairs: list[TradingPair] = []
# Generate pairs
self._generate_pairs()
def _generate_pairs(self) -> None:
"""Generate all unique pairs from asset universe."""
self._pairs = []
for base, quote in combinations(self.config.assets, 2):
pair = TradingPair(base_asset=base, quote_asset=quote)
self._pairs.append(pair)
logger.info("Generated %d pairs from %d assets",
len(self._pairs), len(self.config.assets))
@property
def pairs(self) -> list[TradingPair]:
"""Get list of trading pairs."""
return self._pairs
def fetch_all_ohlcv(self) -> dict[str, pd.DataFrame]:
"""
Fetch OHLCV data for all assets.
Returns:
Dictionary mapping symbol to OHLCV DataFrame
"""
self._asset_data = {}
for symbol in self.config.assets:
try:
ohlcv = self.client.fetch_ohlcv(
symbol,
self.config.timeframe,
self.config.candles_to_fetch
)
df = self._ohlcv_to_dataframe(ohlcv)
if len(df) >= 200:
self._asset_data[symbol] = df
logger.debug("Fetched %s: %d candles", symbol, len(df))
else:
logger.warning("Skipping %s: insufficient data (%d)",
symbol, len(df))
except Exception as e:
logger.error("Error fetching %s: %s", symbol, e)
logger.info("Fetched data for %d/%d assets",
len(self._asset_data), len(self.config.assets))
return self._asset_data
def _ohlcv_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
"""Convert OHLCV list to DataFrame."""
df = pd.DataFrame(
ohlcv,
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df.set_index('timestamp', inplace=True)
return df
def fetch_all_funding_rates(self) -> dict[str, float]:
"""
Fetch current funding rates for all assets.
Returns:
Dictionary mapping symbol to funding rate
"""
self._funding_rates = {}
for symbol in self.config.assets:
try:
rate = self.client.get_funding_rate(symbol)
self._funding_rates[symbol] = rate
except Exception as e:
logger.warning("Could not get funding for %s: %s", symbol, e)
self._funding_rates[symbol] = 0.0
return self._funding_rates
def calculate_pair_features(
self,
pair: TradingPair
) -> pd.DataFrame | None:
"""
Calculate features for a single pair.
Args:
pair: Trading pair
Returns:
DataFrame with features, or None if insufficient data
"""
base = pair.base_asset
quote = pair.quote_asset
if base not in self._asset_data or quote not in self._asset_data:
return None
df_base = self._asset_data[base]
df_quote = self._asset_data[quote]
# Align indices
common_idx = df_base.index.intersection(df_quote.index)
if len(common_idx) < 200:
return None
df_a = df_base.loc[common_idx]
df_b = df_quote.loc[common_idx]
# Calculate spread (base / quote)
spread = df_a['close'] / df_b['close']
# Z-Score
z_window = self.config.z_window
rolling_mean = spread.rolling(window=z_window).mean()
rolling_std = spread.rolling(window=z_window).std()
z_score = (spread - rolling_mean) / rolling_std
# Spread Technicals
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
spread_roc = spread.pct_change(periods=5) * 100
spread_change_1h = spread.pct_change(periods=1)
# Volume Analysis
vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10)
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10)
# Volatility
ret_a = df_a['close'].pct_change()
ret_b = df_b['close'].pct_change()
vol_a = ret_a.rolling(window=z_window).std()
vol_b = ret_b.rolling(window=z_window).std()
vol_spread_ratio = vol_a / (vol_b + 1e-10)
# Realized Volatility
realized_vol_a = ret_a.rolling(window=24).std()
realized_vol_b = ret_b.rolling(window=24).std()
# ATR (Average True Range)
high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close']
tr_a = pd.concat([
high_a - low_a,
(high_a - close_a.shift(1)).abs(),
(low_a - close_a.shift(1)).abs()
], axis=1).max(axis=1)
atr_a = tr_a.rolling(window=self.config.atr_period).mean()
atr_pct_a = atr_a / close_a
# Build feature DataFrame
features = pd.DataFrame(index=common_idx)
features['pair_id'] = pair.pair_id
features['base_asset'] = base
features['quote_asset'] = quote
# Price data
features['spread'] = spread
features['base_close'] = df_a['close']
features['quote_close'] = df_b['close']
features['base_volume'] = df_a['volume']
# Core Features
features['z_score'] = z_score
features['spread_rsi'] = spread_rsi
features['spread_roc'] = spread_roc
features['spread_change_1h'] = spread_change_1h
features['vol_ratio'] = vol_ratio
features['vol_ratio_rel'] = vol_ratio_rel
features['vol_diff_ratio'] = vol_spread_ratio
# Volatility
features['realized_vol_base'] = realized_vol_a
features['realized_vol_quote'] = realized_vol_b
features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2
# ATR
features['atr_base'] = atr_a
features['atr_pct_base'] = atr_pct_a
# Pair encoding
assets = self.config.assets
features['base_idx'] = assets.index(base) if base in assets else -1
features['quote_idx'] = assets.index(quote) if quote in assets else -1
# Funding rates
base_funding = self._funding_rates.get(base, 0.0)
quote_funding = self._funding_rates.get(quote, 0.0)
features['base_funding'] = base_funding
features['quote_funding'] = quote_funding
features['funding_diff'] = base_funding - quote_funding
features['funding_avg'] = (base_funding + quote_funding) / 2
# Drop NaN rows in core features
core_cols = [
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
'realized_vol_base', 'atr_base', 'atr_pct_base'
]
features = features.dropna(subset=core_cols)
return features
def calculate_all_pair_features(self) -> dict[str, pd.DataFrame]:
"""
Calculate features for all pairs.
Returns:
Dictionary mapping pair_id to feature DataFrame
"""
all_features = {}
for pair in self._pairs:
features = self.calculate_pair_features(pair)
if features is not None and len(features) > 0:
all_features[pair.pair_id] = features
logger.info("Calculated features for %d/%d pairs",
len(all_features), len(self._pairs))
return all_features
def get_latest_data(self) -> dict[str, pd.DataFrame] | None:
"""
Fetch and process latest market data for all pairs.
Returns:
Dictionary of pair features or None on error
"""
try:
# Fetch OHLCV for all assets
self.fetch_all_ohlcv()
if len(self._asset_data) < 2:
logger.warning("Insufficient assets fetched")
return None
# Fetch funding rates
self.fetch_all_funding_rates()
# Calculate features for all pairs
pair_features = self.calculate_all_pair_features()
if not pair_features:
logger.warning("No pair features calculated")
return None
logger.info("Processed %d pairs with valid features", len(pair_features))
return pair_features
except Exception as e:
logger.error("Error fetching market data: %s", e, exc_info=True)
return None
def get_pair_by_id(self, pair_id: str) -> TradingPair | None:
"""Get pair object by ID."""
for pair in self._pairs:
if pair.pair_id == pair_id:
return pair
return None
def get_current_price(self, symbol: str) -> float | None:
"""Get current price for a symbol."""
if symbol in self._asset_data:
return self._asset_data[symbol]['close'].iloc[-1]
return None

View File

@@ -0,0 +1,609 @@
#!/usr/bin/env python3
"""
Multi-Pair Divergence Live Trading Bot.
Trades the top 10 cryptocurrency pairs based on spread divergence
using a universal ML model for signal generation.
Usage:
# Run with demo account (default)
uv run python -m live_trading.multi_pair.main
# Run with specific settings
uv run python -m live_trading.multi_pair.main --max-position 500 --leverage 2
"""
import argparse
import logging
import signal
import sys
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from live_trading.okx_client import OKXClient
from live_trading.position_manager import PositionManager
from live_trading.multi_pair.config import (
OKXConfig, MultiPairLiveConfig, PathConfig, get_multi_pair_config
)
from live_trading.multi_pair.data_feed import MultiPairDataFeed, TradingPair
from live_trading.multi_pair.strategy import LiveMultiPairStrategy
def setup_logging(log_dir: Path) -> logging.Logger:
"""Configure logging for the trading bot."""
log_file = log_dir / "multi_pair_live.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout),
],
force=True
)
return logging.getLogger(__name__)
@dataclass
class PositionState:
"""Track current position state for multi-pair."""
pair: TradingPair | None = None
pair_id: str | None = None
direction: str | None = None
entry_price: float = 0.0
size: float = 0.0
stop_loss: float = 0.0
take_profit: float = 0.0
entry_time: datetime | None = None
class MultiPairLiveTradingBot:
"""
Main trading bot for multi-pair divergence strategy.
Coordinates data fetching, pair scoring, and order execution.
"""
def __init__(
self,
okx_config: OKXConfig,
trading_config: MultiPairLiveConfig,
path_config: PathConfig
):
self.okx_config = okx_config
self.trading_config = trading_config
self.path_config = path_config
self.logger = logging.getLogger(__name__)
self.running = True
# Initialize components
self.logger.info("Initializing multi-pair trading bot...")
# Create OKX client with adapted config
self._adapted_trading_config = self._adapt_config_for_okx_client()
self.okx_client = OKXClient(okx_config, self._adapted_trading_config)
# Initialize data feed
self.data_feed = MultiPairDataFeed(
self.okx_client, trading_config, path_config
)
# Initialize position manager (reuse from single-pair)
self.position_manager = PositionManager(
self.okx_client, self._adapted_trading_config, path_config
)
# Initialize strategy
self.strategy = LiveMultiPairStrategy(trading_config, path_config)
# Current position state
self.position = PositionState()
# Register signal handlers
signal.signal(signal.SIGINT, self._handle_shutdown)
signal.signal(signal.SIGTERM, self._handle_shutdown)
self._print_startup_banner()
# Sync with exchange positions on startup
self._sync_position_from_exchange()
def _adapt_config_for_okx_client(self):
"""Create config compatible with OKXClient."""
# OKXClient expects specific attributes
@dataclass
class AdaptedConfig:
eth_symbol: str = "ETH/USDT:USDT"
btc_symbol: str = "BTC/USDT:USDT"
timeframe: str = "1h"
candles_to_fetch: int = 500
max_position_usdt: float = -1.0
min_position_usdt: float = 10.0
leverage: int = 1
margin_mode: str = "cross"
stop_loss_pct: float = 0.06
take_profit_pct: float = 0.05
max_concurrent_positions: int = 1
z_entry_threshold: float = 1.0
z_window: int = 24
model_prob_threshold: float = 0.5
funding_threshold: float = 0.0005
sleep_seconds: int = 3600
slippage_pct: float = 0.001
adapted = AdaptedConfig()
adapted.timeframe = self.trading_config.timeframe
adapted.candles_to_fetch = self.trading_config.candles_to_fetch
adapted.max_position_usdt = self.trading_config.max_position_usdt
adapted.min_position_usdt = self.trading_config.min_position_usdt
adapted.leverage = self.trading_config.leverage
adapted.margin_mode = self.trading_config.margin_mode
adapted.max_concurrent_positions = self.trading_config.max_concurrent_positions
adapted.sleep_seconds = self.trading_config.sleep_seconds
adapted.slippage_pct = self.trading_config.slippage_pct
return adapted
def _print_startup_banner(self) -> None:
"""Print startup information."""
mode = "DEMO/SANDBOX" if self.okx_config.demo_mode else "LIVE"
print("=" * 60)
print(" Multi-Pair Divergence Strategy - Live Trading Bot")
print("=" * 60)
print(f" Mode: {mode}")
print(f" Assets: {len(self.trading_config.assets)} assets")
print(f" Pairs: {self.trading_config.get_pair_count()} pairs")
print(f" Timeframe: {self.trading_config.timeframe}")
print(f" Max Position: ${self.trading_config.max_position_usdt if self.trading_config.max_position_usdt > 0 else 'All available'}")
print(f" Leverage: {self.trading_config.leverage}x")
print(f" Z-Entry: > {self.trading_config.z_entry_threshold}")
print(f" Prob Threshold: > {self.trading_config.prob_threshold}")
print(f" Cycle Interval: {self.trading_config.sleep_seconds // 60} minutes")
print("=" * 60)
print(f" Assets: {', '.join([a.split('/')[0] for a in self.trading_config.assets])}")
print("=" * 60)
if not self.okx_config.demo_mode:
print("\n *** WARNING: LIVE TRADING MODE - REAL FUNDS AT RISK ***\n")
def _handle_shutdown(self, signum, frame) -> None:
"""Handle shutdown signals gracefully."""
self.logger.info("Shutdown signal received, stopping...")
self.running = False
def _sync_position_from_exchange(self) -> bool:
"""
Sync internal position state with exchange positions.
Checks for existing open positions on the exchange and updates
internal state to match. This prevents stacking positions when
the bot is restarted.
Returns:
True if a position was synced, False otherwise
"""
try:
positions = self.okx_client.get_positions()
if not positions:
if self.position.pair is not None:
# Position was closed externally (e.g., SL/TP hit)
self.logger.info(
"Position %s was closed externally, resetting state",
self.position.pair.name if self.position.pair else "unknown"
)
self.position = PositionState()
return False
# Check each position against our tradeable assets
our_assets = set(self.trading_config.assets)
for pos in positions:
pos_symbol = pos.get('symbol', '')
contracts = abs(float(pos.get('contracts', 0)))
if contracts == 0:
continue
# Check if this position is for one of our assets
if pos_symbol not in our_assets:
continue
# Found a position for one of our assets
side = pos.get('side', 'long')
entry_price = float(pos.get('entryPrice', 0))
unrealized_pnl = float(pos.get('unrealizedPnl', 0))
# If we already track this position, just update
if (self.position.pair is not None and
self.position.pair.base_asset == pos_symbol):
self.logger.debug(
"Position already tracked: %s %s %.2f contracts",
side, pos_symbol, contracts
)
return True
# New position found - sync it
# Find or create a TradingPair for this position
matched_pair = None
for pair in self.data_feed.pairs:
if pair.base_asset == pos_symbol:
matched_pair = pair
break
if matched_pair is None:
# Create a placeholder pair (we don't know the quote asset)
matched_pair = TradingPair(
base_asset=pos_symbol,
quote_asset="UNKNOWN"
)
# Calculate approximate SL/TP based on config defaults
sl_pct = self.trading_config.base_sl_pct
tp_pct = self.trading_config.base_tp_pct
if side == 'long':
stop_loss = entry_price * (1 - sl_pct)
take_profit = entry_price * (1 + tp_pct)
else:
stop_loss = entry_price * (1 + sl_pct)
take_profit = entry_price * (1 - tp_pct)
self.position = PositionState(
pair=matched_pair,
pair_id=matched_pair.pair_id,
direction=side,
entry_price=entry_price,
size=contracts,
stop_loss=stop_loss,
take_profit=take_profit,
entry_time=None # Unknown for synced positions
)
self.logger.info(
"Synced existing position from exchange: %s %s %.4f @ %.4f (PnL: %.2f)",
side.upper(),
pos_symbol,
contracts,
entry_price,
unrealized_pnl
)
return True
# No matching positions found
if self.position.pair is not None:
self.logger.info(
"Position %s no longer exists on exchange, resetting state",
self.position.pair.name
)
self.position = PositionState()
return False
except Exception as e:
self.logger.error("Failed to sync position from exchange: %s", e)
return False
def run_trading_cycle(self) -> None:
"""
Execute one trading cycle.
1. Sync position state with exchange
2. Fetch latest market data for all assets
3. Calculate features for all pairs
4. Score pairs and find best opportunity
5. Check exit conditions for current position
6. Execute trades if needed
"""
cycle_start = datetime.now(timezone.utc)
self.logger.info("--- Trading Cycle Start: %s ---", cycle_start.isoformat())
try:
# 1. Sync position state with exchange (detect SL/TP closures)
self._sync_position_from_exchange()
# 2. Fetch all market data
pair_features = self.data_feed.get_latest_data()
if pair_features is None:
self.logger.warning("No market data available, skipping cycle")
return
# 2. Check exit conditions for current position
if self.position.pair is not None:
exit_signal = self.strategy.check_exit_signal(
pair_features,
self.position.pair_id
)
if exit_signal['action'] == 'exit':
self._execute_exit(exit_signal)
else:
# Check SL/TP
current_price = self.data_feed.get_current_price(
self.position.pair.base_asset
)
if current_price:
sl_tp_exit = self._check_sl_tp(current_price)
if sl_tp_exit:
self._execute_exit({'reason': sl_tp_exit})
# 3. Generate entry signal if no position
if self.position.pair is None:
entry_signal = self.strategy.generate_signal(
pair_features,
self.data_feed.pairs
)
if entry_signal['action'] == 'entry':
self._execute_entry(entry_signal)
# 4. Log status
if self.position.pair:
self.logger.info(
"Position: %s %s, entry=%.4f, current PnL check pending",
self.position.direction,
self.position.pair.name,
self.position.entry_price
)
else:
self.logger.info("No open position")
except Exception as e:
self.logger.error("Trading cycle error: %s", e, exc_info=True)
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
self.logger.info("--- Cycle completed in %.1fs ---", cycle_duration)
def _check_sl_tp(self, current_price: float) -> str | None:
"""Check stop-loss and take-profit levels."""
if self.position.direction == 'long':
if current_price <= self.position.stop_loss:
return f"stop_loss ({current_price:.4f} <= {self.position.stop_loss:.4f})"
if current_price >= self.position.take_profit:
return f"take_profit ({current_price:.4f} >= {self.position.take_profit:.4f})"
else: # short
if current_price >= self.position.stop_loss:
return f"stop_loss ({current_price:.4f} >= {self.position.stop_loss:.4f})"
if current_price <= self.position.take_profit:
return f"take_profit ({current_price:.4f} <= {self.position.take_profit:.4f})"
return None
def _execute_entry(self, signal: dict) -> None:
"""Execute entry trade."""
pair = signal['pair']
symbol = pair.base_asset # Trade the base asset
direction = signal['direction']
self.logger.info(
"Entry signal: %s %s (z=%.2f, p=%.2f, score=%.3f)",
direction.upper(),
pair.name,
signal['z_score'],
signal['probability'],
signal['divergence_score']
)
# Get account balance
try:
balance = self.okx_client.get_balance()
available_usdt = balance['free']
except Exception as e:
self.logger.error("Could not get balance: %s", e)
return
# Calculate position size
size_usdt = self.strategy.calculate_position_size(
signal['divergence_score'],
available_usdt
)
if size_usdt <= 0:
self.logger.info("Position size too small, skipping entry")
return
current_price = signal['base_price']
size_asset = size_usdt / current_price
# Calculate SL/TP
stop_loss, take_profit = self.strategy.calculate_sl_tp(
current_price,
direction,
signal['atr'],
signal['atr_pct']
)
self.logger.info(
"Executing %s entry: %.6f %s @ %.4f ($%.2f), SL=%.4f, TP=%.4f",
direction.upper(),
size_asset,
symbol.split('/')[0],
current_price,
size_usdt,
stop_loss,
take_profit
)
try:
# Place market order
order_side = "buy" if direction == "long" else "sell"
order = self.okx_client.place_market_order(symbol, order_side, size_asset)
filled_price = order.get('average') or order.get('price') or current_price
filled_amount = order.get('filled') or order.get('amount') or size_asset
if filled_price is None or filled_price == 0:
filled_price = current_price
if filled_amount is None or filled_amount == 0:
filled_amount = size_asset
# Recalculate SL/TP with filled price
stop_loss, take_profit = self.strategy.calculate_sl_tp(
filled_price, direction, signal['atr'], signal['atr_pct']
)
# Update position state
self.position = PositionState(
pair=pair,
pair_id=pair.pair_id,
direction=direction,
entry_price=filled_price,
size=filled_amount,
stop_loss=stop_loss,
take_profit=take_profit,
entry_time=datetime.now(timezone.utc)
)
self.logger.info(
"Position opened: %s %s %.6f @ %.4f",
direction.upper(),
pair.name,
filled_amount,
filled_price
)
# Try to set SL/TP on exchange
try:
self.okx_client.set_stop_loss_take_profit(
symbol, direction, filled_amount, stop_loss, take_profit
)
except Exception as e:
self.logger.warning("Could not set SL/TP on exchange: %s", e)
except Exception as e:
self.logger.error("Order execution failed: %s", e, exc_info=True)
def _execute_exit(self, signal: dict) -> None:
"""Execute exit trade."""
if self.position.pair is None:
return
symbol = self.position.pair.base_asset
reason = signal.get('reason', 'unknown')
self.logger.info(
"Exit signal: %s %s, reason: %s",
self.position.direction,
self.position.pair.name,
reason
)
try:
# Close position on exchange
self.okx_client.close_position(symbol)
self.logger.info(
"Position closed: %s %s",
self.position.direction,
self.position.pair.name
)
# Reset position state
self.position = PositionState()
except Exception as e:
self.logger.error("Exit execution failed: %s", e, exc_info=True)
def run(self) -> None:
"""Main trading loop."""
self.logger.info("Starting multi-pair trading loop...")
while self.running:
try:
self.run_trading_cycle()
if self.running:
sleep_seconds = self.trading_config.sleep_seconds
minutes = sleep_seconds // 60
self.logger.info("Sleeping for %d minutes...", minutes)
for _ in range(sleep_seconds):
if not self.running:
break
time.sleep(1)
except KeyboardInterrupt:
self.logger.info("Keyboard interrupt received")
break
except Exception as e:
self.logger.error("Unexpected error in main loop: %s", e, exc_info=True)
time.sleep(60)
self.logger.info("Shutting down...")
self.logger.info("Shutdown complete")
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Multi-Pair Divergence Live Trading Bot"
)
parser.add_argument(
"--max-position",
type=float,
default=None,
help="Maximum position size in USDT"
)
parser.add_argument(
"--leverage",
type=int,
default=None,
help="Trading leverage (1-125)"
)
parser.add_argument(
"--interval",
type=int,
default=None,
help="Trading cycle interval in seconds"
)
parser.add_argument(
"--live",
action="store_true",
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
)
return parser.parse_args()
def main():
"""Main entry point."""
args = parse_args()
# Load configuration
okx_config, trading_config, path_config = get_multi_pair_config()
# Apply command line overrides
if args.max_position is not None:
trading_config.max_position_usdt = args.max_position
if args.leverage is not None:
trading_config.leverage = args.leverage
if args.interval is not None:
trading_config.sleep_seconds = args.interval
if args.live:
okx_config.demo_mode = False
# Setup logging
logger = setup_logging(path_config.logs_dir)
try:
# Validate config
okx_config.validate()
# Create and run bot
bot = MultiPairLiveTradingBot(okx_config, trading_config, path_config)
bot.run()
except ValueError as e:
logger.error("Configuration error: %s", e)
sys.exit(1)
except Exception as e:
logger.error("Fatal error: %s", e, exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,396 @@
"""
Live Multi-Pair Divergence Strategy.
Scores all pairs and selects the best divergence opportunity for trading.
Uses the pre-trained universal ML model from backtesting.
"""
import logging
import pickle
from dataclasses import dataclass
from pathlib import Path
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
# Opt-in to future pandas behavior to silence FutureWarning on fillna
pd.set_option('future.no_silent_downcasting', True)
from .config import MultiPairLiveConfig, PathConfig
from .data_feed import TradingPair
logger = logging.getLogger(__name__)
@dataclass
class DivergenceSignal:
"""
Signal for a divergent pair.
Attributes:
pair: Trading pair
z_score: Current Z-Score of the spread
probability: ML model probability of profitable reversion
divergence_score: Combined score (|z_score| * probability)
direction: 'long' or 'short' (relative to base asset)
base_price: Current price of base asset
quote_price: Current price of quote asset
atr: Average True Range in price units
atr_pct: ATR as percentage of price
"""
pair: TradingPair
z_score: float
probability: float
divergence_score: float
direction: str
base_price: float
quote_price: float
atr: float
atr_pct: float
base_funding: float = 0.0
class LiveMultiPairStrategy:
"""
Live trading implementation of multi-pair divergence strategy.
Scores all pairs using the universal ML model and selects
the best opportunity for mean-reversion trading.
"""
def __init__(
self,
config: MultiPairLiveConfig,
path_config: PathConfig
):
self.config = config
self.paths = path_config
self.model: RandomForestClassifier | None = None
self.feature_cols: list[str] | None = None
self._load_model()
def _load_model(self) -> None:
"""Load pre-trained model from backtesting."""
if self.paths.model_path.exists():
try:
with open(self.paths.model_path, 'rb') as f:
saved = pickle.load(f)
self.model = saved['model']
self.feature_cols = saved['feature_cols']
logger.info("Loaded model from %s", self.paths.model_path)
except Exception as e:
logger.error("Could not load model: %s", e)
raise ValueError(
f"Multi-pair model not found at {self.paths.model_path}. "
"Run the backtest first to train the model."
)
else:
raise ValueError(
f"Multi-pair model not found at {self.paths.model_path}. "
"Run the backtest first to train the model."
)
def score_pairs(
self,
pair_features: dict[str, pd.DataFrame],
pairs: list[TradingPair]
) -> list[DivergenceSignal]:
"""
Score all pairs and return ranked signals.
Args:
pair_features: Feature DataFrames by pair_id
pairs: List of TradingPair objects
Returns:
List of DivergenceSignal sorted by score (descending)
"""
if self.model is None:
logger.warning("Model not loaded")
return []
signals = []
pair_map = {p.pair_id: p for p in pairs}
for pair_id, features in pair_features.items():
if pair_id not in pair_map:
continue
pair = pair_map[pair_id]
# Get latest features
if len(features) == 0:
continue
latest = features.iloc[-1]
z_score = latest['z_score']
# Skip if Z-score below threshold
if abs(z_score) < self.config.z_entry_threshold:
continue
# Prepare features for prediction
# Handle missing feature columns gracefully
available_cols = [c for c in self.feature_cols if c in latest.index]
missing_cols = [c for c in self.feature_cols if c not in latest.index]
if missing_cols:
logger.debug("Missing feature columns: %s", missing_cols)
feature_row = latest[available_cols].fillna(0)
feature_row = feature_row.replace([np.inf, -np.inf], 0)
# Create full feature vector with zeros for missing
X_dict = {c: 0 for c in self.feature_cols}
for col in available_cols:
X_dict[col] = feature_row[col]
X = pd.DataFrame([X_dict])
# Predict probability
prob = self.model.predict_proba(X)[0, 1]
# Skip if probability below threshold
if prob < self.config.prob_threshold:
continue
# Apply funding rate filter
base_funding = latest.get('base_funding', 0) or 0
funding_thresh = self.config.funding_threshold
if z_score > 0: # Short signal
if base_funding < -funding_thresh:
logger.debug(
"Skipping %s short: funding too negative (%.4f)",
pair.name, base_funding
)
continue
else: # Long signal
if base_funding > funding_thresh:
logger.debug(
"Skipping %s long: funding too positive (%.4f)",
pair.name, base_funding
)
continue
# Calculate divergence score
divergence_score = abs(z_score) * prob
# Determine direction
direction = 'short' if z_score > 0 else 'long'
signal = DivergenceSignal(
pair=pair,
z_score=z_score,
probability=prob,
divergence_score=divergence_score,
direction=direction,
base_price=latest['base_close'],
quote_price=latest['quote_close'],
atr=latest.get('atr_base', 0),
atr_pct=latest.get('atr_pct_base', 0.02),
base_funding=base_funding
)
signals.append(signal)
# Sort by divergence score (highest first)
signals.sort(key=lambda s: s.divergence_score, reverse=True)
if signals:
logger.info(
"Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f, dir=%s)",
len(signals),
signals[0].pair.name,
signals[0].divergence_score,
signals[0].z_score,
signals[0].probability,
signals[0].direction
)
else:
logger.info("No pairs meet entry criteria")
return signals
def select_best_pair(
self,
signals: list[DivergenceSignal]
) -> DivergenceSignal | None:
"""
Select the best pair from scored signals.
Args:
signals: List of DivergenceSignal (pre-sorted by score)
Returns:
Best signal or None if no valid candidates
"""
if not signals:
return None
return signals[0]
def generate_signal(
self,
pair_features: dict[str, pd.DataFrame],
pairs: list[TradingPair]
) -> dict:
"""
Generate trading signal from latest features.
Args:
pair_features: Feature DataFrames by pair_id
pairs: List of TradingPair objects
Returns:
Signal dictionary with action, pair, direction, etc.
"""
# Score all pairs
signals = self.score_pairs(pair_features, pairs)
# Select best
best = self.select_best_pair(signals)
if best is None:
return {
'action': 'hold',
'reason': 'no_valid_signals'
}
return {
'action': 'entry',
'pair': best.pair,
'pair_id': best.pair.pair_id,
'direction': best.direction,
'z_score': best.z_score,
'probability': best.probability,
'divergence_score': best.divergence_score,
'base_price': best.base_price,
'quote_price': best.quote_price,
'atr': best.atr,
'atr_pct': best.atr_pct,
'base_funding': best.base_funding,
'reason': f'{best.pair.name} z={best.z_score:.2f} p={best.probability:.2f}'
}
def check_exit_signal(
self,
pair_features: dict[str, pd.DataFrame],
current_pair_id: str
) -> dict:
"""
Check if current position should be exited.
Exit conditions:
1. Z-Score reverted to mean (|Z| < threshold)
Args:
pair_features: Feature DataFrames by pair_id
current_pair_id: Current position's pair ID
Returns:
Signal dictionary with action and reason
"""
if current_pair_id not in pair_features:
return {
'action': 'exit',
'reason': 'pair_data_missing'
}
features = pair_features[current_pair_id]
if len(features) == 0:
return {
'action': 'exit',
'reason': 'no_data'
}
latest = features.iloc[-1]
z_score = latest['z_score']
# Check mean reversion
if abs(z_score) < self.config.z_exit_threshold:
return {
'action': 'exit',
'reason': f'mean_reversion (z={z_score:.2f})'
}
return {
'action': 'hold',
'z_score': z_score,
'reason': f'holding (z={z_score:.2f})'
}
def calculate_sl_tp(
self,
entry_price: float,
direction: str,
atr: float,
atr_pct: float
) -> tuple[float, float]:
"""
Calculate ATR-based dynamic stop-loss and take-profit prices.
Args:
entry_price: Entry price
direction: 'long' or 'short'
atr: ATR in price units
atr_pct: ATR as percentage of price
Returns:
Tuple of (stop_loss_price, take_profit_price)
"""
if atr > 0 and atr_pct > 0:
sl_distance = atr * self.config.sl_atr_multiplier
tp_distance = atr * self.config.tp_atr_multiplier
sl_pct = sl_distance / entry_price
tp_pct = tp_distance / entry_price
else:
sl_pct = self.config.base_sl_pct
tp_pct = self.config.base_tp_pct
# Apply bounds
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
if direction == 'long':
stop_loss = entry_price * (1 - sl_pct)
take_profit = entry_price * (1 + tp_pct)
else:
stop_loss = entry_price * (1 + sl_pct)
take_profit = entry_price * (1 - tp_pct)
return stop_loss, take_profit
def calculate_position_size(
self,
divergence_score: float,
available_usdt: float
) -> float:
"""
Calculate position size based on divergence score.
Args:
divergence_score: Combined score (|z| * prob)
available_usdt: Available USDT balance
Returns:
Position size in USDT
"""
if self.config.max_position_usdt <= 0:
base_size = available_usdt
else:
base_size = min(available_usdt, self.config.max_position_usdt)
# Scale by divergence (1.0 at 0.5 score, up to 2.0 at 1.0+ score)
base_threshold = 0.5
if divergence_score <= base_threshold:
scale = 1.0
else:
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
scale = min(scale, 2.0)
size = base_size * scale
if size < self.config.min_position_usdt:
return 0.0
return min(size, available_usdt * 0.95)

View File

@@ -153,7 +153,7 @@ class OKXClient:
reduce_only: bool = False
) -> dict:
"""
Place a market order and fetch the fill price.
Place a market order.
Args:
symbol: Trading pair symbol
@@ -162,10 +162,7 @@ class OKXClient:
reduce_only: If True, only reduce existing position
Returns:
Order result dictionary with guaranteed 'average' fill price
Raises:
RuntimeError: If order placement fails or fill price unavailable
Order result dictionary
"""
params = {
'tdMode': self.trading_config.margin_mode,
@@ -176,48 +173,10 @@ class OKXClient:
order = self.exchange.create_market_order(
symbol, side, amount, params=params
)
order_id = order.get('id')
if not order_id:
raise RuntimeError(f"Order placement failed: no order ID returned")
logger.info(
f"Market {side.upper()} order placed: {amount} {symbol} "
f"@ market price, order_id={order_id}"
f"@ market price, order_id={order['id']}"
)
# Fetch order to get actual fill price if not in initial response
fill_price = order.get('average')
if fill_price is None or fill_price == 0:
logger.info(f"Fetching order {order_id} for fill price...")
try:
fetched_order = self.exchange.fetch_order(order_id, symbol)
fill_price = fetched_order.get('average')
order['average'] = fill_price
order['filled'] = fetched_order.get('filled', order.get('filled'))
order['status'] = fetched_order.get('status', order.get('status'))
except Exception as e:
logger.warning(f"Could not fetch order details: {e}")
# Final fallback: use current ticker price
if fill_price is None or fill_price == 0:
logger.warning(
f"No fill price from order response, fetching ticker..."
)
try:
ticker = self.get_ticker(symbol)
fill_price = ticker.get('last')
order['average'] = fill_price
except Exception as e:
logger.error(f"Could not fetch ticker: {e}")
if fill_price is None or fill_price <= 0:
raise RuntimeError(
f"Could not determine fill price for order {order_id}. "
f"Order response: {order}"
)
logger.info(f"Order {order_id} filled at {fill_price}")
return order
def place_limit_order(

View File

@@ -3,21 +3,16 @@ Position Manager for Live Trading.
Tracks open positions, manages risk, and handles SL/TP logic.
"""
import csv
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional, TYPE_CHECKING
from dataclasses import dataclass, asdict
from typing import Optional
from dataclasses import dataclass, field, asdict
from .okx_client import OKXClient
from .config import TradingConfig, PathConfig
if TYPE_CHECKING:
from .db.database import TradingDatabase
from .db.models import Trade
logger = logging.getLogger(__name__)
@@ -83,13 +78,11 @@ class PositionManager:
self,
okx_client: OKXClient,
trading_config: TradingConfig,
path_config: PathConfig,
database: Optional["TradingDatabase"] = None,
path_config: PathConfig
):
self.client = okx_client
self.config = trading_config
self.paths = path_config
self.db = database
self.positions: dict[str, Position] = {}
self.trade_log: list[dict] = []
self._load_positions()
@@ -256,55 +249,16 @@ class PositionManager:
return trade_record
def _append_trade_log(self, trade_record: dict) -> None:
"""Append trade record to CSV and SQLite database."""
# Write to CSV (backup/compatibility)
self._append_trade_csv(trade_record)
# Write to SQLite (primary)
self._append_trade_db(trade_record)
def _append_trade_csv(self, trade_record: dict) -> None:
"""Append trade record to CSV log file."""
import csv
file_exists = self.paths.trade_log_file.exists()
try:
with open(self.paths.trade_log_file, 'a', newline='') as f:
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
if not file_exists:
writer.writeheader()
writer.writerow(trade_record)
except Exception as e:
logger.error(f"Failed to write trade to CSV: {e}")
def _append_trade_db(self, trade_record: dict) -> None:
"""Append trade record to SQLite database."""
if self.db is None:
return
try:
from .db.models import Trade
trade = Trade(
trade_id=trade_record['trade_id'],
symbol=trade_record['symbol'],
side=trade_record['side'],
entry_price=trade_record['entry_price'],
exit_price=trade_record.get('exit_price'),
size=trade_record['size'],
size_usdt=trade_record['size_usdt'],
pnl_usd=trade_record.get('pnl_usd'),
pnl_pct=trade_record.get('pnl_pct'),
entry_time=trade_record['entry_time'],
exit_time=trade_record.get('exit_time'),
hold_duration_hours=trade_record.get('hold_duration_hours'),
reason=trade_record.get('reason'),
order_id_entry=trade_record.get('order_id_entry'),
order_id_exit=trade_record.get('order_id_exit'),
)
self.db.insert_trade(trade)
logger.debug(f"Trade {trade.trade_id} saved to database")
except Exception as e:
logger.error(f"Failed to write trade to database: {e}")
with open(self.paths.trade_log_file, 'a', newline='') as f:
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
if not file_exists:
writer.writeheader()
writer.writerow(trade_record)
def update_positions(self, current_prices: dict[str, float]) -> list[dict]:
"""

View File

@@ -1,10 +0,0 @@
"""Terminal UI module for live trading dashboard."""
from .dashboard import TradingDashboard
from .state import SharedState
from .log_handler import UILogHandler
__all__ = [
"TradingDashboard",
"SharedState",
"UILogHandler",
]

View File

@@ -1,240 +0,0 @@
"""Main trading dashboard UI orchestration."""
import logging
import queue
import threading
import time
from typing import Optional, Callable
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from rich.panel import Panel
from rich.text import Text
from .state import SharedState
from .log_handler import LogBuffer, UILogHandler
from .keyboard import KeyboardHandler
from .panels import (
HeaderPanel,
TabBar,
LogPanel,
HelpBar,
build_summary_panel,
)
from ..db.database import TradingDatabase
from ..db.metrics import MetricsCalculator, PeriodMetrics
logger = logging.getLogger(__name__)
class TradingDashboard:
"""
Main trading dashboard orchestrator.
Runs in a separate thread and provides real-time UI updates
while the trading loop runs in the main thread.
"""
def __init__(
self,
state: SharedState,
db: TradingDatabase,
log_queue: queue.Queue,
on_quit: Optional[Callable] = None,
):
self.state = state
self.db = db
self.log_queue = log_queue
self.on_quit = on_quit
self.console = Console()
self.log_buffer = LogBuffer(max_entries=1000)
self.keyboard = KeyboardHandler()
self.metrics_calculator = MetricsCalculator(db)
self._running = False
self._thread: Optional[threading.Thread] = None
self._active_tab = 0
self._cached_metrics: dict[int, PeriodMetrics] = {}
self._last_metrics_refresh = 0.0
def start(self) -> None:
"""Start the dashboard in a separate thread."""
if self._running:
return
self._running = True
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
logger.debug("Dashboard thread started")
def stop(self) -> None:
"""Stop the dashboard."""
self._running = False
if self._thread:
self._thread.join(timeout=2.0)
logger.debug("Dashboard thread stopped")
def _run(self) -> None:
"""Main dashboard loop."""
try:
with self.keyboard:
with Live(
self._build_layout(),
console=self.console,
refresh_per_second=1,
screen=True,
) as live:
while self._running and self.state.is_running():
# Process keyboard input
action = self.keyboard.get_action(timeout=0.1)
if action:
self._handle_action(action)
# Drain log queue
self.log_buffer.drain_queue(self.log_queue)
# Refresh metrics periodically (every 5 seconds)
now = time.time()
if now - self._last_metrics_refresh > 5.0:
self._refresh_metrics()
self._last_metrics_refresh = now
# Update display
live.update(self._build_layout())
# Small sleep to prevent CPU spinning
time.sleep(0.1)
except Exception as e:
logger.error(f"Dashboard error: {e}", exc_info=True)
finally:
self._running = False
def _handle_action(self, action: str) -> None:
"""Handle keyboard action."""
if action == "quit":
logger.info("Quit requested from UI")
self.state.stop()
if self.on_quit:
self.on_quit()
elif action == "refresh":
self._refresh_metrics()
logger.debug("Manual refresh triggered")
elif action == "filter":
new_filter = self.log_buffer.cycle_filter()
logger.debug(f"Log filter changed to: {new_filter}")
elif action == "filter_trades":
self.log_buffer.set_filter(LogBuffer.FILTER_TRADES)
logger.debug("Log filter set to: trades")
elif action == "filter_all":
self.log_buffer.set_filter(LogBuffer.FILTER_ALL)
logger.debug("Log filter set to: all")
elif action == "filter_errors":
self.log_buffer.set_filter(LogBuffer.FILTER_ERRORS)
logger.debug("Log filter set to: errors")
elif action == "tab_general":
self._active_tab = 0
elif action == "tab_monthly":
if self._has_monthly_data():
self._active_tab = 1
elif action == "tab_weekly":
if self._has_weekly_data():
self._active_tab = 2
elif action == "tab_daily":
self._active_tab = 3
def _refresh_metrics(self) -> None:
"""Refresh metrics from database."""
try:
self._cached_metrics[0] = self.metrics_calculator.get_all_time_metrics()
self._cached_metrics[1] = self.metrics_calculator.get_monthly_metrics()
self._cached_metrics[2] = self.metrics_calculator.get_weekly_metrics()
self._cached_metrics[3] = self.metrics_calculator.get_daily_metrics()
except Exception as e:
logger.warning(f"Failed to refresh metrics: {e}")
def _has_monthly_data(self) -> bool:
"""Check if monthly tab should be shown."""
try:
return self.metrics_calculator.has_monthly_data()
except Exception:
return False
def _has_weekly_data(self) -> bool:
"""Check if weekly tab should be shown."""
try:
return self.metrics_calculator.has_weekly_data()
except Exception:
return False
def _build_layout(self) -> Layout:
"""Build the complete dashboard layout."""
layout = Layout()
# Calculate available height
term_height = self.console.height or 40
# Header takes 3 lines
# Help bar takes 1 line
# Summary panel takes about 12-14 lines
# Rest goes to logs
log_height = max(8, term_height - 20)
layout.split_column(
Layout(name="header", size=3),
Layout(name="summary", size=14),
Layout(name="logs", size=log_height),
Layout(name="help", size=1),
)
# Header
layout["header"].update(HeaderPanel(self.state).render())
# Summary panel with tabs
current_metrics = self._cached_metrics.get(self._active_tab)
tab_bar = TabBar(active_tab=self._active_tab)
layout["summary"].update(
build_summary_panel(
state=self.state,
metrics=current_metrics,
tab_bar=tab_bar,
has_monthly=self._has_monthly_data(),
has_weekly=self._has_weekly_data(),
)
)
# Log panel
layout["logs"].update(LogPanel(self.log_buffer).render(height=log_height))
# Help bar
layout["help"].update(HelpBar().render())
return layout
def setup_ui_logging(log_queue: queue.Queue) -> UILogHandler:
"""
Set up logging to capture messages for UI.
Args:
log_queue: Queue to send log messages to
Returns:
UILogHandler instance
"""
handler = UILogHandler(log_queue)
handler.setLevel(logging.INFO)
# Add handler to root logger
root_logger = logging.getLogger()
root_logger.addHandler(handler)
return handler

View File

@@ -1,128 +0,0 @@
"""Keyboard input handling for terminal UI."""
import sys
import select
import termios
import tty
from typing import Optional, Callable
from dataclasses import dataclass
@dataclass
class KeyAction:
"""Represents a keyboard action."""
key: str
action: str
description: str
class KeyboardHandler:
"""
Non-blocking keyboard input handler.
Uses terminal raw mode to capture single keypresses
without waiting for Enter.
"""
# Key mappings
ACTIONS = {
"q": "quit",
"Q": "quit",
"\x03": "quit", # Ctrl+C
"r": "refresh",
"R": "refresh",
"f": "filter",
"F": "filter",
"t": "filter_trades",
"T": "filter_trades",
"l": "filter_all",
"L": "filter_all",
"e": "filter_errors",
"E": "filter_errors",
"1": "tab_general",
"2": "tab_monthly",
"3": "tab_weekly",
"4": "tab_daily",
}
def __init__(self):
self._old_settings = None
self._enabled = False
def enable(self) -> bool:
"""
Enable raw keyboard input mode.
Returns:
True if enabled successfully
"""
try:
if not sys.stdin.isatty():
return False
self._old_settings = termios.tcgetattr(sys.stdin)
tty.setcbreak(sys.stdin.fileno())
self._enabled = True
return True
except Exception:
return False
def disable(self) -> None:
"""Restore normal terminal mode."""
if self._enabled and self._old_settings:
try:
termios.tcsetattr(
sys.stdin,
termios.TCSADRAIN,
self._old_settings,
)
except Exception:
pass
self._enabled = False
def get_key(self, timeout: float = 0.1) -> Optional[str]:
"""
Get a keypress if available (non-blocking).
Args:
timeout: Seconds to wait for input
Returns:
Key character or None if no input
"""
if not self._enabled:
return None
try:
readable, _, _ = select.select([sys.stdin], [], [], timeout)
if readable:
return sys.stdin.read(1)
except Exception:
pass
return None
def get_action(self, timeout: float = 0.1) -> Optional[str]:
"""
Get action name for pressed key.
Args:
timeout: Seconds to wait for input
Returns:
Action name or None
"""
key = self.get_key(timeout)
if key:
return self.ACTIONS.get(key)
return None
def __enter__(self):
"""Context manager entry."""
self.enable()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.disable()
return False

View File

@@ -1,178 +0,0 @@
"""Custom logging handler for UI integration."""
import logging
import queue
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from collections import deque
@dataclass
class LogEntry:
"""A single log entry."""
timestamp: str
level: str
message: str
logger_name: str
@property
def level_color(self) -> str:
"""Get Rich color for log level."""
colors = {
"DEBUG": "dim",
"INFO": "white",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "bold red",
}
return colors.get(self.level, "white")
class UILogHandler(logging.Handler):
"""
Custom logging handler that sends logs to UI.
Uses a thread-safe queue to pass log entries from the trading
thread to the UI thread.
"""
def __init__(
self,
log_queue: queue.Queue,
max_entries: int = 1000,
):
super().__init__()
self.log_queue = log_queue
self.max_entries = max_entries
self.setFormatter(
logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
)
def emit(self, record: logging.LogRecord) -> None:
"""Emit a log record to the queue."""
try:
entry = LogEntry(
timestamp=datetime.fromtimestamp(record.created).strftime(
"%H:%M:%S"
),
level=record.levelname,
message=self.format_message(record),
logger_name=record.name,
)
# Non-blocking put, drop if queue is full
try:
self.log_queue.put_nowait(entry)
except queue.Full:
pass
except Exception:
self.handleError(record)
def format_message(self, record: logging.LogRecord) -> str:
"""Format the log message."""
return record.getMessage()
class LogBuffer:
"""
Thread-safe buffer for log entries with filtering support.
Maintains a fixed-size buffer of log entries and supports
filtering by log type.
"""
FILTER_ALL = "all"
FILTER_ERRORS = "errors"
FILTER_TRADES = "trades"
FILTER_SIGNALS = "signals"
FILTERS = [FILTER_ALL, FILTER_ERRORS, FILTER_TRADES, FILTER_SIGNALS]
def __init__(self, max_entries: int = 1000):
self.max_entries = max_entries
self._entries: deque[LogEntry] = deque(maxlen=max_entries)
self._current_filter = self.FILTER_ALL
def add(self, entry: LogEntry) -> None:
"""Add a log entry to the buffer."""
self._entries.append(entry)
def get_filtered(self, limit: int = 50) -> list[LogEntry]:
"""
Get filtered log entries.
Args:
limit: Maximum number of entries to return
Returns:
List of filtered LogEntry objects (most recent first)
"""
entries = list(self._entries)
if self._current_filter == self.FILTER_ERRORS:
entries = [e for e in entries if e.level in ("ERROR", "CRITICAL")]
elif self._current_filter == self.FILTER_TRADES:
# Key terms indicating actual trading activity
include_keywords = [
"order", "entry", "exit", "executed", "filled",
"opening", "closing", "position opened", "position closed"
]
# Terms to exclude (noise)
exclude_keywords = [
"sync complete", "0 positions", "portfolio: 0 positions"
]
entries = [
e for e in entries
if any(kw in e.message.lower() for kw in include_keywords)
and not any(ex in e.message.lower() for ex in exclude_keywords)
]
elif self._current_filter == self.FILTER_SIGNALS:
signal_keywords = ["signal", "z_score", "prob", "z="]
entries = [
e for e in entries
if any(kw in e.message.lower() for kw in signal_keywords)
]
# Return most recent entries
return list(reversed(entries[-limit:]))
def set_filter(self, filter_name: str) -> None:
"""Set a specific filter."""
if filter_name in self.FILTERS:
self._current_filter = filter_name
def cycle_filter(self) -> str:
"""Cycle to next filter and return its name."""
current_idx = self.FILTERS.index(self._current_filter)
next_idx = (current_idx + 1) % len(self.FILTERS)
self._current_filter = self.FILTERS[next_idx]
return self._current_filter
def get_current_filter(self) -> str:
"""Get current filter name."""
return self._current_filter
def clear(self) -> None:
"""Clear all log entries."""
self._entries.clear()
def drain_queue(self, log_queue: queue.Queue) -> int:
"""
Drain log entries from queue into buffer.
Args:
log_queue: Queue to drain from
Returns:
Number of entries drained
"""
count = 0
while True:
try:
entry = log_queue.get_nowait()
self.add(entry)
count += 1
except queue.Empty:
break
return count

View File

@@ -1,399 +0,0 @@
"""UI panel components using Rich."""
from typing import Optional
from rich.console import Console, Group
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from rich.layout import Layout
from .state import SharedState, PositionState, StrategyState, AccountState
from .log_handler import LogBuffer, LogEntry
from ..db.metrics import PeriodMetrics
def format_pnl(value: float, include_sign: bool = True) -> Text:
"""Format PnL value with color."""
if value > 0:
sign = "+" if include_sign else ""
return Text(f"{sign}${value:.2f}", style="green")
elif value < 0:
return Text(f"${value:.2f}", style="red")
else:
return Text(f"${value:.2f}", style="white")
def format_pct(value: float, include_sign: bool = True) -> Text:
"""Format percentage value with color."""
if value > 0:
sign = "+" if include_sign else ""
return Text(f"{sign}{value:.2f}%", style="green")
elif value < 0:
return Text(f"{value:.2f}%", style="red")
else:
return Text(f"{value:.2f}%", style="white")
def format_side(side: str) -> Text:
"""Format position side with color."""
if side.lower() == "long":
return Text("LONG", style="bold green")
else:
return Text("SHORT", style="bold red")
class HeaderPanel:
"""Header panel with title and mode indicator."""
def __init__(self, state: SharedState):
self.state = state
def render(self) -> Panel:
"""Render the header panel."""
mode = self.state.get_mode()
eth_symbol, _ = self.state.get_symbols()
mode_style = "yellow" if mode == "DEMO" else "bold red"
mode_text = Text(f"[{mode}]", style=mode_style)
title = Text()
title.append("REGIME REVERSION STRATEGY - LIVE TRADING", style="bold white")
title.append(" ")
title.append(mode_text)
title.append(" ")
title.append(eth_symbol, style="cyan")
return Panel(title, style="blue", height=3)
class TabBar:
"""Tab bar for period selection."""
TABS = ["1:General", "2:Monthly", "3:Weekly", "4:Daily"]
def __init__(self, active_tab: int = 0):
self.active_tab = active_tab
def render(
self,
has_monthly: bool = True,
has_weekly: bool = True,
) -> Text:
"""Render the tab bar."""
text = Text()
text.append(" ")
for i, tab in enumerate(self.TABS):
# Check if tab should be shown
if i == 1 and not has_monthly:
continue
if i == 2 and not has_weekly:
continue
if i == self.active_tab:
text.append(f"[{tab}]", style="bold white on blue")
else:
text.append(f"[{tab}]", style="dim")
text.append(" ")
return text
class MetricsPanel:
"""Panel showing trading metrics."""
def __init__(self, metrics: Optional[PeriodMetrics] = None):
self.metrics = metrics
def render(self) -> Table:
"""Render metrics as a table."""
table = Table(
show_header=False,
show_edge=False,
box=None,
padding=(0, 1),
)
table.add_column("Label", style="dim")
table.add_column("Value")
if self.metrics is None or self.metrics.total_trades == 0:
table.add_row("Status", Text("No trade data", style="dim"))
return table
m = self.metrics
table.add_row("Total PnL:", format_pnl(m.total_pnl))
table.add_row("Win Rate:", Text(f"{m.win_rate:.1f}%", style="white"))
table.add_row("Total Trades:", Text(str(m.total_trades), style="white"))
table.add_row(
"Win/Loss:",
Text(f"{m.winning_trades}/{m.losing_trades}", style="white"),
)
table.add_row(
"Avg Duration:",
Text(f"{m.avg_trade_duration_hours:.1f}h", style="white"),
)
table.add_row("Max Drawdown:", format_pnl(-m.max_drawdown))
table.add_row("Best Trade:", format_pnl(m.best_trade))
table.add_row("Worst Trade:", format_pnl(m.worst_trade))
return table
class PositionPanel:
"""Panel showing current position."""
def __init__(self, position: Optional[PositionState] = None):
self.position = position
def render(self) -> Table:
"""Render position as a table."""
table = Table(
show_header=False,
show_edge=False,
box=None,
padding=(0, 1),
)
table.add_column("Label", style="dim")
table.add_column("Value")
if self.position is None:
table.add_row("Status", Text("No open position", style="dim"))
return table
p = self.position
table.add_row("Side:", format_side(p.side))
table.add_row("Entry:", Text(f"${p.entry_price:.2f}", style="white"))
table.add_row("Current:", Text(f"${p.current_price:.2f}", style="white"))
# Unrealized PnL
pnl_text = Text()
pnl_text.append_text(format_pnl(p.unrealized_pnl))
pnl_text.append(" (")
pnl_text.append_text(format_pct(p.unrealized_pnl_pct))
pnl_text.append(")")
table.add_row("Unrealized:", pnl_text)
table.add_row("Size:", Text(f"${p.size_usdt:.2f}", style="white"))
# SL/TP
if p.side == "long":
sl_dist = (p.stop_loss_price / p.entry_price - 1) * 100
tp_dist = (p.take_profit_price / p.entry_price - 1) * 100
else:
sl_dist = (1 - p.stop_loss_price / p.entry_price) * 100
tp_dist = (1 - p.take_profit_price / p.entry_price) * 100
sl_text = Text(f"${p.stop_loss_price:.2f} ({sl_dist:+.1f}%)", style="red")
tp_text = Text(f"${p.take_profit_price:.2f} ({tp_dist:+.1f}%)", style="green")
table.add_row("Stop Loss:", sl_text)
table.add_row("Take Profit:", tp_text)
return table
class AccountPanel:
"""Panel showing account information."""
def __init__(self, account: Optional[AccountState] = None):
self.account = account
def render(self) -> Table:
"""Render account info as a table."""
table = Table(
show_header=False,
show_edge=False,
box=None,
padding=(0, 1),
)
table.add_column("Label", style="dim")
table.add_column("Value")
if self.account is None:
table.add_row("Status", Text("Loading...", style="dim"))
return table
a = self.account
table.add_row("Balance:", Text(f"${a.balance:.2f}", style="white"))
table.add_row("Available:", Text(f"${a.available:.2f}", style="white"))
table.add_row("Leverage:", Text(f"{a.leverage}x", style="cyan"))
return table
class StrategyPanel:
"""Panel showing strategy state."""
def __init__(self, strategy: Optional[StrategyState] = None):
self.strategy = strategy
def render(self) -> Table:
"""Render strategy state as a table."""
table = Table(
show_header=False,
show_edge=False,
box=None,
padding=(0, 1),
)
table.add_column("Label", style="dim")
table.add_column("Value")
if self.strategy is None:
table.add_row("Status", Text("Waiting...", style="dim"))
return table
s = self.strategy
# Z-score with color based on threshold
z_style = "white"
if abs(s.z_score) > 1.0:
z_style = "yellow"
if abs(s.z_score) > 1.5:
z_style = "bold yellow"
table.add_row("Z-Score:", Text(f"{s.z_score:.2f}", style=z_style))
# Probability with color
prob_style = "white"
if s.probability > 0.5:
prob_style = "green"
if s.probability > 0.7:
prob_style = "bold green"
table.add_row("Probability:", Text(f"{s.probability:.2f}", style=prob_style))
# Funding rate
funding_style = "green" if s.funding_rate >= 0 else "red"
table.add_row(
"Funding:",
Text(f"{s.funding_rate:.4f}", style=funding_style),
)
# Last action
action_style = "white"
if s.last_action == "entry":
action_style = "bold cyan"
elif s.last_action == "check_exit":
action_style = "yellow"
table.add_row("Last Action:", Text(s.last_action, style=action_style))
return table
class LogPanel:
"""Panel showing log entries."""
def __init__(self, log_buffer: LogBuffer):
self.log_buffer = log_buffer
def render(self, height: int = 10) -> Panel:
"""Render log panel."""
filter_name = self.log_buffer.get_current_filter().title()
entries = self.log_buffer.get_filtered(limit=height - 2)
lines = []
for entry in entries:
line = Text()
line.append(f"{entry.timestamp} ", style="dim")
line.append(f"[{entry.level}] ", style=entry.level_color)
line.append(entry.message)
lines.append(line)
if not lines:
lines.append(Text("No logs to display", style="dim"))
content = Group(*lines)
# Build "tabbed" title
tabs = []
# All Logs tab
if filter_name == "All":
tabs.append("[bold white on blue] [L]ogs [/]")
else:
tabs.append("[dim] [L]ogs [/]")
# Trades tab
if filter_name == "Trades":
tabs.append("[bold white on blue] [T]rades [/]")
else:
tabs.append("[dim] [T]rades [/]")
# Errors tab
if filter_name == "Errors":
tabs.append("[bold white on blue] [E]rrors [/]")
else:
tabs.append("[dim] [E]rrors [/]")
title = " ".join(tabs)
subtitle = "Press 'l', 't', 'e' to switch tabs"
return Panel(
content,
title=title,
subtitle=subtitle,
title_align="left",
subtitle_align="right",
border_style="blue",
)
class HelpBar:
"""Bottom help bar with keyboard shortcuts."""
def render(self) -> Text:
"""Render help bar."""
text = Text()
text.append(" [q]", style="bold")
text.append("Quit ", style="dim")
text.append("[r]", style="bold")
text.append("Refresh ", style="dim")
text.append("[1-4]", style="bold")
text.append("Tabs ", style="dim")
text.append("[l/t/e]", style="bold")
text.append("LogView", style="dim")
return text
def build_summary_panel(
state: SharedState,
metrics: Optional[PeriodMetrics],
tab_bar: TabBar,
has_monthly: bool,
has_weekly: bool,
) -> Panel:
"""Build the complete summary panel with all sections."""
# Create layout for summary content
layout = Layout()
# Tab bar at top
tabs = tab_bar.render(has_monthly, has_weekly)
# Create tables for each section
metrics_table = MetricsPanel(metrics).render()
position_table = PositionPanel(state.get_position()).render()
account_table = AccountPanel(state.get_account()).render()
strategy_table = StrategyPanel(state.get_strategy()).render()
# Build two-column layout
left_col = Table(show_header=True, show_edge=False, box=None, padding=(0, 2))
left_col.add_column("PERFORMANCE", style="bold cyan")
left_col.add_column("ACCOUNT", style="bold cyan")
left_col.add_row(metrics_table, account_table)
right_col = Table(show_header=True, show_edge=False, box=None, padding=(0, 2))
right_col.add_column("CURRENT POSITION", style="bold cyan")
right_col.add_column("STRATEGY STATE", style="bold cyan")
right_col.add_row(position_table, strategy_table)
# Combine into main table
main_table = Table(show_header=False, show_edge=False, box=None, expand=True)
main_table.add_column(ratio=1)
main_table.add_column(ratio=1)
main_table.add_row(left_col, right_col)
content = Group(tabs, Text(""), main_table)
return Panel(content, border_style="blue")

View File

@@ -1,195 +0,0 @@
"""Thread-safe shared state for UI and trading loop."""
import threading
from dataclasses import dataclass, field
from typing import Optional
from datetime import datetime, timezone
@dataclass
class PositionState:
"""Current position information."""
trade_id: str = ""
symbol: str = ""
side: str = ""
entry_price: float = 0.0
current_price: float = 0.0
size: float = 0.0
size_usdt: float = 0.0
unrealized_pnl: float = 0.0
unrealized_pnl_pct: float = 0.0
stop_loss_price: float = 0.0
take_profit_price: float = 0.0
@dataclass
class StrategyState:
"""Current strategy signal state."""
z_score: float = 0.0
probability: float = 0.0
funding_rate: float = 0.0
last_action: str = "hold"
last_reason: str = ""
last_signal_time: str = ""
@dataclass
class AccountState:
"""Account balance information."""
balance: float = 0.0
available: float = 0.0
leverage: int = 1
class SharedState:
"""
Thread-safe shared state between trading loop and UI.
All access to state fields should go through the getter/setter methods
which use a lock for thread safety.
"""
def __init__(self):
self._lock = threading.Lock()
self._position: Optional[PositionState] = None
self._strategy = StrategyState()
self._account = AccountState()
self._is_running = True
self._last_cycle_time: Optional[str] = None
self._mode = "DEMO"
self._eth_symbol = "ETH/USDT:USDT"
self._btc_symbol = "BTC/USDT:USDT"
# Position methods
def get_position(self) -> Optional[PositionState]:
"""Get current position state."""
with self._lock:
return self._position
def set_position(self, position: Optional[PositionState]) -> None:
"""Set current position state."""
with self._lock:
self._position = position
def update_position_price(self, current_price: float) -> None:
"""Update current price and recalculate PnL."""
with self._lock:
if self._position is None:
return
self._position.current_price = current_price
if self._position.side == "long":
pnl = (current_price - self._position.entry_price)
self._position.unrealized_pnl = pnl * self._position.size
pnl_pct = (current_price / self._position.entry_price - 1) * 100
else:
pnl = (self._position.entry_price - current_price)
self._position.unrealized_pnl = pnl * self._position.size
pnl_pct = (1 - current_price / self._position.entry_price) * 100
self._position.unrealized_pnl_pct = pnl_pct
def clear_position(self) -> None:
"""Clear current position."""
with self._lock:
self._position = None
# Strategy methods
def get_strategy(self) -> StrategyState:
"""Get current strategy state."""
with self._lock:
return StrategyState(
z_score=self._strategy.z_score,
probability=self._strategy.probability,
funding_rate=self._strategy.funding_rate,
last_action=self._strategy.last_action,
last_reason=self._strategy.last_reason,
last_signal_time=self._strategy.last_signal_time,
)
def update_strategy(
self,
z_score: float,
probability: float,
funding_rate: float,
action: str,
reason: str,
) -> None:
"""Update strategy state."""
with self._lock:
self._strategy.z_score = z_score
self._strategy.probability = probability
self._strategy.funding_rate = funding_rate
self._strategy.last_action = action
self._strategy.last_reason = reason
self._strategy.last_signal_time = datetime.now(
timezone.utc
).isoformat()
# Account methods
def get_account(self) -> AccountState:
"""Get current account state."""
with self._lock:
return AccountState(
balance=self._account.balance,
available=self._account.available,
leverage=self._account.leverage,
)
def update_account(
self,
balance: float,
available: float,
leverage: int,
) -> None:
"""Update account state."""
with self._lock:
self._account.balance = balance
self._account.available = available
self._account.leverage = leverage
# Control methods
def is_running(self) -> bool:
"""Check if trading loop is running."""
with self._lock:
return self._is_running
def stop(self) -> None:
"""Signal to stop trading loop."""
with self._lock:
self._is_running = False
def get_last_cycle_time(self) -> Optional[str]:
"""Get last trading cycle time."""
with self._lock:
return self._last_cycle_time
def set_last_cycle_time(self, time_str: str) -> None:
"""Set last trading cycle time."""
with self._lock:
self._last_cycle_time = time_str
# Config methods
def get_mode(self) -> str:
"""Get trading mode (DEMO/LIVE)."""
with self._lock:
return self._mode
def set_mode(self, mode: str) -> None:
"""Set trading mode."""
with self._lock:
self._mode = mode
def get_symbols(self) -> tuple[str, str]:
"""Get trading symbols (eth, btc)."""
with self._lock:
return self._eth_symbol, self._btc_symbol
def set_symbols(self, eth_symbol: str, btc_symbol: str) -> None:
"""Set trading symbols."""
with self._lock:
self._eth_symbol = eth_symbol
self._btc_symbol = btc_symbol

View File

@@ -15,8 +15,6 @@ dependencies = [
"plotly>=5.24.0",
"requests>=2.32.5",
"python-dotenv>=1.2.1",
# Terminal UI
"rich>=13.0.0",
# API dependencies
"fastapi>=0.115.0",
"uvicorn[standard]>=0.34.0",

View File

@@ -3,16 +3,7 @@ Regime Detection Research Script with Walk-Forward Training.
Tests multiple holding horizons to find optimal parameters
without look-ahead bias.
Usage:
uv run python research/regime_detection.py [options]
Options:
--days DAYS Number of days of data (default: 90)
--start DATE Start date (YYYY-MM-DD), overrides --days
--end DATE End date (YYYY-MM-DD), defaults to now
"""
import argparse
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -32,39 +23,20 @@ logger = get_logger(__name__)
# Configuration
TRAIN_RATIO = 0.7 # 70% train, 30% test
PROFIT_THRESHOLD = 0.005 # 0.5% profit target
STOP_LOSS_PCT = 0.06 # 6% stop loss
Z_WINDOW = 24
FEE_RATE = 0.001 # 0.1% round-trip fee
DEFAULT_DAYS = 90 # Default lookback period in days
def load_data(days: int = DEFAULT_DAYS, start_date: str = None, end_date: str = None):
"""
Load and align BTC/ETH data.
Args:
days: Number of days of historical data (default: 90)
start_date: Optional start date (YYYY-MM-DD), overrides days
end_date: Optional end date (YYYY-MM-DD), defaults to now
Returns:
Tuple of (df_btc, df_eth) DataFrames
"""
def load_data():
"""Load and align BTC/ETH data."""
dm = DataManager()
df_btc = dm.load_data("okx", "BTC-USDT", "1h", MarketType.SPOT)
df_eth = dm.load_data("okx", "ETH-USDT", "1h", MarketType.SPOT)
# Determine date range
if end_date:
end = pd.Timestamp(end_date, tz="UTC")
else:
end = pd.Timestamp.now(tz="UTC")
if start_date:
start = pd.Timestamp(start_date, tz="UTC")
else:
start = end - pd.Timedelta(days=days)
# Filter to Oct-Dec 2025
start = pd.Timestamp("2025-10-01", tz="UTC")
end = pd.Timestamp("2025-12-31", tz="UTC")
df_btc = df_btc[(df_btc.index >= start) & (df_btc.index <= end)]
df_eth = df_eth[(df_eth.index >= start) & (df_eth.index <= end)]
@@ -74,7 +46,7 @@ def load_data(days: int = DEFAULT_DAYS, start_date: str = None, end_date: str =
df_btc = df_btc.loc[common]
df_eth = df_eth.loc[common]
logger.info(f"Loaded {len(common)} aligned hourly bars from {start} to {end}")
logger.info(f"Loaded {len(common)} aligned hourly bars")
return df_btc, df_eth
@@ -140,74 +112,26 @@ def calculate_features(df_btc, df_eth, cq_df=None):
def calculate_targets(features, horizon):
"""
Calculate target labels for a given horizon.
"""Calculate target labels for a given horizon."""
spread = features['spread']
z_score = features['z_score']
Uses path-dependent labeling: Success is hitting Profit Target BEFORE Stop Loss.
"""
spread = features['spread'].values
z_score = features['z_score'].values
n = len(spread)
# For Short (Z > 1): Did spread drop below target?
future_min = spread.rolling(window=horizon).min().shift(-horizon)
target_short = spread * (1 - PROFIT_THRESHOLD)
success_short = (z_score > 1.0) & (future_min < target_short)
targets = np.zeros(n, dtype=int)
# For Long (Z < -1): Did spread rise above target?
future_max = spread.rolling(window=horizon).max().shift(-horizon)
target_long = spread * (1 + PROFIT_THRESHOLD)
success_long = (z_score < -1.0) & (future_max > target_long)
targets = np.select([success_short, success_long], [1, 1], default=0)
# Create valid mask (rows with complete future data)
valid_mask = np.zeros(n, dtype=bool)
valid_mask[:n-horizon] = True
valid_mask = future_min.notna() & future_max.notna()
# Only iterate relevant rows for efficiency
candidates = np.where((z_score > 1.0) | (z_score < -1.0))[0]
for i in candidates:
if i + horizon >= n:
continue
entry_price = spread[i]
future_prices = spread[i+1 : i+1+horizon]
if z_score[i] > 1.0: # Short
target_price = entry_price * (1 - PROFIT_THRESHOLD)
stop_price = entry_price * (1 + STOP_LOSS_PCT)
# Identify first hit indices
hit_tp = future_prices <= target_price
hit_sl = future_prices >= stop_price
if not np.any(hit_tp):
targets[i] = 0 # Target never hit
elif not np.any(hit_sl):
targets[i] = 1 # Target hit, SL never hit
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
# Success if TP hit before SL
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
else: # Long
target_price = entry_price * (1 + PROFIT_THRESHOLD)
stop_price = entry_price * (1 - STOP_LOSS_PCT)
hit_tp = future_prices >= target_price
hit_sl = future_prices <= stop_price
if not np.any(hit_tp):
targets[i] = 0
elif not np.any(hit_sl):
targets[i] = 1
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
return targets, pd.Series(valid_mask, index=features.index), None, None
return targets, valid_mask, future_min, future_max
def calculate_mae(features, predictions, test_idx, horizon):
@@ -244,10 +168,7 @@ def calculate_mae(features, predictions, test_idx, horizon):
def calculate_net_profit(features, predictions, test_idx, horizon):
"""
Calculate estimated net profit including fees.
Enforces 'one trade at a time' and simulates SL/TP exits.
"""
"""Calculate estimated net profit including fees."""
test_features = features.loc[test_idx]
spread = test_features['spread']
z_score = test_features['z_score']
@@ -255,17 +176,7 @@ def calculate_net_profit(features, predictions, test_idx, horizon):
total_pnl = 0.0
n_trades = 0
# Track when we are free to trade again
next_trade_idx = 0
# Pre-calculate indices for speed
all_indices = features.index
for i, (idx, pred) in enumerate(zip(test_idx, predictions)):
# Skip if we are still in a trade
if i < next_trade_idx:
continue
if pred != 1:
continue
@@ -273,77 +184,26 @@ def calculate_net_profit(features, predictions, test_idx, horizon):
z = z_score.loc[idx]
# Get future spread values
current_loc = features.index.get_loc(idx)
future_end_loc = min(current_loc + horizon, len(features))
future_spreads = features['spread'].iloc[current_loc+1 : future_end_loc]
future_idx = features.index.get_loc(idx)
future_end = min(future_idx + horizon, len(features))
future_spreads = features['spread'].iloc[future_idx:future_end]
if len(future_spreads) < 1:
if len(future_spreads) < 2:
continue
pnl = 0.0
trade_duration = len(future_spreads)
if z > 1.0: # Short trade
tp_price = entry_spread * (1 - PROFIT_THRESHOLD)
sl_price = entry_spread * (1 + STOP_LOSS_PCT)
hit_tp = future_spreads <= tp_price
hit_sl = future_spreads >= sl_price
# Check what happened first
first_tp = np.argmax(hit_tp.values) if hit_tp.any() else 99999
first_sl = np.argmax(hit_sl.values) if hit_sl.any() else 99999
if first_sl < first_tp and first_sl < 99999:
# Stopped out
exit_price = future_spreads.iloc[first_sl] # Approx SL price
# Use exact SL price for realistic simulation? Or close
# Let's use the close price of the bar where it crossed
pnl = (entry_spread - exit_price) / entry_spread
trade_duration = first_sl + 1
elif first_tp < first_sl and first_tp < 99999:
# Take profit
exit_price = future_spreads.iloc[first_tp]
pnl = (entry_spread - exit_price) / entry_spread
trade_duration = first_tp + 1
else:
# Held to horizon
exit_price = future_spreads.iloc[-1]
pnl = (entry_spread - exit_price) / entry_spread
else: # Long trade
tp_price = entry_spread * (1 + PROFIT_THRESHOLD)
sl_price = entry_spread * (1 - STOP_LOSS_PCT)
hit_tp = future_spreads >= tp_price
hit_sl = future_spreads <= sl_price
first_tp = np.argmax(hit_tp.values) if hit_tp.any() else 99999
first_sl = np.argmax(hit_sl.values) if hit_sl.any() else 99999
if first_sl < first_tp and first_sl < 99999:
# Stopped out
exit_price = future_spreads.iloc[first_sl]
pnl = (exit_price - entry_spread) / entry_spread
trade_duration = first_sl + 1
elif first_tp < first_sl and first_tp < 99999:
# Take profit
exit_price = future_spreads.iloc[first_tp]
pnl = (exit_price - entry_spread) / entry_spread
trade_duration = first_tp + 1
else:
# Held to horizon
exit_price = future_spreads.iloc[-1]
pnl = (exit_price - entry_spread) / entry_spread
# Calculate PnL based on direction
if z > 1.0: # Short trade - profit if spread drops
exit_spread = future_spreads.iloc[-1] # Exit at horizon
pnl = (entry_spread - exit_spread) / entry_spread
else: # Long trade - profit if spread rises
exit_spread = future_spreads.iloc[-1]
pnl = (exit_spread - entry_spread) / entry_spread
# Subtract fees
net_pnl = pnl - FEE_RATE
total_pnl += net_pnl
n_trades += 1
# Set next available trade index
next_trade_idx = i + trade_duration
return total_pnl, n_trades
@@ -420,7 +280,7 @@ def test_horizons(features, horizons):
print("\n" + "=" * 80)
print("WALK-FORWARD HORIZON OPTIMIZATION")
print(f"Train Ratio: {TRAIN_RATIO*100:.0f}% | Profit Target: {PROFIT_THRESHOLD*100:.1f}% | Stop Loss: {STOP_LOSS_PCT*100:.1f}% | Fee Rate: {FEE_RATE*100:.2f}%")
print(f"Train Ratio: {TRAIN_RATIO*100:.0f}% | Profit Target: {PROFIT_THRESHOLD*100:.1f}% | Fee Rate: {FEE_RATE*100:.2f}%")
print("=" * 80)
for h in horizons:
@@ -435,54 +295,10 @@ def test_horizons(features, horizons):
return results
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Regime detection research - test multiple horizons"
)
parser.add_argument(
"--days",
type=int,
default=DEFAULT_DAYS,
help=f"Number of days of data (default: {DEFAULT_DAYS})"
)
parser.add_argument(
"--start",
type=str,
default=None,
help="Start date (YYYY-MM-DD), overrides --days"
)
parser.add_argument(
"--end",
type=str,
default=None,
help="End date (YYYY-MM-DD), defaults to now"
)
parser.add_argument(
"--output",
type=str,
default="research/horizon_optimization_results.csv",
help="Output CSV path"
)
parser.add_argument(
"--output-horizon",
type=str,
default=None,
help="Path to save the best horizon (integer) to a file"
)
return parser.parse_args()
def main():
"""Main research function."""
args = parse_args()
# Load data with dynamic date range
df_btc, df_eth = load_data(
days=args.days,
start_date=args.start,
end_date=args.end
)
# Load data
df_btc, df_eth = load_data()
cq_df = load_cryptoquant_data()
# Calculate features
@@ -496,7 +312,7 @@ def main():
if not results:
print("No valid results!")
return None
return
# Find best by different metrics
results_df = pd.DataFrame(results)
@@ -515,15 +331,9 @@ def main():
print(f"Lowest MAE: {lowest_mae['horizon']:.0f}h (MAE={lowest_mae['avg_mae']:.2f}%)")
# Save results
results_df.to_csv(args.output, index=False)
print(f"\nResults saved to {args.output}")
# Save best horizon if requested
if args.output_horizon:
best_h = int(best_pnl['horizon'])
with open(args.output_horizon, 'w') as f:
f.write(str(best_h))
print(f"Best horizon {best_h}h saved to {args.output_horizon}")
output_path = "research/horizon_optimization_results.csv"
results_df.to_csv(output_path, index=False)
print(f"\nResults saved to {output_path}")
return results_df

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python3
"""
Download historical data for Multi-Pair Divergence Strategy.
Downloads 1h OHLCV data for top 10 cryptocurrencies from OKX.
"""
import sys
sys.path.insert(0, '.')
from engine.data_manager import DataManager
from engine.market import MarketType
from engine.logging_config import setup_logging, get_logger
from strategies.multi_pair import MultiPairConfig
logger = get_logger(__name__)
def main():
"""Download data for all configured assets."""
setup_logging()
config = MultiPairConfig()
dm = DataManager()
logger.info("Downloading data for %d assets...", len(config.assets))
for symbol in config.assets:
logger.info("Downloading %s perpetual 1h data...", symbol)
try:
df = dm.download_data(
exchange_id=config.exchange_id,
symbol=symbol,
timeframe=config.timeframe,
market_type=MarketType.PERPETUAL
)
if df is not None:
logger.info("Downloaded %d candles for %s", len(df), symbol)
else:
logger.warning("No data downloaded for %s", symbol)
except Exception as e:
logger.error("Failed to download %s: %s", symbol, e)
logger.info("Download complete!")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,156 @@
#!/usr/bin/env python3
"""
Run Multi-Pair Divergence Strategy backtest and compare with baseline.
Compares the multi-pair strategy against the single-pair BTC/ETH regime strategy.
"""
import sys
sys.path.insert(0, '.')
from engine.backtester import Backtester
from engine.data_manager import DataManager
from engine.logging_config import setup_logging, get_logger
from engine.reporting import Reporter
from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig
from strategies.regime_strategy import RegimeReversionStrategy
from engine.market import MarketType
logger = get_logger(__name__)
def run_baseline():
"""Run baseline BTC/ETH regime strategy."""
logger.info("=" * 60)
logger.info("BASELINE: BTC/ETH Regime Reversion Strategy")
logger.info("=" * 60)
dm = DataManager()
bt = Backtester(dm)
strategy = RegimeReversionStrategy()
result = bt.run_strategy(
strategy,
'okx',
'ETH-USDT',
timeframe='1h',
init_cash=10000
)
logger.info("Baseline Results:")
logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100)
logger.info(" Total Trades: %d", result.portfolio.trades.count())
logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100)
return result
def run_multi_pair(assets: list[str] | None = None):
"""Run multi-pair divergence strategy."""
logger.info("=" * 60)
logger.info("MULTI-PAIR: Divergence Selection Strategy")
logger.info("=" * 60)
dm = DataManager()
bt = Backtester(dm)
# Use provided assets or default
if assets:
config = MultiPairConfig(assets=assets)
else:
config = MultiPairConfig()
logger.info("Configured %d assets, %d pairs", len(config.assets), config.get_pair_count())
strategy = MultiPairDivergenceStrategy(config=config)
result = bt.run_strategy(
strategy,
'okx',
'ETH-USDT', # Reference asset (not used for trading, just index alignment)
timeframe='1h',
init_cash=10000
)
logger.info("Multi-Pair Results:")
logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100)
logger.info(" Total Trades: %d", result.portfolio.trades.count())
logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100)
return result
def compare_results(baseline, multi_pair):
"""Compare and display results."""
logger.info("=" * 60)
logger.info("COMPARISON")
logger.info("=" * 60)
baseline_return = baseline.portfolio.total_return() * 100
multi_return = multi_pair.portfolio.total_return() * 100
improvement = multi_return - baseline_return
logger.info("Baseline Return: %.2f%%", baseline_return)
logger.info("Multi-Pair Return: %.2f%%", multi_return)
logger.info("Improvement: %.2f%% (%.1fx)",
improvement,
multi_return / baseline_return if baseline_return != 0 else 0)
baseline_trades = baseline.portfolio.trades.count()
multi_trades = multi_pair.portfolio.trades.count()
logger.info("Baseline Trades: %d", baseline_trades)
logger.info("Multi-Pair Trades: %d", multi_trades)
return {
'baseline_return': baseline_return,
'multi_pair_return': multi_return,
'improvement': improvement,
'baseline_trades': baseline_trades,
'multi_pair_trades': multi_trades
}
def main():
"""Main entry point."""
setup_logging()
# Check available assets
dm = DataManager()
available = []
for symbol in MultiPairConfig().assets:
try:
dm.load_data('okx', symbol, '1h', market_type=MarketType.PERPETUAL)
available.append(symbol)
except FileNotFoundError:
pass
if len(available) < 2:
logger.error(
"Need at least 2 assets to run multi-pair strategy. "
"Run: uv run python scripts/download_multi_pair_data.py"
)
return
logger.info("Found data for %d assets: %s", len(available), available)
# Run baseline
baseline_result = run_baseline()
# Run multi-pair
multi_result = run_multi_pair(available)
# Compare
comparison = compare_results(baseline_result, multi_result)
# Save reports
reporter = Reporter()
reporter.save_reports(multi_result, "multi_pair_divergence")
logger.info("Reports saved to backtest_logs/")
if __name__ == "__main__":
main()

View File

@@ -37,6 +37,7 @@ def _build_registry() -> dict[str, StrategyConfig]:
from strategies.examples import MaCrossStrategy, RsiStrategy
from strategies.supertrend import MetaSupertrendStrategy
from strategies.regime_strategy import RegimeReversionStrategy
from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig
return {
"rsi": StrategyConfig(
@@ -98,6 +99,18 @@ def _build_registry() -> dict[str, StrategyConfig]:
'stop_loss': [0.04, 0.06, 0.08],
'funding_threshold': [0.005, 0.01, 0.02]
}
),
"multi_pair": StrategyConfig(
strategy_class=MultiPairDivergenceStrategy,
default_params={
# Multi-pair divergence strategy uses config object
# Parameters passed here will override MultiPairConfig defaults
},
grid_params={
'z_entry_threshold': [0.8, 1.0, 1.2],
'prob_threshold': [0.4, 0.5, 0.6],
'correlation_threshold': [0.75, 0.85, 0.95]
}
)
}

View File

@@ -0,0 +1,24 @@
"""
Multi-Pair Divergence Selection Strategy.
Extends regime detection to multiple cryptocurrency pairs and dynamically
selects the most divergent pair for trading.
"""
from .config import MultiPairConfig
from .pair_scanner import PairScanner, TradingPair
from .correlation import CorrelationFilter
from .feature_engine import MultiPairFeatureEngine
from .divergence_scorer import DivergenceScorer
from .strategy import MultiPairDivergenceStrategy
from .funding import FundingRateFetcher
__all__ = [
"MultiPairConfig",
"PairScanner",
"TradingPair",
"CorrelationFilter",
"MultiPairFeatureEngine",
"DivergenceScorer",
"MultiPairDivergenceStrategy",
"FundingRateFetcher",
]

View File

@@ -0,0 +1,88 @@
"""
Configuration for Multi-Pair Divergence Strategy.
"""
from dataclasses import dataclass, field
@dataclass
class MultiPairConfig:
"""
Configuration parameters for multi-pair divergence strategy.
Attributes:
assets: List of asset symbols to analyze (top 10 by market cap)
z_window: Rolling window for Z-Score calculation (hours)
z_entry_threshold: Minimum |Z-Score| to consider for entry
prob_threshold: Minimum ML probability to consider for entry
correlation_threshold: Max correlation to allow between pairs
correlation_window: Rolling window for correlation (hours)
atr_period: ATR lookback period for dynamic stops
sl_atr_multiplier: Stop-loss as multiple of ATR
tp_atr_multiplier: Take-profit as multiple of ATR
train_ratio: Walk-forward train/test split ratio
horizon: Look-ahead horizon for target calculation (hours)
profit_target: Minimum profit threshold for target labels
funding_threshold: Funding rate threshold for filtering
"""
# Asset Universe
assets: list[str] = field(default_factory=lambda: [
"BTC-USDT", "ETH-USDT", "SOL-USDT", "XRP-USDT", "BNB-USDT",
"DOGE-USDT", "ADA-USDT", "AVAX-USDT", "LINK-USDT", "DOT-USDT"
])
# Z-Score Thresholds
z_window: int = 24
z_entry_threshold: float = 1.0
# ML Thresholds
prob_threshold: float = 0.5
train_ratio: float = 0.7
horizon: int = 102
profit_target: float = 0.005
# Correlation Filtering
correlation_threshold: float = 0.85
correlation_window: int = 168 # 7 days in hours
# Risk Management - ATR-Based Stops
# SL/TP are calculated as multiples of ATR
# Mean ATR for crypto is ~0.6% per hour, so:
# - 10x ATR = ~6% SL (matches previous fixed 6%)
# - 8x ATR = ~5% TP (matches previous fixed 5%)
atr_period: int = 14 # ATR lookback period (hours for 1h timeframe)
sl_atr_multiplier: float = 10.0 # Stop-loss = entry +/- (ATR * multiplier)
tp_atr_multiplier: float = 8.0 # Take-profit = entry +/- (ATR * multiplier)
# Fallback fixed percentages (used if ATR is unavailable)
base_sl_pct: float = 0.06
base_tp_pct: float = 0.05
# ATR bounds to prevent extreme stops
min_sl_pct: float = 0.02 # Minimum 2% stop-loss
max_sl_pct: float = 0.10 # Maximum 10% stop-loss
min_tp_pct: float = 0.02 # Minimum 2% take-profit
max_tp_pct: float = 0.15 # Maximum 15% take-profit
volatility_window: int = 24
# Funding Rate Filter
# OKX funding rates are typically 0.0001 (0.01%) per 8h
# Extreme funding is > 0.0005 (0.05%) which indicates crowded trade
funding_threshold: float = 0.0005 # 0.05% - filter extreme funding
# Trade Management
# Note: Setting min_hold_bars=0 and z_exit_threshold=0 gives best results
# The mean-reversion exit at Z=0 is the primary profit driver
min_hold_bars: int = 0 # Disabled - let mean reversion drive exits
switch_threshold: float = 999.0 # Disabled - don't switch mid-trade
cooldown_bars: int = 0 # Disabled - enter when signal appears
z_exit_threshold: float = 0.0 # Exit at Z=0 (mean reversion complete)
# Exchange
exchange_id: str = "okx"
timeframe: str = "1h"
def get_pair_count(self) -> int:
"""Calculate number of unique pairs from asset list."""
n = len(self.assets)
return n * (n - 1) // 2

View File

@@ -0,0 +1,173 @@
"""
Correlation Filter for Multi-Pair Divergence Strategy.
Calculates rolling correlation matrix and filters pairs
to avoid highly correlated positions.
"""
import pandas as pd
import numpy as np
from engine.logging_config import get_logger
from .config import MultiPairConfig
from .pair_scanner import TradingPair
logger = get_logger(__name__)
class CorrelationFilter:
"""
Calculates and filters based on asset correlations.
Uses rolling correlation of returns to identify assets
moving together, avoiding redundant positions.
"""
def __init__(self, config: MultiPairConfig):
self.config = config
self._correlation_matrix: pd.DataFrame | None = None
self._last_update_idx: int = -1
def calculate_correlation_matrix(
self,
price_data: dict[str, pd.Series],
current_idx: int | None = None
) -> pd.DataFrame:
"""
Calculate rolling correlation matrix between all assets.
Args:
price_data: Dictionary mapping asset symbols to price series
current_idx: Current bar index (for caching)
Returns:
Correlation matrix DataFrame
"""
# Use cached if recent
if (
current_idx is not None
and self._correlation_matrix is not None
and current_idx - self._last_update_idx < 24 # Update every 24 bars
):
return self._correlation_matrix
# Calculate returns
returns = {}
for symbol, prices in price_data.items():
returns[symbol] = prices.pct_change()
returns_df = pd.DataFrame(returns)
# Rolling correlation
window = self.config.correlation_window
# Get latest correlation (last row of rolling correlation)
if len(returns_df) >= window:
rolling_corr = returns_df.rolling(window=window).corr()
# Extract last timestamp correlation matrix
last_idx = returns_df.index[-1]
corr_matrix = rolling_corr.loc[last_idx]
else:
# Fallback to full-period correlation if not enough data
corr_matrix = returns_df.corr()
self._correlation_matrix = corr_matrix
if current_idx is not None:
self._last_update_idx = current_idx
return corr_matrix
def filter_pairs(
self,
pairs: list[TradingPair],
current_position_asset: str | None,
price_data: dict[str, pd.Series],
current_idx: int | None = None
) -> list[TradingPair]:
"""
Filter pairs based on correlation with current position.
If we have an open position in an asset, exclude pairs where
either asset is highly correlated with the held asset.
Args:
pairs: List of candidate pairs
current_position_asset: Currently held asset (or None)
price_data: Dictionary of price series by symbol
current_idx: Current bar index for caching
Returns:
Filtered list of pairs
"""
if current_position_asset is None:
return pairs
corr_matrix = self.calculate_correlation_matrix(price_data, current_idx)
threshold = self.config.correlation_threshold
filtered = []
for pair in pairs:
# Check correlation of base and quote with held asset
base_corr = self._get_correlation(
corr_matrix, pair.base_asset, current_position_asset
)
quote_corr = self._get_correlation(
corr_matrix, pair.quote_asset, current_position_asset
)
# Filter if either asset highly correlated with position
if abs(base_corr) > threshold or abs(quote_corr) > threshold:
logger.debug(
"Filtered %s: base_corr=%.2f, quote_corr=%.2f (held: %s)",
pair.name, base_corr, quote_corr, current_position_asset
)
continue
filtered.append(pair)
if len(filtered) < len(pairs):
logger.info(
"Correlation filter: %d/%d pairs remaining (held: %s)",
len(filtered), len(pairs), current_position_asset
)
return filtered
def _get_correlation(
self,
corr_matrix: pd.DataFrame,
asset1: str,
asset2: str
) -> float:
"""
Get correlation between two assets from matrix.
Args:
corr_matrix: Correlation matrix
asset1: First asset symbol
asset2: Second asset symbol
Returns:
Correlation coefficient (-1 to 1), or 0 if not found
"""
if asset1 == asset2:
return 1.0
try:
return corr_matrix.loc[asset1, asset2]
except KeyError:
return 0.0
def get_correlation_report(
self,
price_data: dict[str, pd.Series]
) -> pd.DataFrame:
"""
Generate a readable correlation report.
Args:
price_data: Dictionary of price series
Returns:
Correlation matrix as DataFrame
"""
return self.calculate_correlation_matrix(price_data)

View File

@@ -0,0 +1,311 @@
"""
Divergence Scorer for Multi-Pair Strategy.
Ranks pairs by divergence score and selects the best candidate.
"""
from dataclasses import dataclass
from typing import Optional
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import pickle
from pathlib import Path
from engine.logging_config import get_logger
from .config import MultiPairConfig
from .pair_scanner import TradingPair
logger = get_logger(__name__)
@dataclass
class DivergenceSignal:
"""
Signal for a divergent pair.
Attributes:
pair: Trading pair
z_score: Current Z-Score of the spread
probability: ML model probability of profitable reversion
divergence_score: Combined score (|z_score| * probability)
direction: 'long' or 'short' (relative to base asset)
base_price: Current price of base asset
quote_price: Current price of quote asset
atr: Average True Range in price units
atr_pct: ATR as percentage of price
"""
pair: TradingPair
z_score: float
probability: float
divergence_score: float
direction: str
base_price: float
quote_price: float
atr: float
atr_pct: float
timestamp: pd.Timestamp
class DivergenceScorer:
"""
Scores and ranks pairs by divergence potential.
Uses ML model predictions combined with Z-Score magnitude
to identify the most promising mean-reversion opportunity.
"""
def __init__(self, config: MultiPairConfig, model_path: str = "data/multi_pair_model.pkl"):
self.config = config
self.model_path = Path(model_path)
self.model: RandomForestClassifier | None = None
self.feature_cols: list[str] | None = None
self._load_model()
def _load_model(self) -> None:
"""Load pre-trained model if available."""
if self.model_path.exists():
try:
with open(self.model_path, 'rb') as f:
saved = pickle.load(f)
self.model = saved['model']
self.feature_cols = saved['feature_cols']
logger.info("Loaded model from %s", self.model_path)
except Exception as e:
logger.warning("Could not load model: %s", e)
def save_model(self) -> None:
"""Save trained model."""
if self.model is None:
return
self.model_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.model_path, 'wb') as f:
pickle.dump({
'model': self.model,
'feature_cols': self.feature_cols,
}, f)
logger.info("Saved model to %s", self.model_path)
def train_model(
self,
combined_features: pd.DataFrame,
pair_features: dict[str, pd.DataFrame]
) -> None:
"""
Train universal model on all pairs.
Args:
combined_features: Combined feature DataFrame from all pairs
pair_features: Individual pair feature DataFrames (for target calculation)
"""
logger.info("Training universal model on %d samples...", len(combined_features))
z_thresh = self.config.z_entry_threshold
horizon = self.config.horizon
profit_target = self.config.profit_target
# Calculate targets for each pair
all_targets = []
all_features = []
for pair_id, features in pair_features.items():
if len(features) < horizon + 50:
continue
spread = features['spread']
z_score = features['z_score']
# Future price movements
future_min = spread.rolling(window=horizon).min().shift(-horizon)
future_max = spread.rolling(window=horizon).max().shift(-horizon)
# Target labels
target_short = spread * (1 - profit_target)
target_long = spread * (1 + profit_target)
success_short = (z_score > z_thresh) & (future_min < target_short)
success_long = (z_score < -z_thresh) & (future_max > target_long)
targets = np.select([success_short, success_long], [1, 1], default=0)
# Valid mask (exclude rows without complete future data)
valid_mask = future_min.notna() & future_max.notna()
# Collect valid samples
valid_features = features[valid_mask]
valid_targets = targets[valid_mask.values]
if len(valid_features) > 0:
all_features.append(valid_features)
all_targets.extend(valid_targets)
if not all_features:
logger.warning("No valid training samples")
return
# Combine all training data
X_df = pd.concat(all_features, ignore_index=True)
y = np.array(all_targets)
# Get feature columns
exclude_cols = [
'pair_id', 'base_asset', 'quote_asset',
'spread', 'base_close', 'quote_close', 'base_volume'
]
self.feature_cols = [c for c in X_df.columns if c not in exclude_cols]
# Prepare features
X = X_df[self.feature_cols].fillna(0)
X = X.replace([np.inf, -np.inf], 0)
# Train model
self.model = RandomForestClassifier(
n_estimators=300,
max_depth=5,
min_samples_leaf=30,
class_weight={0: 1, 1: 3},
random_state=42
)
self.model.fit(X, y)
logger.info(
"Model trained on %d samples, %d features, %.1f%% positive class",
len(X), len(self.feature_cols), y.mean() * 100
)
self.save_model()
def score_pairs(
self,
pair_features: dict[str, pd.DataFrame],
pairs: list[TradingPair],
timestamp: pd.Timestamp | None = None
) -> list[DivergenceSignal]:
"""
Score all pairs and return ranked signals.
Args:
pair_features: Feature DataFrames by pair_id
pairs: List of TradingPair objects
timestamp: Current timestamp for feature extraction
Returns:
List of DivergenceSignal sorted by score (descending)
"""
if self.model is None:
logger.warning("Model not trained, returning empty signals")
return []
signals = []
pair_map = {p.pair_id: p for p in pairs}
for pair_id, features in pair_features.items():
if pair_id not in pair_map:
continue
pair = pair_map[pair_id]
# Get latest features
if timestamp is not None:
valid = features[features.index <= timestamp]
if len(valid) == 0:
continue
latest = valid.iloc[-1]
ts = valid.index[-1]
else:
latest = features.iloc[-1]
ts = features.index[-1]
z_score = latest['z_score']
# Skip if Z-score below threshold
if abs(z_score) < self.config.z_entry_threshold:
continue
# Prepare features for prediction
feature_row = latest[self.feature_cols].fillna(0).infer_objects(copy=False)
feature_row = feature_row.replace([np.inf, -np.inf], 0)
X = pd.DataFrame([feature_row.values], columns=self.feature_cols)
# Predict probability
prob = self.model.predict_proba(X)[0, 1]
# Skip if probability below threshold
if prob < self.config.prob_threshold:
continue
# Apply funding rate filter
# Block trades where funding opposes our direction
base_funding = latest.get('base_funding', 0) or 0
funding_thresh = self.config.funding_threshold
if z_score > 0: # Short signal
# High negative funding = shorts are paying -> skip
if base_funding < -funding_thresh:
logger.debug(
"Skipping %s short: funding too negative (%.4f)",
pair.name, base_funding
)
continue
else: # Long signal
# High positive funding = longs are paying -> skip
if base_funding > funding_thresh:
logger.debug(
"Skipping %s long: funding too positive (%.4f)",
pair.name, base_funding
)
continue
# Calculate divergence score
divergence_score = abs(z_score) * prob
# Determine direction
# Z > 0: Spread high (base expensive vs quote) -> Short base
# Z < 0: Spread low (base cheap vs quote) -> Long base
direction = 'short' if z_score > 0 else 'long'
signal = DivergenceSignal(
pair=pair,
z_score=z_score,
probability=prob,
divergence_score=divergence_score,
direction=direction,
base_price=latest['base_close'],
quote_price=latest['quote_close'],
atr=latest.get('atr_base', 0),
atr_pct=latest.get('atr_pct_base', 0.02),
timestamp=ts
)
signals.append(signal)
# Sort by divergence score (highest first)
signals.sort(key=lambda s: s.divergence_score, reverse=True)
if signals:
logger.debug(
"Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f)",
len(signals),
signals[0].pair.name,
signals[0].divergence_score,
signals[0].z_score,
signals[0].probability
)
return signals
def select_best_pair(
self,
signals: list[DivergenceSignal]
) -> DivergenceSignal | None:
"""
Select the best pair from scored signals.
Args:
signals: List of DivergenceSignal (pre-sorted by score)
Returns:
Best signal or None if no valid candidates
"""
if not signals:
return None
return signals[0]

View File

@@ -0,0 +1,433 @@
"""
Feature Engineering for Multi-Pair Divergence Strategy.
Calculates features for all pairs in the universe, including
spread technicals, volatility, and on-chain data.
"""
import pandas as pd
import numpy as np
import ta
from engine.logging_config import get_logger
from engine.data_manager import DataManager
from engine.market import MarketType
from .config import MultiPairConfig
from .pair_scanner import TradingPair
from .funding import FundingRateFetcher
logger = get_logger(__name__)
class MultiPairFeatureEngine:
"""
Calculates features for multiple trading pairs.
Generates consistent feature sets across all pairs for
the universal ML model.
"""
def __init__(self, config: MultiPairConfig):
self.config = config
self.dm = DataManager()
self.funding_fetcher = FundingRateFetcher()
self._funding_data: pd.DataFrame | None = None
def load_all_assets(
self,
start_date: str | None = None,
end_date: str | None = None
) -> dict[str, pd.DataFrame]:
"""
Load OHLCV data for all assets in the universe.
Args:
start_date: Start date filter (YYYY-MM-DD)
end_date: End date filter (YYYY-MM-DD)
Returns:
Dictionary mapping symbol to OHLCV DataFrame
"""
data = {}
market_type = MarketType.PERPETUAL
for symbol in self.config.assets:
try:
df = self.dm.load_data(
self.config.exchange_id,
symbol,
self.config.timeframe,
market_type
)
# Apply date filters
if start_date:
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
if end_date:
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
if len(df) >= 200: # Minimum data requirement
data[symbol] = df
logger.debug("Loaded %s: %d bars", symbol, len(df))
else:
logger.warning(
"Skipping %s: insufficient data (%d bars)",
symbol, len(df)
)
except FileNotFoundError:
logger.warning("Data not found for %s", symbol)
except Exception as e:
logger.error("Error loading %s: %s", symbol, e)
logger.info("Loaded %d/%d assets", len(data), len(self.config.assets))
return data
def load_funding_data(
self,
start_date: str | None = None,
end_date: str | None = None,
use_cache: bool = True
) -> pd.DataFrame:
"""
Load funding rate data for all assets.
Args:
start_date: Start date filter
end_date: End date filter
use_cache: Whether to use cached data
Returns:
DataFrame with funding rates for all assets
"""
self._funding_data = self.funding_fetcher.get_funding_data(
self.config.assets,
start_date=start_date,
end_date=end_date,
use_cache=use_cache
)
if self._funding_data is not None and not self._funding_data.empty:
logger.info(
"Loaded funding data: %d rows, %d assets",
len(self._funding_data),
len(self._funding_data.columns)
)
else:
logger.warning("No funding data available")
return self._funding_data
def calculate_pair_features(
self,
pair: TradingPair,
asset_data: dict[str, pd.DataFrame],
on_chain_data: pd.DataFrame | None = None
) -> pd.DataFrame | None:
"""
Calculate features for a single pair.
Args:
pair: Trading pair
asset_data: Dictionary of OHLCV DataFrames by symbol
on_chain_data: Optional on-chain data (funding, inflows)
Returns:
DataFrame with features, or None if insufficient data
"""
base = pair.base_asset
quote = pair.quote_asset
if base not in asset_data or quote not in asset_data:
return None
df_base = asset_data[base]
df_quote = asset_data[quote]
# Align indices
common_idx = df_base.index.intersection(df_quote.index)
if len(common_idx) < 200:
logger.debug("Pair %s: insufficient aligned data", pair.name)
return None
df_a = df_base.loc[common_idx]
df_b = df_quote.loc[common_idx]
# Calculate spread (base / quote)
spread = df_a['close'] / df_b['close']
# Z-Score
z_window = self.config.z_window
rolling_mean = spread.rolling(window=z_window).mean()
rolling_std = spread.rolling(window=z_window).std()
z_score = (spread - rolling_mean) / rolling_std
# Spread Technicals
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
spread_roc = spread.pct_change(periods=5) * 100
spread_change_1h = spread.pct_change(periods=1)
# Volume Analysis
vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10)
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10)
# Volatility
ret_a = df_a['close'].pct_change()
ret_b = df_b['close'].pct_change()
vol_a = ret_a.rolling(window=z_window).std()
vol_b = ret_b.rolling(window=z_window).std()
vol_spread_ratio = vol_a / (vol_b + 1e-10)
# Realized Volatility (for dynamic SL/TP)
realized_vol_a = ret_a.rolling(window=self.config.volatility_window).std()
realized_vol_b = ret_b.rolling(window=self.config.volatility_window).std()
# ATR (Average True Range) for dynamic stops
# ATR = average of max(high-low, |high-prev_close|, |low-prev_close|)
high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close']
high_b, low_b, close_b = df_b['high'], df_b['low'], df_b['close']
# True Range for base asset
tr_a = pd.concat([
high_a - low_a,
(high_a - close_a.shift(1)).abs(),
(low_a - close_a.shift(1)).abs()
], axis=1).max(axis=1)
atr_a = tr_a.rolling(window=self.config.atr_period).mean()
# True Range for quote asset
tr_b = pd.concat([
high_b - low_b,
(high_b - close_b.shift(1)).abs(),
(low_b - close_b.shift(1)).abs()
], axis=1).max(axis=1)
atr_b = tr_b.rolling(window=self.config.atr_period).mean()
# ATR as percentage of price (normalized)
atr_pct_a = atr_a / close_a
atr_pct_b = atr_b / close_b
# Build feature DataFrame
features = pd.DataFrame(index=common_idx)
features['pair_id'] = pair.pair_id
features['base_asset'] = base
features['quote_asset'] = quote
# Price data (for reference, not features)
features['spread'] = spread
features['base_close'] = df_a['close']
features['quote_close'] = df_b['close']
features['base_volume'] = df_a['volume']
# Core Features
features['z_score'] = z_score
features['spread_rsi'] = spread_rsi
features['spread_roc'] = spread_roc
features['spread_change_1h'] = spread_change_1h
features['vol_ratio'] = vol_ratio
features['vol_ratio_rel'] = vol_ratio_rel
features['vol_diff_ratio'] = vol_spread_ratio
# Volatility for SL/TP
features['realized_vol_base'] = realized_vol_a
features['realized_vol_quote'] = realized_vol_b
features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2
# ATR for dynamic stops (in price units and as percentage)
features['atr_base'] = atr_a
features['atr_quote'] = atr_b
features['atr_pct_base'] = atr_pct_a
features['atr_pct_quote'] = atr_pct_b
features['atr_pct_avg'] = (atr_pct_a + atr_pct_b) / 2
# Pair encoding (for universal model)
# Using base and quote indices for hierarchical encoding
assets = self.config.assets
features['base_idx'] = assets.index(base) if base in assets else -1
features['quote_idx'] = assets.index(quote) if quote in assets else -1
# Add funding and on-chain features
# Funding data is always added from self._funding_data (OKX, all 10 assets)
# On-chain data is optional (CryptoQuant, BTC/ETH only)
features = self._add_on_chain_features(
features, on_chain_data, base, quote
)
# Drop rows with NaN in core features only (not funding/on-chain)
core_cols = [
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg',
'atr_base', 'atr_pct_base' # ATR is core for SL/TP
]
features = features.dropna(subset=core_cols)
# Fill missing funding/on-chain features with 0 (neutral)
optional_cols = [
'base_funding', 'quote_funding', 'funding_diff', 'funding_avg',
'base_inflow', 'quote_inflow', 'inflow_ratio'
]
for col in optional_cols:
if col in features.columns:
features[col] = features[col].fillna(0)
return features
def calculate_all_pair_features(
self,
pairs: list[TradingPair],
asset_data: dict[str, pd.DataFrame],
on_chain_data: pd.DataFrame | None = None
) -> dict[str, pd.DataFrame]:
"""
Calculate features for all pairs.
Args:
pairs: List of trading pairs
asset_data: Dictionary of OHLCV DataFrames
on_chain_data: Optional on-chain data
Returns:
Dictionary mapping pair_id to feature DataFrame
"""
all_features = {}
for pair in pairs:
features = self.calculate_pair_features(
pair, asset_data, on_chain_data
)
if features is not None and len(features) > 0:
all_features[pair.pair_id] = features
logger.info(
"Calculated features for %d/%d pairs",
len(all_features), len(pairs)
)
return all_features
def get_combined_features(
self,
pair_features: dict[str, pd.DataFrame],
timestamp: pd.Timestamp | None = None
) -> pd.DataFrame:
"""
Combine all pair features into a single DataFrame.
Useful for batch model prediction across all pairs.
Args:
pair_features: Dictionary of feature DataFrames by pair_id
timestamp: Optional specific timestamp to filter to
Returns:
Combined DataFrame with all pairs as rows
"""
if not pair_features:
return pd.DataFrame()
if timestamp is not None:
# Get latest row from each pair at or before timestamp
rows = []
for pair_id, features in pair_features.items():
valid = features[features.index <= timestamp]
if len(valid) > 0:
row = valid.iloc[-1:].copy()
rows.append(row)
if rows:
return pd.concat(rows, ignore_index=False)
return pd.DataFrame()
# Combine all features (for training)
return pd.concat(pair_features.values(), ignore_index=False)
def _add_on_chain_features(
self,
features: pd.DataFrame,
on_chain_data: pd.DataFrame | None,
base_asset: str,
quote_asset: str
) -> pd.DataFrame:
"""
Add on-chain and funding rate features for the pair.
Uses funding data from OKX (all 10 assets) and on-chain data
from CryptoQuant (BTC/ETH only for inflows).
"""
base_short = base_asset.replace('-USDT', '').lower()
quote_short = quote_asset.replace('-USDT', '').lower()
# Add funding rates from cached funding data
if self._funding_data is not None and not self._funding_data.empty:
funding_aligned = self._funding_data.reindex(
features.index, method='ffill'
)
base_funding_col = f'{base_short}_funding'
quote_funding_col = f'{quote_short}_funding'
if base_funding_col in funding_aligned.columns:
features['base_funding'] = funding_aligned[base_funding_col]
if quote_funding_col in funding_aligned.columns:
features['quote_funding'] = funding_aligned[quote_funding_col]
# Funding difference (positive = base has higher funding)
if 'base_funding' in features.columns and 'quote_funding' in features.columns:
features['funding_diff'] = (
features['base_funding'] - features['quote_funding']
)
# Funding sentiment: average of both assets
features['funding_avg'] = (
features['base_funding'] + features['quote_funding']
) / 2
# Add on-chain features from CryptoQuant (BTC/ETH only)
if on_chain_data is not None and not on_chain_data.empty:
cq_aligned = on_chain_data.reindex(features.index, method='ffill')
# Inflows (only available for BTC/ETH)
base_inflow_col = f'{base_short}_inflow'
quote_inflow_col = f'{quote_short}_inflow'
if base_inflow_col in cq_aligned.columns:
features['base_inflow'] = cq_aligned[base_inflow_col]
if quote_inflow_col in cq_aligned.columns:
features['quote_inflow'] = cq_aligned[quote_inflow_col]
if 'base_inflow' in features.columns and 'quote_inflow' in features.columns:
features['inflow_ratio'] = (
features['base_inflow'] /
(features['quote_inflow'] + 1)
)
return features
def get_feature_columns(self) -> list[str]:
"""
Get list of feature columns for ML model.
Excludes metadata and target-related columns.
Returns:
List of feature column names
"""
# Core features (always present)
core_features = [
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg',
'base_idx', 'quote_idx'
]
# Funding features (now available for all 10 assets via OKX)
funding_features = [
'base_funding', 'quote_funding', 'funding_diff', 'funding_avg'
]
# On-chain features (BTC/ETH only via CryptoQuant)
onchain_features = [
'base_inflow', 'quote_inflow', 'inflow_ratio'
]
return core_features + funding_features + onchain_features

View File

@@ -0,0 +1,272 @@
"""
Funding Rate Fetcher for Multi-Pair Strategy.
Fetches historical funding rates from OKX for all assets.
CryptoQuant only supports BTC/ETH, so we use OKX for the full universe.
"""
import time
from pathlib import Path
from datetime import datetime, timezone
import ccxt
import pandas as pd
from engine.logging_config import get_logger
logger = get_logger(__name__)
class FundingRateFetcher:
"""
Fetches and caches funding rate data from OKX.
OKX funding rates are settled every 8 hours (00:00, 08:00, 16:00 UTC).
This fetcher retrieves historical funding rate data and aligns it
to hourly candles for use in the multi-pair strategy.
"""
def __init__(self, cache_dir: str = "data/funding"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.exchange: ccxt.okx | None = None
def _init_exchange(self) -> None:
"""Initialize OKX exchange connection."""
if self.exchange is None:
self.exchange = ccxt.okx({
'enableRateLimit': True,
'options': {'defaultType': 'swap'}
})
self.exchange.load_markets()
def fetch_funding_history(
self,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 100
) -> pd.DataFrame:
"""
Fetch historical funding rates for a symbol.
Args:
symbol: Asset symbol (e.g., 'BTC-USDT')
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
limit: Max records per request
Returns:
DataFrame with funding rate history
"""
self._init_exchange()
# Convert symbol format
base = symbol.replace('-USDT', '')
okx_symbol = f"{base}/USDT:USDT"
try:
# OKX funding rate history endpoint
# Uses fetch_funding_rate_history if available
all_funding = []
# Parse dates
if start_date:
since = self.exchange.parse8601(f"{start_date}T00:00:00Z")
else:
# Default to 1 year ago
since = self.exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
if end_date:
until = self.exchange.parse8601(f"{end_date}T23:59:59Z")
else:
until = self.exchange.milliseconds()
# Fetch in batches
current_since = since
while current_since < until:
try:
funding = self.exchange.fetch_funding_rate_history(
okx_symbol,
since=current_since,
limit=limit
)
if not funding:
break
all_funding.extend(funding)
# Move to next batch
last_ts = funding[-1]['timestamp']
if last_ts <= current_since:
break
current_since = last_ts + 1
time.sleep(0.1) # Rate limit
except Exception as e:
logger.warning(
"Error fetching funding batch for %s: %s",
symbol, str(e)[:50]
)
break
if not all_funding:
return pd.DataFrame()
# Convert to DataFrame
df = pd.DataFrame(all_funding)
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df.set_index('timestamp', inplace=True)
df = df[['fundingRate']].rename(columns={'fundingRate': 'funding_rate'})
df.sort_index(inplace=True)
# Remove duplicates
df = df[~df.index.duplicated(keep='first')]
logger.info("Fetched %d funding records for %s", len(df), symbol)
return df
except Exception as e:
logger.error("Failed to fetch funding for %s: %s", symbol, e)
return pd.DataFrame()
def fetch_all_assets(
self,
assets: list[str],
start_date: str | None = None,
end_date: str | None = None
) -> pd.DataFrame:
"""
Fetch funding rates for all assets and combine.
Args:
assets: List of asset symbols (e.g., ['BTC-USDT', 'ETH-USDT'])
start_date: Start date
end_date: End date
Returns:
Combined DataFrame with columns like 'btc_funding', 'eth_funding', etc.
"""
combined = pd.DataFrame()
for symbol in assets:
df = self.fetch_funding_history(symbol, start_date, end_date)
if df.empty:
continue
# Rename column
asset_name = symbol.replace('-USDT', '').lower()
col_name = f"{asset_name}_funding"
df = df.rename(columns={'funding_rate': col_name})
if combined.empty:
combined = df
else:
combined = combined.join(df, how='outer')
time.sleep(0.2) # Be nice to API
# Forward fill to hourly (funding is every 8h)
if not combined.empty:
combined = combined.sort_index()
combined = combined.ffill()
return combined
def save_to_cache(self, df: pd.DataFrame, filename: str = "funding_rates.csv") -> None:
"""Save funding data to cache file."""
path = self.cache_dir / filename
df.to_csv(path)
logger.info("Saved funding rates to %s", path)
def load_from_cache(self, filename: str = "funding_rates.csv") -> pd.DataFrame | None:
"""Load funding data from cache if available."""
path = self.cache_dir / filename
if path.exists():
df = pd.read_csv(path, index_col='timestamp', parse_dates=True)
logger.info("Loaded funding rates from cache: %d rows", len(df))
return df
return None
def get_funding_data(
self,
assets: list[str],
start_date: str | None = None,
end_date: str | None = None,
use_cache: bool = True,
force_refresh: bool = False
) -> pd.DataFrame:
"""
Get funding data, using cache if available.
Args:
assets: List of asset symbols
start_date: Start date
end_date: End date
use_cache: Whether to use cached data
force_refresh: Force refresh even if cache exists
Returns:
DataFrame with funding rates for all assets
"""
cache_file = "funding_rates.csv"
# Try cache first
if use_cache and not force_refresh:
cached = self.load_from_cache(cache_file)
if cached is not None:
# Check if cache covers requested range
if start_date and end_date:
start_ts = pd.Timestamp(start_date, tz='UTC')
end_ts = pd.Timestamp(end_date, tz='UTC')
if cached.index.min() <= start_ts and cached.index.max() >= end_ts:
# Filter to requested range
return cached[(cached.index >= start_ts) & (cached.index <= end_ts)]
# Fetch fresh data
logger.info("Fetching fresh funding rate data...")
df = self.fetch_all_assets(assets, start_date, end_date)
if not df.empty and use_cache:
self.save_to_cache(df, cache_file)
return df
def download_funding_data():
"""Download funding data for all multi-pair assets."""
from strategies.multi_pair.config import MultiPairConfig
config = MultiPairConfig()
fetcher = FundingRateFetcher()
# Fetch last year of data
end_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
start_date = (datetime.now(timezone.utc) - pd.Timedelta(days=365)).strftime("%Y-%m-%d")
logger.info("Downloading funding rates for %d assets...", len(config.assets))
logger.info("Date range: %s to %s", start_date, end_date)
df = fetcher.get_funding_data(
config.assets,
start_date=start_date,
end_date=end_date,
force_refresh=True
)
if not df.empty:
logger.info("Downloaded %d funding rate records", len(df))
logger.info("Columns: %s", list(df.columns))
else:
logger.warning("No funding data downloaded")
return df
if __name__ == "__main__":
from engine.logging_config import setup_logging
setup_logging()
download_funding_data()

View File

@@ -0,0 +1,168 @@
"""
Pair Scanner for Multi-Pair Divergence Strategy.
Generates all possible pairs from asset universe and checks tradeability.
"""
from dataclasses import dataclass
from itertools import combinations
from typing import Optional
import ccxt
from engine.logging_config import get_logger
from .config import MultiPairConfig
logger = get_logger(__name__)
@dataclass
class TradingPair:
"""
Represents a tradeable pair for spread analysis.
Attributes:
base_asset: First asset in the pair (numerator)
quote_asset: Second asset in the pair (denominator)
pair_id: Unique identifier for the pair
is_direct: Whether pair can be traded directly on exchange
exchange_symbol: Symbol for direct trading (if available)
"""
base_asset: str
quote_asset: str
pair_id: str
is_direct: bool = False
exchange_symbol: Optional[str] = None
@property
def name(self) -> str:
"""Human-readable pair name."""
return f"{self.base_asset}/{self.quote_asset}"
def __hash__(self):
return hash(self.pair_id)
def __eq__(self, other):
if not isinstance(other, TradingPair):
return False
return self.pair_id == other.pair_id
class PairScanner:
"""
Scans and generates tradeable pairs from asset universe.
Checks OKX for directly tradeable cross-pairs and generates
synthetic pairs via USDT for others.
"""
def __init__(self, config: MultiPairConfig):
self.config = config
self.exchange: Optional[ccxt.Exchange] = None
self._available_markets: set[str] = set()
def _init_exchange(self) -> None:
"""Initialize exchange connection for market lookup."""
if self.exchange is None:
exchange_class = getattr(ccxt, self.config.exchange_id)
self.exchange = exchange_class({'enableRateLimit': True})
self.exchange.load_markets()
self._available_markets = set(self.exchange.symbols)
logger.info(
"Loaded %d markets from %s",
len(self._available_markets),
self.config.exchange_id
)
def generate_pairs(self, check_exchange: bool = True) -> list[TradingPair]:
"""
Generate all unique pairs from asset universe.
Args:
check_exchange: Whether to check OKX for direct trading
Returns:
List of TradingPair objects
"""
if check_exchange:
self._init_exchange()
pairs = []
assets = self.config.assets
for base, quote in combinations(assets, 2):
pair_id = f"{base}__{quote}"
# Check if directly tradeable as cross-pair on OKX
is_direct = False
exchange_symbol = None
if check_exchange:
# Check perpetual cross-pair (e.g., ETH/BTC:BTC)
# OKX perpetuals are typically quoted in USDT
# Cross-pairs like ETH/BTC are less common
cross_symbol = f"{base.replace('-USDT', '')}/{quote.replace('-USDT', '')}:USDT"
if cross_symbol in self._available_markets:
is_direct = True
exchange_symbol = cross_symbol
pair = TradingPair(
base_asset=base,
quote_asset=quote,
pair_id=pair_id,
is_direct=is_direct,
exchange_symbol=exchange_symbol
)
pairs.append(pair)
# Log summary
direct_count = sum(1 for p in pairs if p.is_direct)
logger.info(
"Generated %d pairs: %d direct, %d synthetic",
len(pairs), direct_count, len(pairs) - direct_count
)
return pairs
def get_required_symbols(self, pairs: list[TradingPair]) -> list[str]:
"""
Get list of symbols needed to calculate all pair spreads.
For synthetic pairs, we need both USDT pairs.
For direct pairs, we still load USDT pairs for simplicity.
Args:
pairs: List of trading pairs
Returns:
List of unique symbols to load (e.g., ['BTC-USDT', 'ETH-USDT'])
"""
symbols = set()
for pair in pairs:
symbols.add(pair.base_asset)
symbols.add(pair.quote_asset)
return list(symbols)
def filter_by_assets(
self,
pairs: list[TradingPair],
exclude_assets: list[str]
) -> list[TradingPair]:
"""
Filter pairs that contain any of the excluded assets.
Args:
pairs: List of trading pairs
exclude_assets: Assets to exclude
Returns:
Filtered list of pairs
"""
if not exclude_assets:
return pairs
exclude_set = set(exclude_assets)
return [
p for p in pairs
if p.base_asset not in exclude_set
and p.quote_asset not in exclude_set
]

View File

@@ -0,0 +1,525 @@
"""
Multi-Pair Divergence Selection Strategy.
Main strategy class that orchestrates pair scanning, feature calculation,
model training, and signal generation for backtesting.
"""
from dataclasses import dataclass
from typing import Optional
import pandas as pd
import numpy as np
from strategies.base import BaseStrategy
from engine.market import MarketType
from engine.logging_config import get_logger
from .config import MultiPairConfig
from .pair_scanner import PairScanner, TradingPair
from .correlation import CorrelationFilter
from .feature_engine import MultiPairFeatureEngine
from .divergence_scorer import DivergenceScorer, DivergenceSignal
logger = get_logger(__name__)
@dataclass
class PositionState:
"""Tracks current position state."""
pair: TradingPair | None = None
direction: str | None = None # 'long' or 'short'
entry_price: float = 0.0
entry_idx: int = -1
stop_loss: float = 0.0
take_profit: float = 0.0
atr: float = 0.0 # ATR at entry for reference
last_exit_idx: int = -100 # For cooldown tracking
class MultiPairDivergenceStrategy(BaseStrategy):
"""
Multi-Pair Divergence Selection Strategy.
Scans multiple cryptocurrency pairs for spread divergence,
selects the most divergent pair using ML-enhanced scoring,
and trades mean-reversion opportunities.
Key Features:
- Universal ML model across all pairs
- Correlation-based pair filtering
- Dynamic SL/TP based on volatility
- Walk-forward training
"""
def __init__(
self,
config: MultiPairConfig | None = None,
model_path: str = "data/multi_pair_model.pkl"
):
super().__init__()
self.config = config or MultiPairConfig()
# Initialize components
self.pair_scanner = PairScanner(self.config)
self.correlation_filter = CorrelationFilter(self.config)
self.feature_engine = MultiPairFeatureEngine(self.config)
self.divergence_scorer = DivergenceScorer(self.config, model_path)
# Strategy configuration
self.default_market_type = MarketType.PERPETUAL
self.default_leverage = 1
# Runtime state
self.pairs: list[TradingPair] = []
self.asset_data: dict[str, pd.DataFrame] = {}
self.pair_features: dict[str, pd.DataFrame] = {}
self.position = PositionState()
self.train_end_idx: int = 0
def run(self, close: pd.Series, **kwargs) -> tuple:
"""
Execute the multi-pair divergence strategy.
This method is called by the backtester with the primary asset's
close prices. For multi-pair, we load all assets internally.
Args:
close: Primary close prices (used for index alignment)
**kwargs: Additional data (high, low, volume)
Returns:
Tuple of (long_entries, long_exits, short_entries, short_exits, size)
"""
logger.info("Starting Multi-Pair Divergence Strategy")
# 1. Load all asset data
start_date = close.index.min().strftime("%Y-%m-%d")
end_date = close.index.max().strftime("%Y-%m-%d")
self.asset_data = self.feature_engine.load_all_assets(
start_date=start_date,
end_date=end_date
)
# 1b. Load funding rate data for all assets
self.feature_engine.load_funding_data(
start_date=start_date,
end_date=end_date,
use_cache=True
)
if len(self.asset_data) < 2:
logger.error("Insufficient assets loaded, need at least 2")
return self._empty_signals(close)
# 2. Generate pairs
self.pairs = self.pair_scanner.generate_pairs(check_exchange=False)
# Filter to pairs with available data
available_assets = set(self.asset_data.keys())
self.pairs = [
p for p in self.pairs
if p.base_asset in available_assets
and p.quote_asset in available_assets
]
logger.info("Trading %d pairs from %d assets", len(self.pairs), len(self.asset_data))
# 3. Calculate features for all pairs
self.pair_features = self.feature_engine.calculate_all_pair_features(
self.pairs, self.asset_data
)
if not self.pair_features:
logger.error("No pair features calculated")
return self._empty_signals(close)
# 4. Align to common index
common_index = self._get_common_index()
if len(common_index) < 200:
logger.error("Insufficient common data across pairs")
return self._empty_signals(close)
# 5. Walk-forward split
n_samples = len(common_index)
train_size = int(n_samples * self.config.train_ratio)
self.train_end_idx = train_size
train_end_date = common_index[train_size - 1]
test_start_date = common_index[train_size]
logger.info(
"Walk-Forward Split: Train=%d bars (until %s), Test=%d bars (from %s)",
train_size, train_end_date.strftime('%Y-%m-%d'),
n_samples - train_size, test_start_date.strftime('%Y-%m-%d')
)
# 6. Train model on training period
if self.divergence_scorer.model is None:
train_features = {
pid: feat[feat.index <= train_end_date]
for pid, feat in self.pair_features.items()
}
combined = self.feature_engine.get_combined_features(train_features)
self.divergence_scorer.train_model(combined, train_features)
# 7. Generate signals for test period
return self._generate_signals(common_index, train_size, close)
def _generate_signals(
self,
index: pd.DatetimeIndex,
train_size: int,
reference_close: pd.Series
) -> tuple:
"""
Generate entry/exit signals for the test period.
Iterates through each bar in the test period, scoring pairs
and generating signals based on divergence scores.
"""
# Initialize signal arrays aligned to reference close
long_entries = pd.Series(False, index=reference_close.index)
long_exits = pd.Series(False, index=reference_close.index)
short_entries = pd.Series(False, index=reference_close.index)
short_exits = pd.Series(False, index=reference_close.index)
size = pd.Series(1.0, index=reference_close.index)
# Track position state
self.position = PositionState()
# Price data for correlation calculation
price_data = {
symbol: df['close'] for symbol, df in self.asset_data.items()
}
# Iterate through test period
test_indices = index[train_size:]
trade_count = 0
for i, timestamp in enumerate(test_indices):
current_idx = train_size + i
# Check exit conditions first
if self.position.pair is not None:
# Enforce minimum hold period
bars_held = current_idx - self.position.entry_idx
if bars_held < self.config.min_hold_bars:
# Only allow SL/TP exits during min hold period
should_exit, exit_reason = self._check_sl_tp_only(timestamp)
else:
should_exit, exit_reason = self._check_exit(timestamp)
if should_exit:
# Map exit signal to reference index
if timestamp in reference_close.index:
if self.position.direction == 'long':
long_exits.loc[timestamp] = True
else:
short_exits.loc[timestamp] = True
logger.debug(
"Exit %s %s at %s: %s (held %d bars)",
self.position.direction,
self.position.pair.name,
timestamp.strftime('%Y-%m-%d %H:%M'),
exit_reason,
bars_held
)
self.position = PositionState(last_exit_idx=current_idx)
# Score pairs (with correlation filter if position exists)
held_asset = None
if self.position.pair is not None:
held_asset = self.position.pair.base_asset
# Filter pairs by correlation
candidate_pairs = self.correlation_filter.filter_pairs(
self.pairs,
held_asset,
price_data,
current_idx
)
# Get candidate features
candidate_features = {
pid: feat for pid, feat in self.pair_features.items()
if any(p.pair_id == pid for p in candidate_pairs)
}
# Score pairs
signals = self.divergence_scorer.score_pairs(
candidate_features,
candidate_pairs,
timestamp
)
# Get best signal
best = self.divergence_scorer.select_best_pair(signals)
if best is None:
continue
# Check if we should switch positions or enter new
should_enter = False
# Check cooldown
bars_since_exit = current_idx - self.position.last_exit_idx
in_cooldown = bars_since_exit < self.config.cooldown_bars
if self.position.pair is None and not in_cooldown:
# No position and not in cooldown, can enter
should_enter = True
elif self.position.pair is not None:
# Check if we should switch (requires min hold + significant improvement)
bars_held = current_idx - self.position.entry_idx
current_score = self._get_current_score(timestamp)
if (bars_held >= self.config.min_hold_bars and
best.divergence_score > current_score * self.config.switch_threshold):
# New opportunity is significantly better
if timestamp in reference_close.index:
if self.position.direction == 'long':
long_exits.loc[timestamp] = True
else:
short_exits.loc[timestamp] = True
self.position = PositionState(last_exit_idx=current_idx)
should_enter = True
if should_enter:
# Calculate ATR-based dynamic SL/TP
sl_price, tp_price = self._calculate_sl_tp(
best.base_price,
best.direction,
best.atr,
best.atr_pct
)
# Set position
self.position = PositionState(
pair=best.pair,
direction=best.direction,
entry_price=best.base_price,
entry_idx=current_idx,
stop_loss=sl_price,
take_profit=tp_price,
atr=best.atr
)
# Calculate position size based on divergence
pos_size = self._calculate_size(best.divergence_score)
# Generate entry signal
if timestamp in reference_close.index:
if best.direction == 'long':
long_entries.loc[timestamp] = True
else:
short_entries.loc[timestamp] = True
size.loc[timestamp] = pos_size
trade_count += 1
logger.debug(
"Entry %s %s at %s: z=%.2f, prob=%.2f, score=%.3f",
best.direction,
best.pair.name,
timestamp.strftime('%Y-%m-%d %H:%M'),
best.z_score,
best.probability,
best.divergence_score
)
logger.info("Generated %d trades in test period", trade_count)
return long_entries, long_exits, short_entries, short_exits, size
def _check_exit(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
"""
Check if current position should be exited.
Exit conditions:
1. Z-Score reverted to mean (|Z| < threshold)
2. Stop-loss hit
3. Take-profit hit
Returns:
Tuple of (should_exit, reason)
"""
if self.position.pair is None:
return False, ""
pair_id = self.position.pair.pair_id
if pair_id not in self.pair_features:
return True, "pair_data_missing"
features = self.pair_features[pair_id]
valid = features[features.index <= timestamp]
if len(valid) == 0:
return True, "no_data"
latest = valid.iloc[-1]
z_score = latest['z_score']
current_price = latest['base_close']
# Check mean reversion (primary exit)
if abs(z_score) < self.config.z_exit_threshold:
return True, f"mean_reversion (z={z_score:.2f})"
# Check SL/TP
return self._check_sl_tp(current_price)
def _check_sl_tp_only(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
"""
Check only stop-loss and take-profit conditions.
Used during minimum hold period.
"""
if self.position.pair is None:
return False, ""
pair_id = self.position.pair.pair_id
if pair_id not in self.pair_features:
return True, "pair_data_missing"
features = self.pair_features[pair_id]
valid = features[features.index <= timestamp]
if len(valid) == 0:
return True, "no_data"
latest = valid.iloc[-1]
current_price = latest['base_close']
return self._check_sl_tp(current_price)
def _check_sl_tp(self, current_price: float) -> tuple[bool, str]:
"""Check stop-loss and take-profit levels."""
if self.position.direction == 'long':
if current_price <= self.position.stop_loss:
return True, f"stop_loss ({current_price:.2f} <= {self.position.stop_loss:.2f})"
if current_price >= self.position.take_profit:
return True, f"take_profit ({current_price:.2f} >= {self.position.take_profit:.2f})"
else: # short
if current_price >= self.position.stop_loss:
return True, f"stop_loss ({current_price:.2f} >= {self.position.stop_loss:.2f})"
if current_price <= self.position.take_profit:
return True, f"take_profit ({current_price:.2f} <= {self.position.take_profit:.2f})"
return False, ""
def _get_current_score(self, timestamp: pd.Timestamp) -> float:
"""Get current position's divergence score for comparison."""
if self.position.pair is None:
return 0.0
pair_id = self.position.pair.pair_id
if pair_id not in self.pair_features:
return 0.0
features = self.pair_features[pair_id]
valid = features[features.index <= timestamp]
if len(valid) == 0:
return 0.0
latest = valid.iloc[-1]
z_score = abs(latest['z_score'])
# Re-score with model
if self.divergence_scorer.model is not None:
feature_row = latest[self.divergence_scorer.feature_cols].fillna(0)
feature_row = feature_row.replace([np.inf, -np.inf], 0)
X = pd.DataFrame(
[feature_row.values],
columns=self.divergence_scorer.feature_cols
)
prob = self.divergence_scorer.model.predict_proba(X)[0, 1]
return z_score * prob
return z_score * 0.5
def _calculate_sl_tp(
self,
entry_price: float,
direction: str,
atr: float,
atr_pct: float
) -> tuple[float, float]:
"""
Calculate ATR-based dynamic stop-loss and take-profit prices.
Uses ATR (Average True Range) to set stops that adapt to
each asset's volatility. More volatile assets get wider stops.
Args:
entry_price: Entry price
direction: 'long' or 'short'
atr: ATR in price units
atr_pct: ATR as percentage of price
Returns:
Tuple of (stop_loss_price, take_profit_price)
"""
# Calculate SL/TP as ATR multiples
if atr > 0 and atr_pct > 0:
# ATR-based calculation
sl_distance = atr * self.config.sl_atr_multiplier
tp_distance = atr * self.config.tp_atr_multiplier
# Convert to percentage for bounds checking
sl_pct = sl_distance / entry_price
tp_pct = tp_distance / entry_price
else:
# Fallback to fixed percentages if ATR unavailable
sl_pct = self.config.base_sl_pct
tp_pct = self.config.base_tp_pct
# Apply bounds to prevent extreme stops
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
# Calculate actual prices
if direction == 'long':
stop_loss = entry_price * (1 - sl_pct)
take_profit = entry_price * (1 + tp_pct)
else: # short
stop_loss = entry_price * (1 + sl_pct)
take_profit = entry_price * (1 - tp_pct)
return stop_loss, take_profit
def _calculate_size(self, divergence_score: float) -> float:
"""
Calculate position size based on divergence score.
Higher divergence = larger position (up to 2x).
"""
# Base score threshold (Z=1.0, prob=0.5 -> score=0.5)
base_threshold = 0.5
# Scale factor
if divergence_score <= base_threshold:
return 1.0
# Linear scaling: 1.0 at threshold, up to 2.0 at 2x threshold
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
return min(scale, 2.0)
def _get_common_index(self) -> pd.DatetimeIndex:
"""Get the intersection of all pair feature indices."""
if not self.pair_features:
return pd.DatetimeIndex([])
common = None
for features in self.pair_features.values():
if common is None:
common = features.index
else:
common = common.intersection(features.index)
return common.sort_values()
def _empty_signals(self, close: pd.Series) -> tuple:
"""Return empty signal arrays."""
empty = self.create_empty_signals(close)
size = pd.Series(1.0, index=close.index)
return empty, empty, empty, empty, size

View File

@@ -30,7 +30,7 @@ class RegimeReversionStrategy(BaseStrategy):
# Optimal parameters from walk-forward research (2025-10 to 2025-12)
# Research: research/horizon_optimization_results.csv
OPTIMAL_HORIZON = 54 # Updated from 102h based on corrected labeling
OPTIMAL_HORIZON = 102 # 4.25 days - best Net PnL (+232%)
OPTIMAL_Z_WINDOW = 24 # 24h rolling window for spread Z-score
OPTIMAL_TRAIN_RATIO = 0.7 # 70% train / 30% test split
OPTIMAL_PROFIT_TARGET = 0.005 # 0.5% profit threshold for target definition
@@ -321,64 +321,21 @@ class RegimeReversionStrategy(BaseStrategy):
train_features: DataFrame containing features for training period only
"""
threshold = self.profit_target
stop_loss_pct = self.stop_loss
horizon = self.horizon
z_thresh = self.z_entry_threshold
# Calculate targets path-dependently (checking SL before TP)
spread = train_features['spread'].values
z_score = train_features['z_score'].values
n = len(spread)
# Define targets using ONLY training data
# For Short Spread (Z > threshold): Did spread drop below target within horizon?
future_min = train_features['spread'].rolling(window=horizon).min().shift(-horizon)
target_short = train_features['spread'] * (1 - threshold)
success_short = (train_features['z_score'] > z_thresh) & (future_min < target_short)
targets = np.zeros(n, dtype=int)
# For Long Spread (Z < -threshold): Did spread rise above target within horizon?
future_max = train_features['spread'].rolling(window=horizon).max().shift(-horizon)
target_long = train_features['spread'] * (1 + threshold)
success_long = (train_features['z_score'] < -z_thresh) & (future_max > target_long)
# Only iterate relevant rows for efficiency
candidates = np.where((z_score > z_thresh) | (z_score < -z_thresh))[0]
for i in candidates:
if i + horizon >= n:
continue
entry_price = spread[i]
future_prices = spread[i+1 : i+1+horizon]
if z_score[i] > z_thresh: # Short
target_price = entry_price * (1 - threshold)
stop_price = entry_price * (1 + stop_loss_pct)
hit_tp = future_prices <= target_price
hit_sl = future_prices >= stop_price
if not np.any(hit_tp):
targets[i] = 0
elif not np.any(hit_sl):
targets[i] = 1
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
else: # Long
target_price = entry_price * (1 + threshold)
stop_price = entry_price * (1 - stop_loss_pct)
hit_tp = future_prices >= target_price
hit_sl = future_prices <= stop_price
if not np.any(hit_tp):
targets[i] = 0
elif not np.any(hit_sl):
targets[i] = 1
else:
first_tp_idx = np.argmax(hit_tp)
first_sl_idx = np.argmax(hit_sl)
if first_tp_idx < first_sl_idx:
targets[i] = 1
else:
targets[i] = 0
targets = np.select([success_short, success_long], [1, 1], default=0)
# Build model
model = RandomForestClassifier(
@@ -394,9 +351,10 @@ class RegimeReversionStrategy(BaseStrategy):
X_train = train_features[cols].fillna(0)
X_train = X_train.replace([np.inf, -np.inf], 0)
# Use rows where we had enough data to look ahead
valid_mask = np.zeros(n, dtype=bool)
valid_mask[:n-horizon] = True
# Remove rows with NaN targets (from rolling window at end of training period)
valid_mask = ~np.isnan(targets) & ~np.isinf(targets)
# Also check for rows where future data doesn't exist (shift created NaNs)
valid_mask = valid_mask & (future_min.notna().values) & (future_max.notna().values)
X_train_clean = X_train[valid_mask]
targets_clean = targets[valid_mask]

View File

@@ -0,0 +1,321 @@
# PRD: Multi-Pair Divergence Selection Strategy
## 1. Introduction / Overview
This document describes the **Multi-Pair Divergence Selection Strategy**, an extension of the existing BTC/ETH regime reversion system. The strategy expands spread analysis to the **top 10 cryptocurrencies by market cap**, calculates divergence scores for all tradeable pairs, and dynamically selects the **most divergent pair** for trading.
The core hypothesis: by scanning multiple pairs simultaneously, we can identify stronger mean-reversion opportunities than focusing on a single pair, improving net PnL while maintaining the proven ML-based regime detection approach.
---
## 2. Goals
1. **Extend regime detection** to top 10 market cap cryptocurrencies
2. **Dynamically select** the most divergent tradeable pair each cycle
3. **Integrate volatility** into dynamic SL/TP calculations
4. **Filter correlated pairs** to avoid redundant positions
5. **Improve net PnL** compared to single-pair BTC/ETH strategy
6. **Backtest-first** implementation with walk-forward validation
---
## 3. User Stories
### US-1: Multi-Pair Analysis
> As a trader, I want the system to analyze spread divergence across multiple cryptocurrency pairs so that I can identify the best trading opportunity at any given moment.
### US-2: Dynamic Pair Selection
> As a trader, I want the system to automatically select and trade the pair with the highest divergence score (combination of Z-score magnitude and ML probability) so that I maximize mean-reversion profit potential.
### US-3: Volatility-Adjusted Risk
> As a trader, I want stop-loss and take-profit levels to adapt to each pair's volatility so that I avoid being stopped out prematurely on volatile assets while protecting profits on stable ones.
### US-4: Correlation Filtering
> As a trader, I want the system to avoid selecting pairs that are highly correlated with my current position so that I don't inadvertently double-down on the same market exposure.
### US-5: Backtest Validation
> As a researcher, I want to backtest this multi-pair strategy with walk-forward training so that I can validate improvement over the single-pair baseline without look-ahead bias.
---
## 4. Functional Requirements
### 4.1 Data Management
| ID | Requirement |
|----|-------------|
| FR-1.1 | System must support loading OHLCV data for top 10 market cap cryptocurrencies |
| FR-1.2 | Target assets: BTC, ETH, SOL, XRP, BNB, DOGE, ADA, AVAX, LINK, DOT (configurable) |
| FR-1.3 | System must identify all directly tradeable cross-pairs on OKX perpetuals |
| FR-1.4 | System must align timestamps across all pairs for synchronized analysis |
| FR-1.5 | System must handle missing data gracefully (skip pair if insufficient history) |
### 4.2 Pair Generation
| ID | Requirement |
|----|-------------|
| FR-2.1 | Generate all unique pairs from asset universe: N*(N-1)/2 pairs (e.g., 45 pairs for 10 assets) |
| FR-2.2 | Filter pairs to only those directly tradeable on OKX (no USDT intermediate) |
| FR-2.3 | Fallback: If cross-pair not available, calculate synthetic spread via USDT pairs |
| FR-2.4 | Store pair metadata: base asset, quote asset, exchange symbol, tradeable flag |
### 4.3 Feature Engineering (Per Pair)
| ID | Requirement |
|----|-------------|
| FR-3.1 | Calculate spread ratio: `asset_a_close / asset_b_close` |
| FR-3.2 | Calculate Z-Score with configurable rolling window (default: 24h) |
| FR-3.3 | Calculate spread technicals: RSI(14), ROC(5), 1h change |
| FR-3.4 | Calculate volume ratio and relative volume |
| FR-3.5 | Calculate volatility ratio: `std(returns_a) / std(returns_b)` over Z-window |
| FR-3.6 | Calculate realized volatility for each asset (for dynamic SL/TP) |
| FR-3.7 | Merge on-chain data (funding rates, inflows) if available per asset |
| FR-3.8 | Add pair identifier as categorical feature for universal model |
### 4.4 Correlation Filtering
| ID | Requirement |
|----|-------------|
| FR-4.1 | Calculate rolling correlation matrix between all assets (default: 168h / 7 days) |
| FR-4.2 | Define correlation threshold (default: 0.85) |
| FR-4.3 | If current position exists, exclude pairs where either asset has correlation > threshold with held asset |
| FR-4.4 | Log filtered pairs with reason for exclusion |
### 4.5 Divergence Scoring & Pair Selection
| ID | Requirement |
|----|-------------|
| FR-5.1 | Calculate divergence score: `abs(z_score) * model_probability` |
| FR-5.2 | Only consider pairs where `abs(z_score) > z_entry_threshold` (default: 1.0) |
| FR-5.3 | Only consider pairs where `model_probability > prob_threshold` (default: 0.5) |
| FR-5.4 | Apply correlation filter to eligible pairs |
| FR-5.5 | Select pair with highest divergence score |
| FR-5.6 | If no pair qualifies, signal "hold" |
| FR-5.7 | Log all pair scores for analysis/debugging |
### 4.6 ML Model (Universal)
| ID | Requirement |
|----|-------------|
| FR-6.1 | Train single Random Forest model on all pairs combined |
| FR-6.2 | Include `pair_id` as one-hot encoded or label-encoded feature |
| FR-6.3 | Target: binary (1 = profitable reversion within horizon, 0 = no reversion) |
| FR-6.4 | Walk-forward training: 70% train / 30% test split |
| FR-6.5 | Daily retraining schedule (for live, configurable for backtest) |
| FR-6.6 | Model hyperparameters: `n_estimators=300, max_depth=5, min_samples_leaf=30, class_weight={0:1, 1:3}` |
| FR-6.7 | Save/load model with feature column metadata |
### 4.7 Signal Generation
| ID | Requirement |
|----|-------------|
| FR-7.1 | Direction: If `z_score > threshold` -> Short spread (short asset_a), If `z_score < -threshold` -> Long spread (long asset_a) |
| FR-7.2 | Apply funding rate filter per asset (block if extreme funding opposes direction) |
| FR-7.3 | Output signal: `{pair, action, side, probability, z_score, divergence_score, reason}` |
### 4.8 Position Sizing
| ID | Requirement |
|----|-------------|
| FR-8.1 | Base size: 100% of available subaccount balance |
| FR-8.2 | Scale by divergence: `size_multiplier = 1.0 + (divergence_score - base_threshold) * scaling_factor` |
| FR-8.3 | Cap multiplier between 1.0x and 2.0x |
| FR-8.4 | Respect exchange minimum order size per asset |
### 4.9 Dynamic SL/TP (Volatility-Adjusted)
| ID | Requirement |
|----|-------------|
| FR-9.1 | Calculate asset realized volatility: `std(returns) * sqrt(24)` for daily vol |
| FR-9.2 | Base SL: `entry_price * (1 - base_sl_pct * vol_multiplier)` for longs |
| FR-9.3 | Base TP: `entry_price * (1 + base_tp_pct * vol_multiplier)` for longs |
| FR-9.4 | `vol_multiplier = asset_volatility / baseline_volatility` (baseline = BTC volatility) |
| FR-9.5 | Cap vol_multiplier between 0.5x and 2.0x to prevent extreme values |
| FR-9.6 | Invert logic for short positions |
### 4.10 Exit Conditions
| ID | Requirement |
|----|-------------|
| FR-10.1 | Exit when Z-score crosses back through 0 (mean reversion complete) |
| FR-10.2 | Exit when dynamic SL or TP hit |
| FR-10.3 | No minimum holding period (can switch pairs immediately) |
| FR-10.4 | If new pair has higher divergence score, close current and open new |
### 4.11 Backtest Integration
| ID | Requirement |
|----|-------------|
| FR-11.1 | Integrate with existing `engine/backtester.py` framework |
| FR-11.2 | Support 1h timeframe (matching live trading) |
| FR-11.3 | Walk-forward validation: train on 70%, test on 30% |
| FR-11.4 | Output: trades log, equity curve, performance metrics |
| FR-11.5 | Compare against single-pair BTC/ETH baseline |
---
## 5. Non-Goals (Out of Scope)
1. **Live trading implementation** - Backtest validation first
2. **Multi-position portfolio** - Single pair at a time for v1
3. **Cross-exchange arbitrage** - OKX only
4. **Alternative ML models** - Stick with Random Forest for consistency
5. **Sub-1h timeframes** - 1h candles only for initial version
6. **Leveraged positions** - 1x leverage for backtest
7. **Portfolio-level VaR/risk budgeting** - Full subaccount allocation
---
## 6. Design Considerations
### 6.1 Architecture
```
strategies/
multi_pair/
__init__.py
pair_scanner.py # Generates all pairs, filters tradeable
feature_engine.py # Calculates features for all pairs
correlation.py # Rolling correlation matrix & filtering
divergence_scorer.py # Ranks pairs by divergence score
strategy.py # Main strategy orchestration
```
### 6.2 Data Flow
```
1. Load OHLCV for all 10 assets
2. Generate pair combinations (45 pairs)
3. Filter to tradeable pairs (OKX check)
4. Calculate features for each pair
5. Train/load universal ML model
6. Predict probability for all pairs
7. Calculate divergence scores
8. Apply correlation filter
9. Select top pair
10. Generate signal with dynamic SL/TP
11. Execute in backtest engine
```
### 6.3 Configuration
```python
@dataclass
class MultiPairConfig:
# Assets
assets: list[str] = field(default_factory=lambda: [
"BTC", "ETH", "SOL", "XRP", "BNB",
"DOGE", "ADA", "AVAX", "LINK", "DOT"
])
# Thresholds
z_window: int = 24
z_entry_threshold: float = 1.0
prob_threshold: float = 0.5
correlation_threshold: float = 0.85
correlation_window: int = 168 # 7 days in hours
# Risk
base_sl_pct: float = 0.06
base_tp_pct: float = 0.05
vol_multiplier_min: float = 0.5
vol_multiplier_max: float = 2.0
# Model
train_ratio: float = 0.7
horizon: int = 102
profit_target: float = 0.005
```
---
## 7. Technical Considerations
### 7.1 Dependencies
- Extend `DataManager` to load multiple symbols
- Query OKX API for available perpetual cross-pairs
- Reuse existing feature engineering from `RegimeReversionStrategy`
### 7.2 Performance
- Pre-calculate all pair features in batch (vectorized)
- Cache correlation matrix (update every N candles, not every minute)
- Model inference is fast (single predict call with all pairs as rows)
### 7.3 Edge Cases
- Handle pairs with insufficient history (< 200 bars) - exclude
- Handle assets delisted mid-backtest - skip pair
- Handle zero-volume periods - use last valid price
---
## 8. Success Metrics
| Metric | Baseline (BTC/ETH) | Target |
|--------|-------------------|--------|
| Net PnL | Current performance | > 10% improvement |
| Number of Trades | N | Comparable or higher |
| Win Rate | Baseline % | Maintain or improve |
| Average Trade Duration | Baseline hours | Flexible |
| Max Drawdown | Baseline % | Not significantly worse |
---
## 9. Open Questions
1. **OKX Cross-Pairs**: Need to verify which cross-pairs are available on OKX perpetuals. May need to fallback to synthetic spreads for most pairs.
2. **On-Chain Data**: CryptoQuant data currently covers BTC/ETH. Should we:
- Run without on-chain features for other assets?
- Source alternative on-chain data?
- Use funding rates only (available from OKX)?
3. **Pair ID Encoding**: For the universal model, should pair_id be:
- One-hot encoded (adds 45 features)?
- Label encoded (single ordinal feature)?
- Hierarchical (base_asset + quote_asset as separate features)?
4. **Synthetic Spreads**: If trading SOL/DOT spread but only USDT pairs available:
- Calculate spread synthetically: `SOL-USDT / DOT-USDT`
- Execute as two legs: Long SOL-USDT, Short DOT-USDT
- This doubles fees and adds execution complexity. Include in v1?
---
## 10. Implementation Phases
### Phase 1: Data & Infrastructure (Est. 2-3 days)
- Extend DataManager for multi-symbol loading
- Build pair scanner with OKX tradeable filter
- Implement correlation matrix calculation
### Phase 2: Feature Engineering (Est. 2 days)
- Adapt existing feature calculation for arbitrary pairs
- Add pair identifier feature
- Batch feature calculation for all pairs
### Phase 3: Model & Scoring (Est. 2 days)
- Train universal model on all pairs
- Implement divergence scoring
- Add correlation filtering to pair selection
### Phase 4: Strategy Integration (Est. 2-3 days)
- Implement dynamic SL/TP with volatility
- Integrate with backtester
- Build strategy orchestration class
### Phase 5: Validation & Comparison (Est. 2 days)
- Run walk-forward backtest
- Compare against BTC/ETH baseline
- Generate performance report
**Total Estimated Effort: 10-12 days**
---
*Document Version: 1.0*
*Created: 2026-01-15*
*Author: AI Assistant*
*Status: Draft - Awaiting Review*

View File

@@ -1,351 +0,0 @@
# PRD: Terminal UI for Live Trading Bot
## Introduction/Overview
The live trading bot currently uses basic console logging for output, making it difficult to monitor trading activity, track performance, and understand the system state at a glance. This feature introduces a Rich-based terminal UI that provides a professional, real-time dashboard for monitoring the live trading bot.
The UI will display a horizontal split layout with a **summary panel** at the top (with tabbed time-period views) and a **scrollable log panel** at the bottom. The interface will update every second and support keyboard navigation.
## Goals
1. Provide real-time visibility into trading performance (PnL, win rate, trade count)
2. Enable monitoring of current position state (entry, SL/TP, unrealized PnL)
3. Display strategy signals (Z-score, model probability) for transparency
4. Support historical performance tracking across time periods (daily, weekly, monthly, all-time)
5. Improve operational experience with keyboard shortcuts and log filtering
6. Create a responsive design that works across different terminal sizes
## User Stories
1. **As a trader**, I want to see my total PnL and daily PnL at a glance so I can quickly assess performance.
2. **As a trader**, I want to see my current position details (entry price, unrealized PnL, SL/TP levels) so I can monitor risk.
3. **As a trader**, I want to view performance metrics by time period (daily, weekly, monthly) so I can track trends.
4. **As a trader**, I want to filter logs by type (errors, trades, signals) so I can focus on relevant information.
5. **As a trader**, I want keyboard shortcuts to navigate the UI without using a mouse.
6. **As a trader**, I want the UI to show strategy state (Z-score, probability) so I understand why signals are generated.
## Functional Requirements
### FR1: Layout Structure
1.1. The UI must use a horizontal split layout with the summary panel at the top and logs panel at the bottom.
1.2. The summary panel must contain tabbed views accessible via number keys:
- Tab 1 (`1`): **General** - Overall metrics since bot started
- Tab 2 (`2`): **Monthly** - Current month metrics (shown only if data spans > 1 month)
- Tab 3 (`3`): **Weekly** - Current week metrics (shown only if data spans > 1 week)
- Tab 4 (`4`): **Daily** - Today's metrics
1.3. The logs panel must be a scrollable area showing recent log entries.
1.4. The UI must be responsive and adapt to terminal size (minimum 80x24).
### FR2: Metrics Display
The summary panel must display the following metrics:
**Performance Metrics:**
2.1. Total PnL (USD) - cumulative profit/loss since tracking began
2.2. Period PnL (USD) - profit/loss for selected time period (daily/weekly/monthly)
2.3. Win Rate (%) - percentage of winning trades
2.4. Total Number of Trades
2.5. Average Trade Duration (hours)
2.6. Max Drawdown (USD and %)
**Current Position (if open):**
2.7. Symbol and side (long/short)
2.8. Entry price
2.9. Current price
2.10. Unrealized PnL (USD and %)
2.11. Stop-loss price and distance (%)
2.12. Take-profit price and distance (%)
2.13. Position size (USD)
**Account Status:**
2.14. Account balance / available margin (USDT)
2.15. Current leverage setting
**Strategy State:**
2.16. Current Z-score
2.17. Model probability
2.18. Current funding rate (BTC)
2.19. Last signal action and reason
### FR3: Historical Data Loading
3.1. On startup, the system must initialize SQLite database at `live_trading/trading.db`.
3.2. If `trade_log.csv` exists and database is empty, migrate CSV data to SQLite.
3.3. The UI must load current positions from `live_trading/positions.json` (kept for compatibility with existing position manager).
3.4. Metrics must be calculated via SQL aggregation queries for each time period.
3.5. If no historical data exists, the UI must show "No data" gracefully.
3.6. New trades must be written to both SQLite (primary) and CSV (backup/compatibility).
### FR4: Real-Time Updates
4.1. The UI must refresh every 1 second.
4.2. Position unrealized PnL must update based on latest price data.
4.3. New log entries must appear in real-time.
4.4. Metrics must recalculate when trades are opened/closed.
### FR5: Log Panel
5.1. The log panel must display log entries with timestamp, level, and message.
5.2. Log entries must be color-coded by level:
- ERROR: Red
- WARNING: Yellow
- INFO: White/Default
- DEBUG: Gray (if shown)
5.3. The log panel must support filtering by log type:
- All logs (default)
- Errors only
- Trades only (entries containing "position", "trade", "order")
- Signals only (entries containing "signal", "z_score", "prob")
5.4. Filter switching must be available via keyboard shortcut (`f` to cycle filters).
### FR6: Keyboard Controls
6.1. `q` or `Ctrl+C` - Graceful shutdown
6.2. `r` - Force refresh data
6.3. `1` - Switch to General tab
6.4. `2` - Switch to Monthly tab
6.5. `3` - Switch to Weekly tab
6.6. `4` - Switch to Daily tab
6.7. `f` - Cycle log filter
6.8. Arrow keys - Scroll logs (if supported)
### FR7: Color Scheme
7.1. Use dark theme as base.
7.2. PnL values must be colored:
- Positive: Green
- Negative: Red
- Zero/Neutral: White
7.3. Position side must be colored:
- Long: Green
- Short: Red
7.4. Use consistent color coding for emphasis and warnings.
## Non-Goals (Out of Scope)
1. **Mouse support** - This is a keyboard-driven terminal UI
2. **Trade execution from UI** - The UI is read-only; trades are executed by the bot
3. **Configuration editing** - Config changes require restarting the bot
4. **Multi-exchange support** - Only OKX is supported
5. **Charts/graphs** - Text-based metrics only (no ASCII charts in v1)
6. **Sound alerts** - No audio notifications
7. **Remote access** - Local terminal only
## Design Considerations
### Technology Choice: Rich
Use the [Rich](https://github.com/Textualize/rich) Python library for terminal UI:
- Rich provides `Live` display for real-time updates
- Rich `Layout` for split-screen design
- Rich `Table` for metrics display
- Rich `Panel` for bordered sections
- Rich `Text` for colored output
Alternative considered: **Textual** (also by Will McGugan) provides more advanced TUI features but adds complexity. Rich is simpler and sufficient for this use case.
### UI Mockup
```
+==============================================================================+
| REGIME REVERSION STRATEGY - LIVE TRADING [DEMO] ETH/USDT |
+==============================================================================+
| [1:General] [2:Monthly] [3:Weekly] [4:Daily] |
+------------------------------------------------------------------------------+
| PERFORMANCE | CURRENT POSITION |
| Total PnL: $1,234.56 | Side: LONG |
| Today PnL: $45.23 | Entry: $3,245.50 |
| Win Rate: 67.5% | Current: $3,289.00 |
| Total Trades: 24 | Unrealized: +$43.50 (+1.34%) |
| Avg Duration: 4.2h | Size: $500.00 |
| Max Drawdown: -$156.00 | SL: $3,050.00 (-6.0%) TP: $3,408.00 (+5%)|
| | |
| ACCOUNT | STRATEGY STATE |
| Balance: $5,432.10 | Z-Score: 1.45 |
| Available: $4,932.10 | Probability: 0.72 |
| Leverage: 2x | Funding: 0.0012 |
+------------------------------------------------------------------------------+
| LOGS [Filter: All] Press 'f' cycle |
+------------------------------------------------------------------------------+
| 14:32:15 [INFO] Trading Cycle Start: 2026-01-16T14:32:15+00:00 |
| 14:32:16 [INFO] Signal: entry long (prob=0.72, z=-1.45, reason=z_score...) |
| 14:32:17 [INFO] Executing LONG entry: 0.1540 ETH @ 3245.50 ($500.00) |
| 14:32:18 [INFO] Position opened: ETH/USDT:USDT_20260116_143217 |
| 14:32:18 [INFO] Portfolio: 1 positions, exposure=$500.00, unrealized=$0.00 |
| 14:32:18 [INFO] --- Cycle completed in 3.2s --- |
| 14:32:18 [INFO] Sleeping for 60 minutes... |
| |
+------------------------------------------------------------------------------+
| [q]Quit [r]Refresh [1-4]Tabs [f]Filter |
+==============================================================================+
```
### File Structure
```
live_trading/
ui/
__init__.py
dashboard.py # Main UI orchestration and threading
panels.py # Panel components (metrics, logs, position)
state.py # Thread-safe shared state
log_handler.py # Custom logging handler for UI queue
keyboard.py # Keyboard input handling
db/
__init__.py
database.py # SQLite connection and queries
models.py # Data models (Trade, DailySummary, Session)
migrations.py # CSV migration and schema setup
metrics.py # Metrics aggregation queries
trading.db # SQLite database file (created at runtime)
```
## Technical Considerations
### Integration with Existing Code
1. **Logging Integration**: Create a custom `logging.Handler` that captures log messages and forwards them to the UI log panel while still writing to file.
2. **Data Access**: The UI needs access to:
- `PositionManager` for current positions
- `TradingConfig` for settings display
- SQLite database for historical metrics
- Real-time data from `DataFeed` for current prices
3. **Main Loop Modification**: The `LiveTradingBot.run()` method needs modification to run the UI in a **separate thread** alongside the trading loop. The UI thread handles rendering and keyboard input while the main thread executes trading logic.
4. **Graceful Shutdown**: Ensure `SIGINT`/`SIGTERM` handlers work with the UI layer and properly terminate the UI thread.
### Database Schema (SQLite)
Create `live_trading/trading.db` with the following schema:
```sql
-- Trade history table
CREATE TABLE trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trade_id TEXT UNIQUE NOT NULL,
symbol TEXT NOT NULL,
side TEXT NOT NULL, -- 'long' or 'short'
entry_price REAL NOT NULL,
exit_price REAL,
size REAL NOT NULL,
size_usdt REAL NOT NULL,
pnl_usd REAL,
pnl_pct REAL,
entry_time TEXT NOT NULL, -- ISO format
exit_time TEXT,
hold_duration_hours REAL,
reason TEXT, -- 'stop_loss', 'take_profit', 'signal', etc.
order_id_entry TEXT,
order_id_exit TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Daily summary table (for faster queries)
CREATE TABLE daily_summary (
id INTEGER PRIMARY KEY AUTOINCREMENT,
date TEXT UNIQUE NOT NULL, -- YYYY-MM-DD
total_trades INTEGER DEFAULT 0,
winning_trades INTEGER DEFAULT 0,
total_pnl_usd REAL DEFAULT 0,
max_drawdown_usd REAL DEFAULT 0,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Session metadata
CREATE TABLE sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
start_time TEXT NOT NULL,
end_time TEXT,
starting_balance REAL,
ending_balance REAL,
total_pnl REAL,
total_trades INTEGER DEFAULT 0
);
-- Indexes for common queries
CREATE INDEX idx_trades_entry_time ON trades(entry_time);
CREATE INDEX idx_trades_exit_time ON trades(exit_time);
CREATE INDEX idx_daily_summary_date ON daily_summary(date);
```
### Migration from CSV
On first run with the new system:
1. Check if `trading.db` exists
2. If not, create database with schema
3. If `trade_log.csv` exists, migrate data to `trades` table
4. Rebuild `daily_summary` from migrated trades
### Dependencies
Add to `pyproject.toml`:
```toml
dependencies = [
# ... existing deps
"rich>=13.0.0",
]
```
Note: SQLite is part of Python's standard library (`sqlite3`), no additional dependency needed.
### Performance Considerations
- UI runs in a separate thread to avoid blocking trading logic
- Log buffer limited to 1000 entries in memory to prevent growth
- SQLite queries should use indexes for fast period-based aggregations
- Historical data loading happens once at startup, incremental updates thereafter
### Threading Model
```
Main Thread UI Thread
| |
v v
[Trading Loop] [Rich Live Display]
| |
+---> SharedState <------------+
(thread-safe)
| |
+---> LogQueue <---------------+
(thread-safe)
```
- Use `threading.Lock` for shared state access
- Use `queue.Queue` for log message passing
- UI thread polls for updates every 1 second
## Success Metrics
1. UI starts successfully and displays all required metrics
2. UI updates in real-time (1-second refresh) without impacting trading performance
3. All keyboard shortcuts function correctly
4. Historical data loads and displays accurately from SQLite
5. Log filtering works as expected
6. UI gracefully handles edge cases (no data, no position, terminal resize)
7. CSV migration completes successfully on first run
8. Database queries complete within 100ms
---
*Generated: 2026-01-16*
*Decisions: Threading model, 1000 log buffer, SQLite database, no fallback mode*

36
uv.lock generated
View File

@@ -981,7 +981,6 @@ dependencies = [
{ name = "plotly" },
{ name = "python-dotenv" },
{ name = "requests" },
{ name = "rich" },
{ name = "scikit-learn" },
{ name = "sqlalchemy" },
{ name = "ta" },
@@ -1005,7 +1004,6 @@ requires-dist = [
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
{ name = "python-dotenv", specifier = ">=1.2.1" },
{ name = "requests", specifier = ">=2.32.5" },
{ name = "rich", specifier = ">=13.0.0" },
{ name = "scikit-learn", specifier = ">=1.6.0" },
{ name = "sqlalchemy", specifier = ">=2.0.0" },
{ name = "ta", specifier = ">=0.11.0" },
@@ -1014,18 +1012,6 @@ requires-dist = [
]
provides-extras = ["dev"]
[[package]]
name = "markdown-it-py"
version = "4.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mdurl" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
]
[[package]]
name = "matplotlib"
version = "3.10.8"
@@ -1092,15 +1078,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" },
]
[[package]]
name = "mdurl"
version = "0.1.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
]
[[package]]
name = "multidict"
version = "6.7.0"
@@ -1935,19 +1912,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
]
[[package]]
name = "rich"
version = "14.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markdown-it-py" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/fb/d2/8920e102050a0de7bfabeb4c4614a49248cf8d5d7a8d01885fbb24dc767a/rich-14.2.0.tar.gz", hash = "sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4", size = 219990, upload-time = "2025-10-09T14:16:53.064Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/25/7a/b0178788f8dc6cafce37a212c99565fa1fe7872c70c6c9c1e1a372d9d88f/rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd", size = 243393, upload-time = "2025-10-09T14:16:51.245Z" },
]
[[package]]
name = "schedule"
version = "1.2.2"