Compare commits
2 Commits
regime-imb
...
regime-imb
| Author | SHA1 | Date | |
|---|---|---|---|
| 1af0aab5fa | |||
| df37366603 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -175,7 +175,3 @@ data/backtest_runs.db
|
|||||||
.gitignore
|
.gitignore
|
||||||
live_trading/regime_model.pkl
|
live_trading/regime_model.pkl
|
||||||
live_trading/positions.json
|
live_trading/positions.json
|
||||||
|
|
||||||
|
|
||||||
*.pkl
|
|
||||||
*.db
|
|
||||||
304
README.md
304
README.md
@@ -1,262 +1,82 @@
|
|||||||
# Lowkey Backtest
|
### lowkey_backtest — Supertrend Backtester
|
||||||
|
|
||||||
A backtesting framework supporting multiple market types (spot, perpetual) with realistic trading simulation including leverage, funding, and shorts.
|
### Overview
|
||||||
|
Backtest a simple, long-only strategy driven by a meta Supertrend signal on aggregated OHLCV data. The script:
|
||||||
## Requirements
|
- Loads 1-minute BTC/USD data from `../data/btcusd_1-min_data.csv`
|
||||||
|
- Aggregates to multiple timeframes (e.g., `5min`, `15min`, `30min`, `1h`, `4h`, `1d`)
|
||||||
|
- Computes three Supertrend variants and creates a meta signal when all agree
|
||||||
|
- Executes entries/exits at the aggregated bar open price
|
||||||
|
- Applies OKX spot fee assumptions (taker by default)
|
||||||
|
- Evaluates stop-loss using intra-bar 1-minute data
|
||||||
|
- Writes detailed trade logs and a summary CSV
|
||||||
|
|
||||||
|
### Requirements
|
||||||
- Python 3.12+
|
- Python 3.12+
|
||||||
- Package manager: `uv`
|
- Dependencies: `pandas`, `numpy`, `ta`
|
||||||
|
- Package management: `uv`
|
||||||
|
|
||||||
## Installation
|
Install dependencies with uv:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv sync
|
uv sync
|
||||||
|
# If a dependency is missing, add it explicitly and sync
|
||||||
|
uv add pandas numpy ta
|
||||||
|
uv sync
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Reference
|
### Data
|
||||||
|
- Expected CSV location: `../data/btcusd_1-min_data.csv` (relative to the repo root)
|
||||||
|
- Required columns: `Timestamp`, `Open`, `High`, `Low`, `Close`, `Volume`
|
||||||
|
- `Timestamp` should be UNIX seconds; zero-volume rows are ignored
|
||||||
|
|
||||||
| Command | Description |
|
### Quickstart
|
||||||
|---------|-------------|
|
Run the backtest with defaults:
|
||||||
| `uv run python main.py download -p BTC-USDT` | Download data |
|
|
||||||
| `uv run python main.py backtest -s meta_st -p BTC-USDT` | Run backtest |
|
|
||||||
| `uv run python main.py wfa -s regime -p BTC-USDT` | Walk-forward analysis |
|
|
||||||
| `uv run python train_model.py --download` | Train/retrain ML model |
|
|
||||||
| `uv run python research/regime_detection.py` | Run research script |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Backtest CLI
|
|
||||||
|
|
||||||
The main entry point is `main.py` which provides three commands: `download`, `backtest`, and `wfa`.
|
|
||||||
|
|
||||||
### Download Data
|
|
||||||
|
|
||||||
Download historical OHLCV data from exchanges.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run python main.py download -p BTC-USDT -t 1h
|
uv run python main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
**Options:**
|
Outputs:
|
||||||
- `-p, --pair` (required): Trading pair (e.g., `BTC-USDT`, `ETH-USDT`)
|
- Per-run trade logs in `backtest_logs/` named like `trade_log_<TIMEFRAME>_sl<STOPLOSS>.csv`
|
||||||
- `-t, --timeframe`: Timeframe (default: `1m`)
|
- Run-level summary in `backtest_summary.csv`
|
||||||
- `-e, --exchange`: Exchange (default: `okx`)
|
|
||||||
- `-m, --market`: Market type: `spot` or `perpetual` (default: `spot`)
|
|
||||||
- `--start`: Start date in `YYYY-MM-DD` format
|
|
||||||
|
|
||||||
**Examples:**
|
### Configuring a Run
|
||||||
```bash
|
Adjust parameters directly in `main.py`:
|
||||||
# Download 1-hour spot data
|
- Date range (in `load_data`): `load_data('2021-11-01', '2024-10-16')`
|
||||||
uv run python main.py download -p ETH-USDT -t 1h
|
- Timeframes to test (any subset of `"5min", "15min", "30min", "1h", "4h", "1d"`):
|
||||||
|
- `timeframes = ["5min", "15min", "30min", "1h", "4h", "1d"]`
|
||||||
|
- Stop-loss percentages:
|
||||||
|
- `stoplosses = [0.03, 0.05, 0.1]`
|
||||||
|
- Supertrend settings (in `add_supertrend_indicators`): `(period, multiplier)` pairs `(12, 3.0)`, `(10, 1.0)`, `(11, 2.0)`
|
||||||
|
- Fee model (in `calculate_okx_taker_maker_fee`): taker `0.0010`, maker `0.0008`
|
||||||
|
|
||||||
# Download perpetual data from a specific date
|
### What the Backtester Does
|
||||||
uv run python main.py download -p BTC-USDT -m perpetual --start 2024-01-01
|
- Aggregation: Resamples 1-minute data to your selected timeframe using OHLCV rules
|
||||||
|
- Supertrend signals: Computes three Supertrends and derives a meta trend of `+1` (bullish) or `-1` (bearish) when all agree; otherwise `0`
|
||||||
|
- Trade logic (long-only):
|
||||||
|
- Entry when the meta trend changes to bullish; uses aggregated bar open price
|
||||||
|
- Exit when the meta trend changes to bearish; uses aggregated bar open price
|
||||||
|
- Stop-loss: For each aggregated bar, scans corresponding 1-minute closes to detect stop-loss and exits using a realistic fill (threshold or next 1-minute open if gapped)
|
||||||
|
- Performance metrics: total return, max drawdown, Sharpe (daily, factor 252), win rate, number of trades, final/initial equity, and total fees
|
||||||
|
|
||||||
|
### Important: Lookahead Bias Note
|
||||||
|
The current implementation uses the meta Supertrend signal of the same bar for entries/exits, which introduces lookahead bias. To avoid this, lag the signal by one bar inside `backtest()` in `main.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Replace the current line
|
||||||
|
meta_trend_signal = meta_trend
|
||||||
|
|
||||||
|
# With a one-bar lag to remove lookahead
|
||||||
|
# meta_trend_signal = np.roll(meta_trend, 1)
|
||||||
|
# meta_trend_signal[0] = 0
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run Backtest
|
### Outputs
|
||||||
|
- `backtest_logs/trade_log_<TIMEFRAME>_sl<STOPLOSS>.csv`: trade-by-trade records including type (`buy`, `sell`, `stop_loss`, `forced_close`), timestamps, prices, balances, PnL, and fees
|
||||||
|
- `backtest_summary.csv`: one row per (timeframe, stop-loss) combination with `timeframe`, `stop_loss`, `total_return`, `max_drawdown`, `sharpe_ratio`, `win_rate`, `num_trades`, `final_equity`, `initial_equity`, `num_stop_losses`, `total_fees`
|
||||||
|
|
||||||
Run a backtest with a specific strategy.
|
### Troubleshooting
|
||||||
|
- CSV not found: Ensure the dataset is located at `../data/btcusd_1-min_data.csv`
|
||||||
|
- Missing packages: Run `uv add pandas numpy ta` then `uv sync`
|
||||||
|
- Memory/performance: Large date ranges on 1-minute data can be heavy; narrow the date span or test fewer timeframes
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run python main.py backtest -s <strategy> -p <pair> [options]
|
|
||||||
```
|
|
||||||
|
|
||||||
**Available Strategies:**
|
|
||||||
- `meta_st` - Meta Supertrend (triple supertrend consensus)
|
|
||||||
- `regime` - Regime Reversion (ML-based spread trading)
|
|
||||||
- `rsi` - RSI overbought/oversold
|
|
||||||
- `macross` - Moving Average Crossover
|
|
||||||
|
|
||||||
**Options:**
|
|
||||||
- `-s, --strategy` (required): Strategy name
|
|
||||||
- `-p, --pair` (required): Trading pair
|
|
||||||
- `-t, --timeframe`: Timeframe (default: `1m`)
|
|
||||||
- `--start`: Start date
|
|
||||||
- `--end`: End date
|
|
||||||
- `-g, --grid`: Run grid search optimization
|
|
||||||
- `--plot`: Show equity curve plot
|
|
||||||
- `--sl`: Stop loss percentage
|
|
||||||
- `--tp`: Take profit percentage
|
|
||||||
- `--trail`: Enable trailing stop
|
|
||||||
- `--fees`: Override fee rate
|
|
||||||
- `--slippage`: Slippage (default: `0.001`)
|
|
||||||
- `-l, --leverage`: Leverage multiplier
|
|
||||||
|
|
||||||
**Examples:**
|
|
||||||
```bash
|
|
||||||
# Basic backtest with Meta Supertrend
|
|
||||||
uv run python main.py backtest -s meta_st -p BTC-USDT -t 1h
|
|
||||||
|
|
||||||
# Backtest with date range and plot
|
|
||||||
uv run python main.py backtest -s meta_st -p BTC-USDT --start 2024-01-01 --end 2024-12-31 --plot
|
|
||||||
|
|
||||||
# Grid search optimization
|
|
||||||
uv run python main.py backtest -s meta_st -p BTC-USDT -t 4h -g
|
|
||||||
|
|
||||||
# Backtest with risk parameters
|
|
||||||
uv run python main.py backtest -s meta_st -p BTC-USDT --sl 0.05 --tp 0.10 --trail
|
|
||||||
|
|
||||||
# Regime strategy on ETH/BTC spread
|
|
||||||
uv run python main.py backtest -s regime -p ETH-USDT -t 1h
|
|
||||||
```
|
|
||||||
|
|
||||||
### Walk-Forward Analysis (WFA)
|
|
||||||
|
|
||||||
Run walk-forward optimization to avoid overfitting.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run python main.py wfa -s <strategy> -p <pair> [options]
|
|
||||||
```
|
|
||||||
|
|
||||||
**Options:**
|
|
||||||
- `-s, --strategy` (required): Strategy name
|
|
||||||
- `-p, --pair` (required): Trading pair
|
|
||||||
- `-t, --timeframe`: Timeframe (default: `1d`)
|
|
||||||
- `-w, --windows`: Number of walk-forward windows (default: `10`)
|
|
||||||
- `--plot`: Show WFA results plot
|
|
||||||
|
|
||||||
**Examples:**
|
|
||||||
```bash
|
|
||||||
# Walk-forward analysis with 10 windows
|
|
||||||
uv run python main.py wfa -s meta_st -p BTC-USDT -t 1d -w 10
|
|
||||||
|
|
||||||
# WFA with plot output
|
|
||||||
uv run python main.py wfa -s regime -p ETH-USDT --plot
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Research Scripts
|
|
||||||
|
|
||||||
Research scripts are located in the `research/` directory for experimental analysis.
|
|
||||||
|
|
||||||
### Regime Detection Research
|
|
||||||
|
|
||||||
Tests multiple holding horizons for the regime reversion strategy using walk-forward training.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run python research/regime_detection.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Options:**
|
|
||||||
- `--days DAYS`: Number of days of historical data (default: 90)
|
|
||||||
- `--start DATE`: Start date (YYYY-MM-DD), overrides `--days`
|
|
||||||
- `--end DATE`: End date (YYYY-MM-DD), defaults to now
|
|
||||||
- `--output PATH`: Output CSV path
|
|
||||||
|
|
||||||
**Examples:**
|
|
||||||
```bash
|
|
||||||
# Use last 90 days (default)
|
|
||||||
uv run python research/regime_detection.py
|
|
||||||
|
|
||||||
# Use last 180 days
|
|
||||||
uv run python research/regime_detection.py --days 180
|
|
||||||
|
|
||||||
# Specific date range
|
|
||||||
uv run python research/regime_detection.py --start 2025-07-01 --end 2025-12-31
|
|
||||||
```
|
|
||||||
|
|
||||||
**What it does:**
|
|
||||||
- Loads BTC and ETH hourly data
|
|
||||||
- Calculates spread features (Z-score, RSI, volume ratios)
|
|
||||||
- Trains RandomForest classifier with walk-forward methodology
|
|
||||||
- Tests horizons from 6h to 150h
|
|
||||||
- Outputs best parameters by F1 score, Net PnL, and MAE
|
|
||||||
|
|
||||||
**Output:**
|
|
||||||
- Console: Summary of results for each horizon
|
|
||||||
- File: `research/horizon_optimization_results.csv`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ML Model Training
|
|
||||||
|
|
||||||
The `regime` strategy uses a RandomForest classifier that can be trained with new data.
|
|
||||||
|
|
||||||
### Train Model
|
|
||||||
|
|
||||||
Train or retrain the ML model with latest data:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run python train_model.py [options]
|
|
||||||
```
|
|
||||||
|
|
||||||
**Options:**
|
|
||||||
- `--days DAYS`: Days of historical data (default: 90)
|
|
||||||
- `--pair PAIR`: Base pair for context (default: BTC-USDT)
|
|
||||||
- `--spread-pair PAIR`: Trading pair (default: ETH-USDT)
|
|
||||||
- `--timeframe TF`: Timeframe (default: 1h)
|
|
||||||
- `--market TYPE`: Market type: `spot` or `perpetual` (default: perpetual)
|
|
||||||
- `--output PATH`: Model output path (default: `data/regime_model.pkl`)
|
|
||||||
- `--train-ratio R`: Train/test split ratio (default: 0.7)
|
|
||||||
- `--horizon H`: Prediction horizon in bars (default: 102)
|
|
||||||
- `--download`: Download latest data before training
|
|
||||||
- `--dry-run`: Run without saving model
|
|
||||||
|
|
||||||
**Examples:**
|
|
||||||
```bash
|
|
||||||
# Train with last 90 days of data
|
|
||||||
uv run python train_model.py
|
|
||||||
|
|
||||||
# Download fresh data and train
|
|
||||||
uv run python train_model.py --download
|
|
||||||
|
|
||||||
# Train with 180 days of data
|
|
||||||
uv run python train_model.py --days 180
|
|
||||||
|
|
||||||
# Train on spot market data
|
|
||||||
uv run python train_model.py --market spot
|
|
||||||
|
|
||||||
# Dry run to see metrics without saving
|
|
||||||
uv run python train_model.py --dry-run
|
|
||||||
```
|
|
||||||
|
|
||||||
### Daily Retraining (Cron)
|
|
||||||
|
|
||||||
To automate daily model retraining, add a cron job:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Edit crontab
|
|
||||||
crontab -e
|
|
||||||
|
|
||||||
# Add entry to retrain daily at 00:30 UTC
|
|
||||||
30 0 * * * cd /path/to/lowkey_backtest_live && uv run python train_model.py --download >> logs/training.log 2>&1
|
|
||||||
```
|
|
||||||
|
|
||||||
### Model Files
|
|
||||||
|
|
||||||
| File | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `data/regime_model.pkl` | Current production model |
|
|
||||||
| `data/regime_model_YYYYMMDD_HHMMSS.pkl` | Versioned model snapshots |
|
|
||||||
|
|
||||||
The model file contains:
|
|
||||||
- Trained RandomForest classifier
|
|
||||||
- Feature column names
|
|
||||||
- Training metrics (F1 score, sample counts)
|
|
||||||
- Training timestamp
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Output Files
|
|
||||||
|
|
||||||
| Location | Description |
|
|
||||||
|----------|-------------|
|
|
||||||
| `backtest_logs/` | Trade logs and WFA results |
|
|
||||||
| `research/` | Research output files |
|
|
||||||
| `data/` | Downloaded OHLCV data and ML models |
|
|
||||||
| `data/regime_model.pkl` | Trained ML model for regime strategy |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Running Tests
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run pytest tests/
|
|
||||||
```
|
|
||||||
|
|
||||||
Run a specific test file:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run pytest tests/test_data_manager.py
|
|
||||||
```
|
|
||||||
|
|||||||
98
check_demo_account.py
Normal file
98
check_demo_account.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Check OKX demo account positions and recent orders.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python check_demo_account.py
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from live_trading.config import OKXConfig
|
||||||
|
import ccxt
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Check demo account status."""
|
||||||
|
config = OKXConfig()
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f" OKX Demo Account Check")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f" Demo Mode: {config.demo_mode}")
|
||||||
|
print(f" API Key: {config.api_key[:8]}..." if config.api_key else " API Key: NOT SET")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
exchange = ccxt.okx({
|
||||||
|
'apiKey': config.api_key,
|
||||||
|
'secret': config.secret,
|
||||||
|
'password': config.password,
|
||||||
|
'sandbox': config.demo_mode,
|
||||||
|
'options': {'defaultType': 'swap'},
|
||||||
|
'enableRateLimit': True,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check balance
|
||||||
|
print("--- BALANCE ---")
|
||||||
|
balance = exchange.fetch_balance()
|
||||||
|
usdt = balance.get('USDT', {})
|
||||||
|
print(f"USDT Total: {usdt.get('total', 0):.2f}")
|
||||||
|
print(f"USDT Free: {usdt.get('free', 0):.2f}")
|
||||||
|
print(f"USDT Used: {usdt.get('used', 0):.2f}")
|
||||||
|
|
||||||
|
# Check all balances
|
||||||
|
print("\n--- ALL NON-ZERO BALANCES ---")
|
||||||
|
for currency, data in balance.items():
|
||||||
|
if isinstance(data, dict) and data.get('total', 0) > 0:
|
||||||
|
print(f"{currency}: total={data.get('total', 0):.6f}, free={data.get('free', 0):.6f}")
|
||||||
|
|
||||||
|
# Check open positions
|
||||||
|
print("\n--- OPEN POSITIONS ---")
|
||||||
|
positions = exchange.fetch_positions()
|
||||||
|
open_positions = [p for p in positions if abs(float(p.get('contracts', 0))) > 0]
|
||||||
|
|
||||||
|
if open_positions:
|
||||||
|
for pos in open_positions:
|
||||||
|
print(f" {pos['symbol']}: {pos['side']} {pos['contracts']} contracts @ {pos.get('entryPrice', 'N/A')}")
|
||||||
|
print(f" Unrealized PnL: {pos.get('unrealizedPnl', 'N/A')}")
|
||||||
|
else:
|
||||||
|
print(" No open positions")
|
||||||
|
|
||||||
|
# Check recent orders (last 50)
|
||||||
|
print("\n--- RECENT ORDERS (last 24h) ---")
|
||||||
|
try:
|
||||||
|
# Fetch closed orders for AVAX
|
||||||
|
orders = exchange.fetch_orders('AVAX/USDT:USDT', limit=20)
|
||||||
|
if orders:
|
||||||
|
for order in orders[-10:]: # Last 10
|
||||||
|
ts = datetime.fromtimestamp(order['timestamp']/1000, tz=timezone.utc)
|
||||||
|
print(f" [{ts.strftime('%H:%M:%S')}] {order['side'].upper()} {order['amount']} AVAX @ {order.get('average', order.get('price', 'market'))}")
|
||||||
|
print(f" Status: {order['status']}, Filled: {order.get('filled', 0)}, ID: {order['id']}")
|
||||||
|
else:
|
||||||
|
print(" No recent AVAX orders")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Could not fetch orders: {e}")
|
||||||
|
|
||||||
|
# Check order history more broadly
|
||||||
|
print("\n--- ORDER HISTORY (AVAX) ---")
|
||||||
|
try:
|
||||||
|
# Try fetching my trades
|
||||||
|
trades = exchange.fetch_my_trades('AVAX/USDT:USDT', limit=10)
|
||||||
|
if trades:
|
||||||
|
for trade in trades[-5:]:
|
||||||
|
ts = datetime.fromtimestamp(trade['timestamp']/1000, tz=timezone.utc)
|
||||||
|
print(f" [{ts.strftime('%Y-%m-%d %H:%M:%S')}] {trade['side'].upper()} {trade['amount']} @ {trade['price']}")
|
||||||
|
print(f" Fee: {trade.get('fee', {}).get('cost', 'N/A')} {trade.get('fee', {}).get('currency', '')}")
|
||||||
|
else:
|
||||||
|
print(" No recent AVAX trades")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Could not fetch trades: {e}")
|
||||||
|
|
||||||
|
print(f"\n{'='*60}\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
data/multi_pair_model.pkl
Normal file
BIN
data/multi_pair_model.pkl
Normal file
Binary file not shown.
@@ -60,7 +60,7 @@ class TradingConfig:
|
|||||||
|
|
||||||
# Position sizing
|
# Position sizing
|
||||||
max_position_usdt: float = -1.0 # Max position size in USDT. If <= 0, use all available funds
|
max_position_usdt: float = -1.0 # Max position size in USDT. If <= 0, use all available funds
|
||||||
min_position_usdt: float = 1.0 # Min position size in USDT
|
min_position_usdt: float = 10.0 # Min position size in USDT
|
||||||
leverage: int = 1 # Leverage (1x = no leverage)
|
leverage: int = 1 # Leverage (1x = no leverage)
|
||||||
margin_mode: str = "cross" # "cross" or "isolated"
|
margin_mode: str = "cross" # "cross" or "isolated"
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
"""Database module for live trading persistence."""
|
|
||||||
from .database import get_db, init_db
|
|
||||||
from .models import Trade, DailySummary, Session
|
|
||||||
from .metrics import MetricsCalculator
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"get_db",
|
|
||||||
"init_db",
|
|
||||||
"Trade",
|
|
||||||
"DailySummary",
|
|
||||||
"Session",
|
|
||||||
"MetricsCalculator",
|
|
||||||
]
|
|
||||||
@@ -1,325 +0,0 @@
|
|||||||
"""SQLite database connection and operations."""
|
|
||||||
import sqlite3
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
from .models import Trade, DailySummary, Session
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Database schema
|
|
||||||
SCHEMA = """
|
|
||||||
-- Trade history table
|
|
||||||
CREATE TABLE IF NOT EXISTS trades (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
trade_id TEXT UNIQUE NOT NULL,
|
|
||||||
symbol TEXT NOT NULL,
|
|
||||||
side TEXT NOT NULL,
|
|
||||||
entry_price REAL NOT NULL,
|
|
||||||
exit_price REAL,
|
|
||||||
size REAL NOT NULL,
|
|
||||||
size_usdt REAL NOT NULL,
|
|
||||||
pnl_usd REAL,
|
|
||||||
pnl_pct REAL,
|
|
||||||
entry_time TEXT NOT NULL,
|
|
||||||
exit_time TEXT,
|
|
||||||
hold_duration_hours REAL,
|
|
||||||
reason TEXT,
|
|
||||||
order_id_entry TEXT,
|
|
||||||
order_id_exit TEXT,
|
|
||||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
||||||
);
|
|
||||||
|
|
||||||
-- Daily summary table
|
|
||||||
CREATE TABLE IF NOT EXISTS daily_summary (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
date TEXT UNIQUE NOT NULL,
|
|
||||||
total_trades INTEGER DEFAULT 0,
|
|
||||||
winning_trades INTEGER DEFAULT 0,
|
|
||||||
total_pnl_usd REAL DEFAULT 0,
|
|
||||||
max_drawdown_usd REAL DEFAULT 0,
|
|
||||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
||||||
);
|
|
||||||
|
|
||||||
-- Session metadata
|
|
||||||
CREATE TABLE IF NOT EXISTS sessions (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
start_time TEXT NOT NULL,
|
|
||||||
end_time TEXT,
|
|
||||||
starting_balance REAL,
|
|
||||||
ending_balance REAL,
|
|
||||||
total_pnl REAL,
|
|
||||||
total_trades INTEGER DEFAULT 0
|
|
||||||
);
|
|
||||||
|
|
||||||
-- Indexes for common queries
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades(entry_time);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_trades_exit_time ON trades(exit_time);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_daily_summary_date ON daily_summary(date);
|
|
||||||
"""
|
|
||||||
|
|
||||||
_db_instance: Optional["TradingDatabase"] = None
|
|
||||||
|
|
||||||
|
|
||||||
class TradingDatabase:
|
|
||||||
"""SQLite database for trade persistence."""
|
|
||||||
|
|
||||||
def __init__(self, db_path: Path):
|
|
||||||
self.db_path = db_path
|
|
||||||
self._connection: Optional[sqlite3.Connection] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def connection(self) -> sqlite3.Connection:
|
|
||||||
"""Get or create database connection."""
|
|
||||||
if self._connection is None:
|
|
||||||
self._connection = sqlite3.connect(
|
|
||||||
str(self.db_path),
|
|
||||||
check_same_thread=False,
|
|
||||||
)
|
|
||||||
self._connection.row_factory = sqlite3.Row
|
|
||||||
return self._connection
|
|
||||||
|
|
||||||
def init_schema(self) -> None:
|
|
||||||
"""Initialize database schema."""
|
|
||||||
with self.connection:
|
|
||||||
self.connection.executescript(SCHEMA)
|
|
||||||
logger.info(f"Database initialized at {self.db_path}")
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Close database connection."""
|
|
||||||
if self._connection:
|
|
||||||
self._connection.close()
|
|
||||||
self._connection = None
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def transaction(self):
|
|
||||||
"""Context manager for database transactions."""
|
|
||||||
try:
|
|
||||||
yield self.connection
|
|
||||||
self.connection.commit()
|
|
||||||
except Exception:
|
|
||||||
self.connection.rollback()
|
|
||||||
raise
|
|
||||||
|
|
||||||
def insert_trade(self, trade: Trade) -> int:
|
|
||||||
"""
|
|
||||||
Insert a new trade record.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trade: Trade object to insert
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Row ID of inserted trade
|
|
||||||
"""
|
|
||||||
sql = """
|
|
||||||
INSERT INTO trades (
|
|
||||||
trade_id, symbol, side, entry_price, exit_price,
|
|
||||||
size, size_usdt, pnl_usd, pnl_pct, entry_time,
|
|
||||||
exit_time, hold_duration_hours, reason,
|
|
||||||
order_id_entry, order_id_exit
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
"""
|
|
||||||
with self.transaction():
|
|
||||||
cursor = self.connection.execute(
|
|
||||||
sql,
|
|
||||||
(
|
|
||||||
trade.trade_id,
|
|
||||||
trade.symbol,
|
|
||||||
trade.side,
|
|
||||||
trade.entry_price,
|
|
||||||
trade.exit_price,
|
|
||||||
trade.size,
|
|
||||||
trade.size_usdt,
|
|
||||||
trade.pnl_usd,
|
|
||||||
trade.pnl_pct,
|
|
||||||
trade.entry_time,
|
|
||||||
trade.exit_time,
|
|
||||||
trade.hold_duration_hours,
|
|
||||||
trade.reason,
|
|
||||||
trade.order_id_entry,
|
|
||||||
trade.order_id_exit,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return cursor.lastrowid
|
|
||||||
|
|
||||||
def update_trade(self, trade_id: str, **kwargs) -> bool:
|
|
||||||
"""
|
|
||||||
Update an existing trade record.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trade_id: Trade ID to update
|
|
||||||
**kwargs: Fields to update
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if trade was updated
|
|
||||||
"""
|
|
||||||
if not kwargs:
|
|
||||||
return False
|
|
||||||
|
|
||||||
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
|
|
||||||
sql = f"UPDATE trades SET {set_clause} WHERE trade_id = ?"
|
|
||||||
|
|
||||||
with self.transaction():
|
|
||||||
cursor = self.connection.execute(
|
|
||||||
sql, (*kwargs.values(), trade_id)
|
|
||||||
)
|
|
||||||
return cursor.rowcount > 0
|
|
||||||
|
|
||||||
def get_trade(self, trade_id: str) -> Optional[Trade]:
|
|
||||||
"""Get a trade by ID."""
|
|
||||||
sql = "SELECT * FROM trades WHERE trade_id = ?"
|
|
||||||
row = self.connection.execute(sql, (trade_id,)).fetchone()
|
|
||||||
if row:
|
|
||||||
return Trade(**dict(row))
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_trades(
|
|
||||||
self,
|
|
||||||
start_time: Optional[str] = None,
|
|
||||||
end_time: Optional[str] = None,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
) -> list[Trade]:
|
|
||||||
"""
|
|
||||||
Get trades within a time range.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start_time: ISO format start time filter
|
|
||||||
end_time: ISO format end time filter
|
|
||||||
limit: Maximum number of trades to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Trade objects
|
|
||||||
"""
|
|
||||||
conditions = []
|
|
||||||
params = []
|
|
||||||
|
|
||||||
if start_time:
|
|
||||||
conditions.append("entry_time >= ?")
|
|
||||||
params.append(start_time)
|
|
||||||
if end_time:
|
|
||||||
conditions.append("entry_time <= ?")
|
|
||||||
params.append(end_time)
|
|
||||||
|
|
||||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
|
||||||
limit_clause = f"LIMIT {limit}" if limit else ""
|
|
||||||
|
|
||||||
sql = f"""
|
|
||||||
SELECT * FROM trades
|
|
||||||
WHERE {where_clause}
|
|
||||||
ORDER BY entry_time DESC
|
|
||||||
{limit_clause}
|
|
||||||
"""
|
|
||||||
|
|
||||||
rows = self.connection.execute(sql, params).fetchall()
|
|
||||||
return [Trade(**dict(row)) for row in rows]
|
|
||||||
|
|
||||||
def get_all_trades(self) -> list[Trade]:
|
|
||||||
"""Get all trades."""
|
|
||||||
sql = "SELECT * FROM trades ORDER BY entry_time DESC"
|
|
||||||
rows = self.connection.execute(sql).fetchall()
|
|
||||||
return [Trade(**dict(row)) for row in rows]
|
|
||||||
|
|
||||||
def count_trades(self) -> int:
|
|
||||||
"""Get total number of trades."""
|
|
||||||
sql = "SELECT COUNT(*) FROM trades WHERE exit_time IS NOT NULL"
|
|
||||||
return self.connection.execute(sql).fetchone()[0]
|
|
||||||
|
|
||||||
def upsert_daily_summary(self, summary: DailySummary) -> None:
|
|
||||||
"""Insert or update daily summary."""
|
|
||||||
sql = """
|
|
||||||
INSERT INTO daily_summary (
|
|
||||||
date, total_trades, winning_trades, total_pnl_usd, max_drawdown_usd
|
|
||||||
) VALUES (?, ?, ?, ?, ?)
|
|
||||||
ON CONFLICT(date) DO UPDATE SET
|
|
||||||
total_trades = excluded.total_trades,
|
|
||||||
winning_trades = excluded.winning_trades,
|
|
||||||
total_pnl_usd = excluded.total_pnl_usd,
|
|
||||||
max_drawdown_usd = excluded.max_drawdown_usd,
|
|
||||||
updated_at = CURRENT_TIMESTAMP
|
|
||||||
"""
|
|
||||||
with self.transaction():
|
|
||||||
self.connection.execute(
|
|
||||||
sql,
|
|
||||||
(
|
|
||||||
summary.date,
|
|
||||||
summary.total_trades,
|
|
||||||
summary.winning_trades,
|
|
||||||
summary.total_pnl_usd,
|
|
||||||
summary.max_drawdown_usd,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_daily_summary(self, date: str) -> Optional[DailySummary]:
|
|
||||||
"""Get daily summary for a specific date."""
|
|
||||||
sql = "SELECT * FROM daily_summary WHERE date = ?"
|
|
||||||
row = self.connection.execute(sql, (date,)).fetchone()
|
|
||||||
if row:
|
|
||||||
return DailySummary(**dict(row))
|
|
||||||
return None
|
|
||||||
|
|
||||||
def insert_session(self, session: Session) -> int:
|
|
||||||
"""Insert a new session record."""
|
|
||||||
sql = """
|
|
||||||
INSERT INTO sessions (
|
|
||||||
start_time, end_time, starting_balance,
|
|
||||||
ending_balance, total_pnl, total_trades
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?)
|
|
||||||
"""
|
|
||||||
with self.transaction():
|
|
||||||
cursor = self.connection.execute(
|
|
||||||
sql,
|
|
||||||
(
|
|
||||||
session.start_time,
|
|
||||||
session.end_time,
|
|
||||||
session.starting_balance,
|
|
||||||
session.ending_balance,
|
|
||||||
session.total_pnl,
|
|
||||||
session.total_trades,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return cursor.lastrowid
|
|
||||||
|
|
||||||
def update_session(self, session_id: int, **kwargs) -> bool:
|
|
||||||
"""Update an existing session."""
|
|
||||||
if not kwargs:
|
|
||||||
return False
|
|
||||||
|
|
||||||
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
|
|
||||||
sql = f"UPDATE sessions SET {set_clause} WHERE id = ?"
|
|
||||||
|
|
||||||
with self.transaction():
|
|
||||||
cursor = self.connection.execute(
|
|
||||||
sql, (*kwargs.values(), session_id)
|
|
||||||
)
|
|
||||||
return cursor.rowcount > 0
|
|
||||||
|
|
||||||
def get_latest_session(self) -> Optional[Session]:
|
|
||||||
"""Get the most recent session."""
|
|
||||||
sql = "SELECT * FROM sessions ORDER BY id DESC LIMIT 1"
|
|
||||||
row = self.connection.execute(sql).fetchone()
|
|
||||||
if row:
|
|
||||||
return Session(**dict(row))
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def init_db(db_path: Path) -> TradingDatabase:
|
|
||||||
"""
|
|
||||||
Initialize the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_path: Path to the SQLite database file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TradingDatabase instance
|
|
||||||
"""
|
|
||||||
global _db_instance
|
|
||||||
_db_instance = TradingDatabase(db_path)
|
|
||||||
_db_instance.init_schema()
|
|
||||||
return _db_instance
|
|
||||||
|
|
||||||
|
|
||||||
def get_db() -> Optional[TradingDatabase]:
|
|
||||||
"""Get the global database instance."""
|
|
||||||
return _db_instance
|
|
||||||
@@ -1,235 +0,0 @@
|
|||||||
"""Metrics calculation from trade database."""
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timezone, timedelta
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from .database import TradingDatabase
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PeriodMetrics:
|
|
||||||
"""Trading metrics for a time period."""
|
|
||||||
|
|
||||||
period_name: str
|
|
||||||
start_time: Optional[str]
|
|
||||||
end_time: Optional[str]
|
|
||||||
total_pnl: float = 0.0
|
|
||||||
total_trades: int = 0
|
|
||||||
winning_trades: int = 0
|
|
||||||
losing_trades: int = 0
|
|
||||||
win_rate: float = 0.0
|
|
||||||
avg_trade_duration_hours: float = 0.0
|
|
||||||
max_drawdown: float = 0.0
|
|
||||||
max_drawdown_pct: float = 0.0
|
|
||||||
best_trade: float = 0.0
|
|
||||||
worst_trade: float = 0.0
|
|
||||||
avg_win: float = 0.0
|
|
||||||
avg_loss: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
class MetricsCalculator:
|
|
||||||
"""Calculate trading metrics from database."""
|
|
||||||
|
|
||||||
def __init__(self, db: TradingDatabase):
|
|
||||||
self.db = db
|
|
||||||
|
|
||||||
def get_all_time_metrics(self) -> PeriodMetrics:
|
|
||||||
"""Get metrics for all trades ever."""
|
|
||||||
return self._calculate_metrics("All Time", None, None)
|
|
||||||
|
|
||||||
def get_monthly_metrics(self) -> PeriodMetrics:
|
|
||||||
"""Get metrics for current month."""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
return self._calculate_metrics(
|
|
||||||
"Monthly",
|
|
||||||
start.isoformat(),
|
|
||||||
now.isoformat(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weekly_metrics(self) -> PeriodMetrics:
|
|
||||||
"""Get metrics for current week (Monday to now)."""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
days_since_monday = now.weekday()
|
|
||||||
start = now - timedelta(days=days_since_monday)
|
|
||||||
start = start.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
return self._calculate_metrics(
|
|
||||||
"Weekly",
|
|
||||||
start.isoformat(),
|
|
||||||
now.isoformat(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_daily_metrics(self) -> PeriodMetrics:
|
|
||||||
"""Get metrics for today (UTC)."""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
return self._calculate_metrics(
|
|
||||||
"Daily",
|
|
||||||
start.isoformat(),
|
|
||||||
now.isoformat(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _calculate_metrics(
|
|
||||||
self,
|
|
||||||
period_name: str,
|
|
||||||
start_time: Optional[str],
|
|
||||||
end_time: Optional[str],
|
|
||||||
) -> PeriodMetrics:
|
|
||||||
"""
|
|
||||||
Calculate metrics for a time period.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
period_name: Name of the period
|
|
||||||
start_time: ISO format start time (None for all time)
|
|
||||||
end_time: ISO format end time (None for all time)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PeriodMetrics object
|
|
||||||
"""
|
|
||||||
metrics = PeriodMetrics(
|
|
||||||
period_name=period_name,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build query conditions
|
|
||||||
conditions = ["exit_time IS NOT NULL"]
|
|
||||||
params = []
|
|
||||||
|
|
||||||
if start_time:
|
|
||||||
conditions.append("exit_time >= ?")
|
|
||||||
params.append(start_time)
|
|
||||||
if end_time:
|
|
||||||
conditions.append("exit_time <= ?")
|
|
||||||
params.append(end_time)
|
|
||||||
|
|
||||||
where_clause = " AND ".join(conditions)
|
|
||||||
|
|
||||||
# Get aggregate metrics
|
|
||||||
sql = f"""
|
|
||||||
SELECT
|
|
||||||
COUNT(*) as total_trades,
|
|
||||||
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
|
|
||||||
SUM(CASE WHEN pnl_usd < 0 THEN 1 ELSE 0 END) as losing_trades,
|
|
||||||
COALESCE(SUM(pnl_usd), 0) as total_pnl,
|
|
||||||
COALESCE(AVG(hold_duration_hours), 0) as avg_duration,
|
|
||||||
COALESCE(MAX(pnl_usd), 0) as best_trade,
|
|
||||||
COALESCE(MIN(pnl_usd), 0) as worst_trade,
|
|
||||||
COALESCE(AVG(CASE WHEN pnl_usd > 0 THEN pnl_usd END), 0) as avg_win,
|
|
||||||
COALESCE(AVG(CASE WHEN pnl_usd < 0 THEN pnl_usd END), 0) as avg_loss
|
|
||||||
FROM trades
|
|
||||||
WHERE {where_clause}
|
|
||||||
"""
|
|
||||||
|
|
||||||
row = self.db.connection.execute(sql, params).fetchone()
|
|
||||||
|
|
||||||
if row and row["total_trades"] > 0:
|
|
||||||
metrics.total_trades = row["total_trades"]
|
|
||||||
metrics.winning_trades = row["winning_trades"] or 0
|
|
||||||
metrics.losing_trades = row["losing_trades"] or 0
|
|
||||||
metrics.total_pnl = row["total_pnl"]
|
|
||||||
metrics.avg_trade_duration_hours = row["avg_duration"]
|
|
||||||
metrics.best_trade = row["best_trade"]
|
|
||||||
metrics.worst_trade = row["worst_trade"]
|
|
||||||
metrics.avg_win = row["avg_win"]
|
|
||||||
metrics.avg_loss = row["avg_loss"]
|
|
||||||
|
|
||||||
if metrics.total_trades > 0:
|
|
||||||
metrics.win_rate = (
|
|
||||||
metrics.winning_trades / metrics.total_trades * 100
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate max drawdown
|
|
||||||
metrics.max_drawdown = self._calculate_max_drawdown(
|
|
||||||
start_time, end_time
|
|
||||||
)
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
def _calculate_max_drawdown(
|
|
||||||
self,
|
|
||||||
start_time: Optional[str],
|
|
||||||
end_time: Optional[str],
|
|
||||||
) -> float:
|
|
||||||
"""Calculate maximum drawdown for a period."""
|
|
||||||
conditions = ["exit_time IS NOT NULL"]
|
|
||||||
params = []
|
|
||||||
|
|
||||||
if start_time:
|
|
||||||
conditions.append("exit_time >= ?")
|
|
||||||
params.append(start_time)
|
|
||||||
if end_time:
|
|
||||||
conditions.append("exit_time <= ?")
|
|
||||||
params.append(end_time)
|
|
||||||
|
|
||||||
where_clause = " AND ".join(conditions)
|
|
||||||
|
|
||||||
sql = f"""
|
|
||||||
SELECT pnl_usd
|
|
||||||
FROM trades
|
|
||||||
WHERE {where_clause}
|
|
||||||
ORDER BY exit_time
|
|
||||||
"""
|
|
||||||
|
|
||||||
rows = self.db.connection.execute(sql, params).fetchall()
|
|
||||||
|
|
||||||
if not rows:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
cumulative = 0.0
|
|
||||||
peak = 0.0
|
|
||||||
max_drawdown = 0.0
|
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
pnl = row["pnl_usd"] or 0.0
|
|
||||||
cumulative += pnl
|
|
||||||
peak = max(peak, cumulative)
|
|
||||||
drawdown = peak - cumulative
|
|
||||||
max_drawdown = max(max_drawdown, drawdown)
|
|
||||||
|
|
||||||
return max_drawdown
|
|
||||||
|
|
||||||
def has_monthly_data(self) -> bool:
|
|
||||||
"""Check if we have data spanning more than current month."""
|
|
||||||
sql = """
|
|
||||||
SELECT MIN(exit_time) as first_trade
|
|
||||||
FROM trades
|
|
||||||
WHERE exit_time IS NOT NULL
|
|
||||||
"""
|
|
||||||
row = self.db.connection.execute(sql).fetchone()
|
|
||||||
if not row or not row["first_trade"]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
first_trade = datetime.fromisoformat(row["first_trade"])
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
month_start = now.replace(day=1, hour=0, minute=0, second=0)
|
|
||||||
|
|
||||||
return first_trade < month_start
|
|
||||||
|
|
||||||
def has_weekly_data(self) -> bool:
|
|
||||||
"""Check if we have data spanning more than current week."""
|
|
||||||
sql = """
|
|
||||||
SELECT MIN(exit_time) as first_trade
|
|
||||||
FROM trades
|
|
||||||
WHERE exit_time IS NOT NULL
|
|
||||||
"""
|
|
||||||
row = self.db.connection.execute(sql).fetchone()
|
|
||||||
if not row or not row["first_trade"]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
first_trade = datetime.fromisoformat(row["first_trade"])
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
days_since_monday = now.weekday()
|
|
||||||
week_start = now - timedelta(days=days_since_monday)
|
|
||||||
week_start = week_start.replace(hour=0, minute=0, second=0)
|
|
||||||
|
|
||||||
return first_trade < week_start
|
|
||||||
|
|
||||||
def get_session_start_balance(self) -> Optional[float]:
|
|
||||||
"""Get starting balance from latest session."""
|
|
||||||
sql = "SELECT starting_balance FROM sessions ORDER BY id DESC LIMIT 1"
|
|
||||||
row = self.db.connection.execute(sql).fetchone()
|
|
||||||
return row["starting_balance"] if row else None
|
|
||||||
@@ -1,191 +0,0 @@
|
|||||||
"""Database migrations and CSV import."""
|
|
||||||
import csv
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from .database import TradingDatabase
|
|
||||||
from .models import Trade, DailySummary
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_csv_to_db(db: TradingDatabase, csv_path: Path) -> int:
|
|
||||||
"""
|
|
||||||
Migrate trades from CSV file to SQLite database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: TradingDatabase instance
|
|
||||||
csv_path: Path to trade_log.csv
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of trades migrated
|
|
||||||
"""
|
|
||||||
if not csv_path.exists():
|
|
||||||
logger.info("No CSV file to migrate")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Check if database already has trades
|
|
||||||
existing_count = db.count_trades()
|
|
||||||
if existing_count > 0:
|
|
||||||
logger.info(
|
|
||||||
f"Database already has {existing_count} trades, skipping migration"
|
|
||||||
)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
migrated = 0
|
|
||||||
try:
|
|
||||||
with open(csv_path, "r", newline="") as f:
|
|
||||||
reader = csv.DictReader(f)
|
|
||||||
|
|
||||||
for row in reader:
|
|
||||||
trade = _csv_row_to_trade(row)
|
|
||||||
if trade:
|
|
||||||
try:
|
|
||||||
db.insert_trade(trade)
|
|
||||||
migrated += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to migrate trade {row.get('trade_id')}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Migrated {migrated} trades from CSV to database")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"CSV migration failed: {e}")
|
|
||||||
|
|
||||||
return migrated
|
|
||||||
|
|
||||||
|
|
||||||
def _csv_row_to_trade(row: dict) -> Trade | None:
|
|
||||||
"""Convert a CSV row to a Trade object."""
|
|
||||||
try:
|
|
||||||
return Trade(
|
|
||||||
trade_id=row["trade_id"],
|
|
||||||
symbol=row["symbol"],
|
|
||||||
side=row["side"],
|
|
||||||
entry_price=float(row["entry_price"]),
|
|
||||||
exit_price=_safe_float(row.get("exit_price")),
|
|
||||||
size=float(row["size"]),
|
|
||||||
size_usdt=float(row["size_usdt"]),
|
|
||||||
pnl_usd=_safe_float(row.get("pnl_usd")),
|
|
||||||
pnl_pct=_safe_float(row.get("pnl_pct")),
|
|
||||||
entry_time=row["entry_time"],
|
|
||||||
exit_time=row.get("exit_time") or None,
|
|
||||||
hold_duration_hours=_safe_float(row.get("hold_duration_hours")),
|
|
||||||
reason=row.get("reason") or None,
|
|
||||||
order_id_entry=row.get("order_id_entry") or None,
|
|
||||||
order_id_exit=row.get("order_id_exit") or None,
|
|
||||||
)
|
|
||||||
except (KeyError, ValueError) as e:
|
|
||||||
logger.warning(f"Invalid CSV row: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_float(value: str | None) -> float | None:
|
|
||||||
"""Safely convert string to float."""
|
|
||||||
if value is None or value == "":
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return float(value)
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def rebuild_daily_summaries(db: TradingDatabase) -> int:
|
|
||||||
"""
|
|
||||||
Rebuild daily summary table from trades.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: TradingDatabase instance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of daily summaries created
|
|
||||||
"""
|
|
||||||
sql = """
|
|
||||||
SELECT
|
|
||||||
DATE(exit_time) as date,
|
|
||||||
COUNT(*) as total_trades,
|
|
||||||
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
|
|
||||||
SUM(pnl_usd) as total_pnl_usd
|
|
||||||
FROM trades
|
|
||||||
WHERE exit_time IS NOT NULL
|
|
||||||
GROUP BY DATE(exit_time)
|
|
||||||
ORDER BY date
|
|
||||||
"""
|
|
||||||
|
|
||||||
rows = db.connection.execute(sql).fetchall()
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
summary = DailySummary(
|
|
||||||
date=row["date"],
|
|
||||||
total_trades=row["total_trades"],
|
|
||||||
winning_trades=row["winning_trades"],
|
|
||||||
total_pnl_usd=row["total_pnl_usd"] or 0.0,
|
|
||||||
max_drawdown_usd=0.0, # Calculated separately
|
|
||||||
)
|
|
||||||
db.upsert_daily_summary(summary)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
# Calculate max drawdowns
|
|
||||||
_calculate_daily_drawdowns(db)
|
|
||||||
|
|
||||||
logger.info(f"Rebuilt {count} daily summaries")
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
|
||||||
def _calculate_daily_drawdowns(db: TradingDatabase) -> None:
|
|
||||||
"""Calculate and update max drawdown for each day."""
|
|
||||||
sql = """
|
|
||||||
SELECT trade_id, DATE(exit_time) as date, pnl_usd
|
|
||||||
FROM trades
|
|
||||||
WHERE exit_time IS NOT NULL
|
|
||||||
ORDER BY exit_time
|
|
||||||
"""
|
|
||||||
|
|
||||||
rows = db.connection.execute(sql).fetchall()
|
|
||||||
|
|
||||||
# Track cumulative PnL and drawdown per day
|
|
||||||
daily_drawdowns: dict[str, float] = {}
|
|
||||||
cumulative_pnl = 0.0
|
|
||||||
peak_pnl = 0.0
|
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
date = row["date"]
|
|
||||||
pnl = row["pnl_usd"] or 0.0
|
|
||||||
|
|
||||||
cumulative_pnl += pnl
|
|
||||||
peak_pnl = max(peak_pnl, cumulative_pnl)
|
|
||||||
drawdown = peak_pnl - cumulative_pnl
|
|
||||||
|
|
||||||
if date not in daily_drawdowns:
|
|
||||||
daily_drawdowns[date] = 0.0
|
|
||||||
daily_drawdowns[date] = max(daily_drawdowns[date], drawdown)
|
|
||||||
|
|
||||||
# Update daily summaries with drawdown
|
|
||||||
for date, drawdown in daily_drawdowns.items():
|
|
||||||
db.connection.execute(
|
|
||||||
"UPDATE daily_summary SET max_drawdown_usd = ? WHERE date = ?",
|
|
||||||
(drawdown, date),
|
|
||||||
)
|
|
||||||
db.connection.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def run_migrations(db: TradingDatabase, csv_path: Path) -> None:
|
|
||||||
"""
|
|
||||||
Run all migrations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: TradingDatabase instance
|
|
||||||
csv_path: Path to trade_log.csv for migration
|
|
||||||
"""
|
|
||||||
logger.info("Running database migrations...")
|
|
||||||
|
|
||||||
# Migrate CSV data if exists
|
|
||||||
migrate_csv_to_db(db, csv_path)
|
|
||||||
|
|
||||||
# Rebuild daily summaries
|
|
||||||
rebuild_daily_summaries(db)
|
|
||||||
|
|
||||||
logger.info("Migrations complete")
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
"""Data models for trade persistence."""
|
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
from typing import Optional
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Trade:
|
|
||||||
"""Represents a completed trade."""
|
|
||||||
|
|
||||||
trade_id: str
|
|
||||||
symbol: str
|
|
||||||
side: str
|
|
||||||
entry_price: float
|
|
||||||
size: float
|
|
||||||
size_usdt: float
|
|
||||||
entry_time: str
|
|
||||||
exit_price: Optional[float] = None
|
|
||||||
pnl_usd: Optional[float] = None
|
|
||||||
pnl_pct: Optional[float] = None
|
|
||||||
exit_time: Optional[str] = None
|
|
||||||
hold_duration_hours: Optional[float] = None
|
|
||||||
reason: Optional[str] = None
|
|
||||||
order_id_entry: Optional[str] = None
|
|
||||||
order_id_exit: Optional[str] = None
|
|
||||||
id: Optional[int] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
"""Convert to dictionary."""
|
|
||||||
return asdict(self)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_row(cls, row: tuple, columns: list[str]) -> "Trade":
|
|
||||||
"""Create Trade from database row."""
|
|
||||||
data = dict(zip(columns, row))
|
|
||||||
return cls(**data)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DailySummary:
|
|
||||||
"""Daily trading summary."""
|
|
||||||
|
|
||||||
date: str
|
|
||||||
total_trades: int = 0
|
|
||||||
winning_trades: int = 0
|
|
||||||
total_pnl_usd: float = 0.0
|
|
||||||
max_drawdown_usd: float = 0.0
|
|
||||||
id: Optional[int] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
"""Convert to dictionary."""
|
|
||||||
return asdict(self)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Session:
|
|
||||||
"""Trading session metadata."""
|
|
||||||
|
|
||||||
start_time: str
|
|
||||||
end_time: Optional[str] = None
|
|
||||||
starting_balance: Optional[float] = None
|
|
||||||
ending_balance: Optional[float] = None
|
|
||||||
total_pnl: Optional[float] = None
|
|
||||||
total_trades: int = 0
|
|
||||||
id: Optional[int] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
"""Convert to dictionary."""
|
|
||||||
return asdict(self)
|
|
||||||
@@ -6,7 +6,6 @@ Uses a pre-trained ML model or trains on historical data.
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -40,47 +39,17 @@ class LiveRegimeStrategy:
|
|||||||
self.paths = path_config
|
self.paths = path_config
|
||||||
self.model: Optional[RandomForestClassifier] = None
|
self.model: Optional[RandomForestClassifier] = None
|
||||||
self.feature_cols: Optional[list] = None
|
self.feature_cols: Optional[list] = None
|
||||||
self.horizon: int = 54 # Default horizon
|
|
||||||
self._last_model_load_time: float = 0.0
|
|
||||||
self._last_train_time: float = 0.0
|
|
||||||
self._load_or_train_model()
|
self._load_or_train_model()
|
||||||
|
|
||||||
def reload_model_if_changed(self) -> None:
|
|
||||||
"""Check if model file has changed and reload if necessary."""
|
|
||||||
if not self.paths.model_path.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
mtime = self.paths.model_path.stat().st_mtime
|
|
||||||
if mtime > self._last_model_load_time:
|
|
||||||
logger.info(f"Model file changed, reloading... (last: {self._last_model_load_time}, new: {mtime})")
|
|
||||||
self._load_or_train_model()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error checking model file: {e}")
|
|
||||||
|
|
||||||
def _load_or_train_model(self) -> None:
|
def _load_or_train_model(self) -> None:
|
||||||
"""Load pre-trained model or train a new one."""
|
"""Load pre-trained model or train a new one."""
|
||||||
if self.paths.model_path.exists():
|
if self.paths.model_path.exists():
|
||||||
try:
|
try:
|
||||||
self._last_model_load_time = self.paths.model_path.stat().st_mtime
|
|
||||||
with open(self.paths.model_path, 'rb') as f:
|
with open(self.paths.model_path, 'rb') as f:
|
||||||
saved = pickle.load(f)
|
saved = pickle.load(f)
|
||||||
self.model = saved['model']
|
self.model = saved['model']
|
||||||
self.feature_cols = saved['feature_cols']
|
self.feature_cols = saved['feature_cols']
|
||||||
|
logger.info(f"Loaded model from {self.paths.model_path}")
|
||||||
# Load horizon from metrics if available
|
|
||||||
if 'metrics' in saved and 'horizon' in saved['metrics']:
|
|
||||||
self.horizon = saved['metrics']['horizon']
|
|
||||||
logger.info(f"Loaded model from {self.paths.model_path} (horizon={self.horizon})")
|
|
||||||
else:
|
|
||||||
logger.info(f"Loaded model from {self.paths.model_path} (default horizon={self.horizon})")
|
|
||||||
|
|
||||||
# Load timestamp if available
|
|
||||||
if 'timestamp' in saved:
|
|
||||||
self._last_train_time = saved['timestamp']
|
|
||||||
else:
|
|
||||||
self._last_train_time = self._last_model_load_time
|
|
||||||
|
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not load model: {e}")
|
logger.warning(f"Could not load model: {e}")
|
||||||
@@ -97,20 +66,11 @@ class LiveRegimeStrategy:
|
|||||||
pickle.dump({
|
pickle.dump({
|
||||||
'model': self.model,
|
'model': self.model,
|
||||||
'feature_cols': self.feature_cols,
|
'feature_cols': self.feature_cols,
|
||||||
'metrics': {'horizon': self.horizon}, # Save horizon
|
|
||||||
'timestamp': time.time()
|
|
||||||
}, f)
|
}, f)
|
||||||
logger.info(f"Saved model to {self.paths.model_path}")
|
logger.info(f"Saved model to {self.paths.model_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Could not save model: {e}")
|
logger.error(f"Could not save model: {e}")
|
||||||
|
|
||||||
def check_retrain(self, features: pd.DataFrame) -> None:
|
|
||||||
"""Check if model needs retraining (older than 24h)."""
|
|
||||||
if time.time() - self._last_train_time > 24 * 3600:
|
|
||||||
logger.info("Model is older than 24h. Retraining...")
|
|
||||||
self.train_model(features)
|
|
||||||
self._last_train_time = time.time()
|
|
||||||
|
|
||||||
def train_model(self, features: pd.DataFrame) -> None:
|
def train_model(self, features: pd.DataFrame) -> None:
|
||||||
"""
|
"""
|
||||||
Train the Random Forest model on historical data.
|
Train the Random Forest model on historical data.
|
||||||
@@ -121,63 +81,20 @@ class LiveRegimeStrategy:
|
|||||||
logger.info(f"Training model on {len(features)} samples...")
|
logger.info(f"Training model on {len(features)} samples...")
|
||||||
|
|
||||||
z_thresh = self.config.z_entry_threshold
|
z_thresh = self.config.z_entry_threshold
|
||||||
horizon = self.horizon
|
horizon = 102 # Optimal horizon from research
|
||||||
profit_target = 0.005 # 0.5% profit threshold
|
profit_target = 0.005 # 0.5% profit threshold
|
||||||
stop_loss_pct = self.config.stop_loss_pct
|
|
||||||
|
|
||||||
# Calculate targets path-dependently
|
# Define targets
|
||||||
spread = features['spread'].values
|
future_min = features['spread'].rolling(window=horizon).min().shift(-horizon)
|
||||||
z_score = features['z_score'].values
|
future_max = features['spread'].rolling(window=horizon).max().shift(-horizon)
|
||||||
n = len(spread)
|
|
||||||
|
|
||||||
targets = np.zeros(n, dtype=int)
|
target_short = features['spread'] * (1 - profit_target)
|
||||||
|
target_long = features['spread'] * (1 + profit_target)
|
||||||
|
|
||||||
candidates = np.where((z_score > z_thresh) | (z_score < -z_thresh))[0]
|
success_short = (features['z_score'] > z_thresh) & (future_min < target_short)
|
||||||
|
success_long = (features['z_score'] < -z_thresh) & (future_max > target_long)
|
||||||
|
|
||||||
for i in candidates:
|
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||||
if i + horizon >= n:
|
|
||||||
continue
|
|
||||||
|
|
||||||
entry_price = spread[i]
|
|
||||||
future_prices = spread[i+1 : i+1+horizon]
|
|
||||||
|
|
||||||
if z_score[i] > z_thresh: # Short
|
|
||||||
target_price = entry_price * (1 - profit_target)
|
|
||||||
stop_price = entry_price * (1 + stop_loss_pct)
|
|
||||||
|
|
||||||
hit_tp = future_prices <= target_price
|
|
||||||
hit_sl = future_prices >= stop_price
|
|
||||||
|
|
||||||
if not np.any(hit_tp):
|
|
||||||
targets[i] = 0
|
|
||||||
elif not np.any(hit_sl):
|
|
||||||
targets[i] = 1
|
|
||||||
else:
|
|
||||||
first_tp_idx = np.argmax(hit_tp)
|
|
||||||
first_sl_idx = np.argmax(hit_sl)
|
|
||||||
if first_tp_idx < first_sl_idx:
|
|
||||||
targets[i] = 1
|
|
||||||
else:
|
|
||||||
targets[i] = 0
|
|
||||||
|
|
||||||
else: # Long
|
|
||||||
target_price = entry_price * (1 + profit_target)
|
|
||||||
stop_price = entry_price * (1 - stop_loss_pct)
|
|
||||||
|
|
||||||
hit_tp = future_prices >= target_price
|
|
||||||
hit_sl = future_prices <= stop_price
|
|
||||||
|
|
||||||
if not np.any(hit_tp):
|
|
||||||
targets[i] = 0
|
|
||||||
elif not np.any(hit_sl):
|
|
||||||
targets[i] = 1
|
|
||||||
else:
|
|
||||||
first_tp_idx = np.argmax(hit_tp)
|
|
||||||
first_sl_idx = np.argmax(hit_sl)
|
|
||||||
if first_tp_idx < first_sl_idx:
|
|
||||||
targets[i] = 1
|
|
||||||
else:
|
|
||||||
targets[i] = 0
|
|
||||||
|
|
||||||
# Exclude non-feature columns
|
# Exclude non-feature columns
|
||||||
exclude = ['spread', 'btc_close', 'eth_close', 'eth_volume']
|
exclude = ['spread', 'btc_close', 'eth_close', 'eth_volume']
|
||||||
@@ -187,10 +104,8 @@ class LiveRegimeStrategy:
|
|||||||
X = features[self.feature_cols].fillna(0)
|
X = features[self.feature_cols].fillna(0)
|
||||||
X = X.replace([np.inf, -np.inf], 0)
|
X = X.replace([np.inf, -np.inf], 0)
|
||||||
|
|
||||||
# Use rows where we had enough data to look ahead
|
# Remove rows with invalid targets
|
||||||
valid_mask = np.zeros(n, dtype=bool)
|
valid_mask = ~np.isnan(targets) & future_min.notna().values & future_max.notna().values
|
||||||
valid_mask[:n-horizon] = True
|
|
||||||
|
|
||||||
X_clean = X[valid_mask]
|
X_clean = X[valid_mask]
|
||||||
y_clean = targets[valid_mask]
|
y_clean = targets[valid_mask]
|
||||||
|
|
||||||
@@ -214,8 +129,7 @@ class LiveRegimeStrategy:
|
|||||||
def generate_signal(
|
def generate_signal(
|
||||||
self,
|
self,
|
||||||
features: pd.DataFrame,
|
features: pd.DataFrame,
|
||||||
current_funding: dict,
|
current_funding: dict
|
||||||
position_side: Optional[str] = None
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Generate trading signal from latest features.
|
Generate trading signal from latest features.
|
||||||
@@ -223,14 +137,10 @@ class LiveRegimeStrategy:
|
|||||||
Args:
|
Args:
|
||||||
features: DataFrame with calculated features
|
features: DataFrame with calculated features
|
||||||
current_funding: Dictionary with funding rate data
|
current_funding: Dictionary with funding rate data
|
||||||
position_side: Current position side ('long', 'short', or None)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Signal dictionary with action, side, confidence, etc.
|
Signal dictionary with action, side, confidence, etc.
|
||||||
"""
|
"""
|
||||||
# Check if retraining is needed
|
|
||||||
self.check_retrain(features)
|
|
||||||
|
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
# Train model if not available
|
# Train model if not available
|
||||||
if len(features) >= 200:
|
if len(features) >= 200:
|
||||||
@@ -300,17 +210,12 @@ class LiveRegimeStrategy:
|
|||||||
signal['action'] = 'hold'
|
signal['action'] = 'hold'
|
||||||
signal['reason'] = f'funding_filter_blocked_short (funding={btc_funding:.4f})'
|
signal['reason'] = f'funding_filter_blocked_short (funding={btc_funding:.4f})'
|
||||||
|
|
||||||
# Check for exit conditions (Overshoot Logic)
|
# Check for exit conditions (mean reversion complete)
|
||||||
if signal['action'] == 'hold' and position_side:
|
if signal['action'] == 'hold':
|
||||||
# Overshoot Logic
|
# Z-score crossed back through 0
|
||||||
# If Long, exit if Z > 0.5 (Reverted past 0 to +0.5)
|
if abs(z_score) < 0.3:
|
||||||
if position_side == 'long' and z_score > 0.5:
|
|
||||||
signal['action'] = 'check_exit'
|
signal['action'] = 'check_exit'
|
||||||
signal['reason'] = f'overshoot_exit_long (z={z_score:.2f} > 0.5)'
|
signal['reason'] = f'z_score_reverted_to_mean ({z_score:.2f})'
|
||||||
# If Short, exit if Z < -0.5 (Reverted past 0 to -0.5)
|
|
||||||
elif position_side == 'short' and z_score < -0.5:
|
|
||||||
signal['action'] = 'check_exit'
|
|
||||||
signal['reason'] = f'overshoot_exit_short (z={z_score:.2f} < -0.5)'
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Signal: {signal['action']} {signal['side'] or ''} "
|
f"Signal: {signal['action']} {signal['side'] or ''} "
|
||||||
@@ -356,9 +261,9 @@ class LiveRegimeStrategy:
|
|||||||
|
|
||||||
def calculate_sl_tp(
|
def calculate_sl_tp(
|
||||||
self,
|
self,
|
||||||
entry_price: Optional[float],
|
entry_price: float,
|
||||||
side: str
|
side: str
|
||||||
) -> tuple[Optional[float], Optional[float]]:
|
) -> tuple[float, float]:
|
||||||
"""
|
"""
|
||||||
Calculate stop-loss and take-profit prices.
|
Calculate stop-loss and take-profit prices.
|
||||||
|
|
||||||
@@ -367,21 +272,8 @@ class LiveRegimeStrategy:
|
|||||||
side: "long" or "short"
|
side: "long" or "short"
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (stop_loss_price, take_profit_price), or (None, None) if
|
Tuple of (stop_loss_price, take_profit_price)
|
||||||
entry_price is invalid
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If side is not "long" or "short"
|
|
||||||
"""
|
"""
|
||||||
if entry_price is None or entry_price <= 0:
|
|
||||||
logger.error(
|
|
||||||
f"Invalid entry_price for SL/TP calculation: {entry_price}"
|
|
||||||
)
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
if side not in ("long", "short"):
|
|
||||||
raise ValueError(f"Invalid side: {side}. Must be 'long' or 'short'")
|
|
||||||
|
|
||||||
sl_pct = self.config.stop_loss_pct
|
sl_pct = self.config.stop_loss_pct
|
||||||
tp_pct = self.config.take_profit_pct
|
tp_pct = self.config.take_profit_pct
|
||||||
|
|
||||||
|
|||||||
@@ -11,19 +11,14 @@ Usage:
|
|||||||
|
|
||||||
# Run with specific settings
|
# Run with specific settings
|
||||||
uv run python -m live_trading.main --max-position 500 --leverage 2
|
uv run python -m live_trading.main --max-position 500 --leverage 2
|
||||||
|
|
||||||
# Run without UI (headless mode)
|
|
||||||
uv run python -m live_trading.main --no-ui
|
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import queue
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
@@ -33,47 +28,22 @@ from live_trading.okx_client import OKXClient
|
|||||||
from live_trading.data_feed import DataFeed
|
from live_trading.data_feed import DataFeed
|
||||||
from live_trading.position_manager import PositionManager
|
from live_trading.position_manager import PositionManager
|
||||||
from live_trading.live_regime_strategy import LiveRegimeStrategy
|
from live_trading.live_regime_strategy import LiveRegimeStrategy
|
||||||
from live_trading.db.database import init_db, TradingDatabase
|
|
||||||
from live_trading.db.migrations import run_migrations
|
|
||||||
from live_trading.ui.state import SharedState, PositionState
|
|
||||||
from live_trading.ui.dashboard import TradingDashboard, setup_ui_logging
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(
|
def setup_logging(log_dir: Path) -> logging.Logger:
|
||||||
log_dir: Path,
|
"""Configure logging for the trading bot."""
|
||||||
log_queue: Optional[queue.Queue] = None,
|
|
||||||
) -> logging.Logger:
|
|
||||||
"""
|
|
||||||
Configure logging for the trading bot.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
log_dir: Directory for log files
|
|
||||||
log_queue: Optional queue for UI log handler
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Logger instance
|
|
||||||
"""
|
|
||||||
log_file = log_dir / "live_trading.log"
|
log_file = log_dir / "live_trading.log"
|
||||||
|
|
||||||
handlers = [
|
|
||||||
logging.FileHandler(log_file),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Only add StreamHandler if no UI (log_queue is None)
|
|
||||||
if log_queue is None:
|
|
||||||
handlers.append(logging.StreamHandler(sys.stdout))
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||||||
handlers=handlers,
|
handlers=[
|
||||||
force=True,
|
logging.FileHandler(log_file),
|
||||||
|
logging.StreamHandler(sys.stdout),
|
||||||
|
],
|
||||||
|
force=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add UI log handler if queue provided
|
|
||||||
if log_queue is not None:
|
|
||||||
setup_ui_logging(log_queue)
|
|
||||||
|
|
||||||
return logging.getLogger(__name__)
|
return logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -89,15 +59,11 @@ class LiveTradingBot:
|
|||||||
self,
|
self,
|
||||||
okx_config: OKXConfig,
|
okx_config: OKXConfig,
|
||||||
trading_config: TradingConfig,
|
trading_config: TradingConfig,
|
||||||
path_config: PathConfig,
|
path_config: PathConfig
|
||||||
database: Optional[TradingDatabase] = None,
|
|
||||||
shared_state: Optional[SharedState] = None,
|
|
||||||
):
|
):
|
||||||
self.okx_config = okx_config
|
self.okx_config = okx_config
|
||||||
self.trading_config = trading_config
|
self.trading_config = trading_config
|
||||||
self.path_config = path_config
|
self.path_config = path_config
|
||||||
self.db = database
|
|
||||||
self.state = shared_state
|
|
||||||
|
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.running = True
|
self.running = True
|
||||||
@@ -108,7 +74,7 @@ class LiveTradingBot:
|
|||||||
self.okx_client = OKXClient(okx_config, trading_config)
|
self.okx_client = OKXClient(okx_config, trading_config)
|
||||||
self.data_feed = DataFeed(self.okx_client, trading_config, path_config)
|
self.data_feed = DataFeed(self.okx_client, trading_config, path_config)
|
||||||
self.position_manager = PositionManager(
|
self.position_manager = PositionManager(
|
||||||
self.okx_client, trading_config, path_config, database
|
self.okx_client, trading_config, path_config
|
||||||
)
|
)
|
||||||
self.strategy = LiveRegimeStrategy(trading_config, path_config)
|
self.strategy = LiveRegimeStrategy(trading_config, path_config)
|
||||||
|
|
||||||
@@ -116,16 +82,6 @@ class LiveTradingBot:
|
|||||||
signal.signal(signal.SIGINT, self._handle_shutdown)
|
signal.signal(signal.SIGINT, self._handle_shutdown)
|
||||||
signal.signal(signal.SIGTERM, self._handle_shutdown)
|
signal.signal(signal.SIGTERM, self._handle_shutdown)
|
||||||
|
|
||||||
# Initialize shared state if provided
|
|
||||||
if self.state:
|
|
||||||
mode = "DEMO" if okx_config.demo_mode else "LIVE"
|
|
||||||
self.state.set_mode(mode)
|
|
||||||
self.state.set_symbols(
|
|
||||||
trading_config.eth_symbol,
|
|
||||||
trading_config.btc_symbol,
|
|
||||||
)
|
|
||||||
self.state.update_account(0.0, 0.0, trading_config.leverage)
|
|
||||||
|
|
||||||
self._print_startup_banner()
|
self._print_startup_banner()
|
||||||
|
|
||||||
def _print_startup_banner(self) -> None:
|
def _print_startup_banner(self) -> None:
|
||||||
@@ -153,8 +109,6 @@ class LiveTradingBot:
|
|||||||
"""Handle shutdown signals gracefully."""
|
"""Handle shutdown signals gracefully."""
|
||||||
self.logger.info("Shutdown signal received, stopping...")
|
self.logger.info("Shutdown signal received, stopping...")
|
||||||
self.running = False
|
self.running = False
|
||||||
if self.state:
|
|
||||||
self.state.stop()
|
|
||||||
|
|
||||||
def run_trading_cycle(self) -> None:
|
def run_trading_cycle(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -164,20 +118,10 @@ class LiveTradingBot:
|
|||||||
2. Update open positions
|
2. Update open positions
|
||||||
3. Generate trading signal
|
3. Generate trading signal
|
||||||
4. Execute trades if signal triggers
|
4. Execute trades if signal triggers
|
||||||
5. Update shared state for UI
|
|
||||||
"""
|
"""
|
||||||
# Reload model if it has changed (e.g. daily training)
|
|
||||||
try:
|
|
||||||
self.strategy.reload_model_if_changed()
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(f"Failed to reload model: {e}")
|
|
||||||
|
|
||||||
cycle_start = datetime.now(timezone.utc)
|
cycle_start = datetime.now(timezone.utc)
|
||||||
self.logger.info(f"--- Trading Cycle Start: {cycle_start.isoformat()} ---")
|
self.logger.info(f"--- Trading Cycle Start: {cycle_start.isoformat()} ---")
|
||||||
|
|
||||||
if self.state:
|
|
||||||
self.state.set_last_cycle_time(cycle_start.isoformat())
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. Fetch market data
|
# 1. Fetch market data
|
||||||
features = self.data_feed.get_latest_data()
|
features = self.data_feed.get_latest_data()
|
||||||
@@ -206,31 +150,19 @@ class LiveTradingBot:
|
|||||||
# 3. Sync with exchange positions
|
# 3. Sync with exchange positions
|
||||||
self.position_manager.sync_with_exchange()
|
self.position_manager.sync_with_exchange()
|
||||||
|
|
||||||
# Get current position side for signal generation
|
|
||||||
symbol = self.trading_config.eth_symbol
|
|
||||||
position = self.position_manager.get_position_for_symbol(symbol)
|
|
||||||
position_side = position.side if position else None
|
|
||||||
|
|
||||||
# 4. Get current funding rates
|
# 4. Get current funding rates
|
||||||
funding = self.data_feed.get_current_funding_rates()
|
funding = self.data_feed.get_current_funding_rates()
|
||||||
|
|
||||||
# 5. Generate trading signal
|
# 5. Generate trading signal
|
||||||
sig = self.strategy.generate_signal(features, funding, position_side=position_side)
|
signal = self.strategy.generate_signal(features, funding)
|
||||||
|
|
||||||
# 6. Update shared state with strategy info
|
# 6. Execute trades based on signal
|
||||||
self._update_strategy_state(sig, funding)
|
if signal['action'] == 'entry':
|
||||||
|
self._execute_entry(signal, eth_price)
|
||||||
|
elif signal['action'] == 'check_exit':
|
||||||
|
self._execute_exit(signal)
|
||||||
|
|
||||||
# 7. Execute trades based on signal
|
# 7. Log portfolio summary
|
||||||
if sig['action'] == 'entry':
|
|
||||||
self._execute_entry(sig, eth_price)
|
|
||||||
elif sig['action'] == 'check_exit':
|
|
||||||
self._execute_exit(sig)
|
|
||||||
|
|
||||||
# 8. Update shared state with position and account
|
|
||||||
self._update_position_state(eth_price)
|
|
||||||
self._update_account_state()
|
|
||||||
|
|
||||||
# 9. Log portfolio summary
|
|
||||||
summary = self.position_manager.get_portfolio_summary()
|
summary = self.position_manager.get_portfolio_summary()
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Portfolio: {summary['open_positions']} positions, "
|
f"Portfolio: {summary['open_positions']} positions, "
|
||||||
@@ -246,61 +178,6 @@ class LiveTradingBot:
|
|||||||
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
|
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
|
||||||
self.logger.info(f"--- Cycle completed in {cycle_duration:.1f}s ---")
|
self.logger.info(f"--- Cycle completed in {cycle_duration:.1f}s ---")
|
||||||
|
|
||||||
def _update_strategy_state(self, sig: dict, funding: dict) -> None:
|
|
||||||
"""Update shared state with strategy information."""
|
|
||||||
if not self.state:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.state.update_strategy(
|
|
||||||
z_score=sig.get('z_score', 0.0),
|
|
||||||
probability=sig.get('probability', 0.0),
|
|
||||||
funding_rate=funding.get('btc_funding', 0.0),
|
|
||||||
action=sig.get('action', 'hold'),
|
|
||||||
reason=sig.get('reason', ''),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_position_state(self, current_price: float) -> None:
|
|
||||||
"""Update shared state with current position."""
|
|
||||||
if not self.state:
|
|
||||||
return
|
|
||||||
|
|
||||||
symbol = self.trading_config.eth_symbol
|
|
||||||
position = self.position_manager.get_position_for_symbol(symbol)
|
|
||||||
|
|
||||||
if position is None:
|
|
||||||
self.state.clear_position()
|
|
||||||
return
|
|
||||||
|
|
||||||
pos_state = PositionState(
|
|
||||||
trade_id=position.trade_id,
|
|
||||||
symbol=position.symbol,
|
|
||||||
side=position.side,
|
|
||||||
entry_price=position.entry_price,
|
|
||||||
current_price=position.current_price,
|
|
||||||
size=position.size,
|
|
||||||
size_usdt=position.size_usdt,
|
|
||||||
unrealized_pnl=position.unrealized_pnl,
|
|
||||||
unrealized_pnl_pct=position.unrealized_pnl_pct,
|
|
||||||
stop_loss_price=position.stop_loss_price,
|
|
||||||
take_profit_price=position.take_profit_price,
|
|
||||||
)
|
|
||||||
self.state.set_position(pos_state)
|
|
||||||
|
|
||||||
def _update_account_state(self) -> None:
|
|
||||||
"""Update shared state with account information."""
|
|
||||||
if not self.state:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
balance = self.okx_client.get_balance()
|
|
||||||
self.state.update_account(
|
|
||||||
balance=balance.get('total', 0.0),
|
|
||||||
available=balance.get('free', 0.0),
|
|
||||||
leverage=self.trading_config.leverage,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(f"Failed to update account state: {e}")
|
|
||||||
|
|
||||||
def _execute_entry(self, signal: dict, current_price: float) -> None:
|
def _execute_entry(self, signal: dict, current_price: float) -> None:
|
||||||
"""Execute entry trade."""
|
"""Execute entry trade."""
|
||||||
symbol = self.trading_config.eth_symbol
|
symbol = self.trading_config.eth_symbol
|
||||||
@@ -314,46 +191,43 @@ class LiveTradingBot:
|
|||||||
# Get account balance
|
# Get account balance
|
||||||
balance = self.okx_client.get_balance()
|
balance = self.okx_client.get_balance()
|
||||||
available_usdt = balance['free']
|
available_usdt = balance['free']
|
||||||
self.logger.info(f"Account balance: ${available_usdt:.2f} USDT available")
|
|
||||||
|
|
||||||
# Calculate position size
|
# Calculate position size
|
||||||
size_usdt = self.strategy.calculate_position_size(signal, available_usdt)
|
size_usdt = self.strategy.calculate_position_size(signal, available_usdt)
|
||||||
if size_usdt <= 0:
|
if size_usdt <= 0:
|
||||||
self.logger.info(
|
self.logger.info("Position size too small, skipping entry")
|
||||||
f"Position size too small (${size_usdt:.2f}), skipping entry. "
|
|
||||||
f"Min required: ${self.strategy.config.min_position_usdt:.2f}"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
size_eth = size_usdt / current_price
|
size_eth = size_usdt / current_price
|
||||||
|
|
||||||
# Calculate SL/TP for logging
|
# Calculate SL/TP
|
||||||
stop_loss, take_profit = self.strategy.calculate_sl_tp(current_price, side)
|
stop_loss, take_profit = self.strategy.calculate_sl_tp(current_price, side)
|
||||||
|
|
||||||
sl_str = f"{stop_loss:.2f}" if stop_loss else "N/A"
|
|
||||||
tp_str = f"{take_profit:.2f}" if take_profit else "N/A"
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Executing {side.upper()} entry: {size_eth:.4f} ETH @ {current_price:.2f} "
|
f"Executing {side.upper()} entry: {size_eth:.4f} ETH @ {current_price:.2f} "
|
||||||
f"(${size_usdt:.2f}), SL={sl_str}, TP={tp_str}"
|
f"(${size_usdt:.2f}), SL={stop_loss:.2f}, TP={take_profit:.2f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Place market order (guaranteed to have fill price or raises)
|
# Place market order
|
||||||
order_side = "buy" if side == "long" else "sell"
|
order_side = "buy" if side == "long" else "sell"
|
||||||
order = self.okx_client.place_market_order(symbol, order_side, size_eth)
|
order = self.okx_client.place_market_order(symbol, order_side, size_eth)
|
||||||
|
|
||||||
# Get filled price and amount (guaranteed by OKX client)
|
# Get filled price (handle None values from OKX response)
|
||||||
filled_price = order['average']
|
filled_price = order.get('average') or order.get('price') or current_price
|
||||||
filled_amount = order.get('filled') or size_eth
|
filled_amount = order.get('filled') or order.get('amount') or size_eth
|
||||||
|
|
||||||
# Calculate SL/TP with filled price
|
# Ensure we have valid numeric values
|
||||||
|
if filled_price is None or filled_price == 0:
|
||||||
|
self.logger.warning(f"No fill price in order response, using current price: {current_price}")
|
||||||
|
filled_price = current_price
|
||||||
|
if filled_amount is None or filled_amount == 0:
|
||||||
|
self.logger.warning(f"No fill amount in order response, using requested: {size_eth}")
|
||||||
|
filled_amount = size_eth
|
||||||
|
|
||||||
|
# Recalculate SL/TP with filled price
|
||||||
stop_loss, take_profit = self.strategy.calculate_sl_tp(filled_price, side)
|
stop_loss, take_profit = self.strategy.calculate_sl_tp(filled_price, side)
|
||||||
|
|
||||||
if stop_loss is None or take_profit is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to calculate SL/TP: filled_price={filled_price}, side={side}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get order ID from response
|
# Get order ID from response
|
||||||
order_id = order.get('id', '')
|
order_id = order.get('id', '')
|
||||||
|
|
||||||
@@ -417,30 +291,22 @@ class LiveTradingBot:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Exit execution failed: {e}", exc_info=True)
|
self.logger.error(f"Exit execution failed: {e}", exc_info=True)
|
||||||
|
|
||||||
def _is_running(self) -> bool:
|
|
||||||
"""Check if bot should continue running."""
|
|
||||||
if not self.running:
|
|
||||||
return False
|
|
||||||
if self.state and not self.state.is_running():
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""Main trading loop."""
|
"""Main trading loop."""
|
||||||
self.logger.info("Starting trading loop...")
|
self.logger.info("Starting trading loop...")
|
||||||
|
|
||||||
while self._is_running():
|
while self.running:
|
||||||
try:
|
try:
|
||||||
self.run_trading_cycle()
|
self.run_trading_cycle()
|
||||||
|
|
||||||
if self._is_running():
|
if self.running:
|
||||||
sleep_seconds = self.trading_config.sleep_seconds
|
sleep_seconds = self.trading_config.sleep_seconds
|
||||||
minutes = sleep_seconds // 60
|
minutes = sleep_seconds // 60
|
||||||
self.logger.info(f"Sleeping for {minutes} minutes...")
|
self.logger.info(f"Sleeping for {minutes} minutes...")
|
||||||
|
|
||||||
# Sleep in smaller chunks to allow faster shutdown
|
# Sleep in smaller chunks to allow faster shutdown
|
||||||
for _ in range(sleep_seconds):
|
for _ in range(sleep_seconds):
|
||||||
if not self._is_running():
|
if not self.running:
|
||||||
break
|
break
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
@@ -454,8 +320,6 @@ class LiveTradingBot:
|
|||||||
# Cleanup
|
# Cleanup
|
||||||
self.logger.info("Shutting down...")
|
self.logger.info("Shutting down...")
|
||||||
self.position_manager.save_positions()
|
self.position_manager.save_positions()
|
||||||
if self.db:
|
|
||||||
self.db.close()
|
|
||||||
self.logger.info("Shutdown complete")
|
self.logger.info("Shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
@@ -487,11 +351,6 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
|
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--no-ui",
|
|
||||||
action="store_true",
|
|
||||||
help="Run in headless mode without terminal UI"
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -512,64 +371,19 @@ def main():
|
|||||||
if args.live:
|
if args.live:
|
||||||
okx_config.demo_mode = False
|
okx_config.demo_mode = False
|
||||||
|
|
||||||
# Determine if UI should be enabled
|
|
||||||
use_ui = not args.no_ui and sys.stdin.isatty()
|
|
||||||
|
|
||||||
# Initialize database
|
|
||||||
db_path = path_config.base_dir / "live_trading" / "trading.db"
|
|
||||||
db = init_db(db_path)
|
|
||||||
|
|
||||||
# Run migrations (imports CSV if exists)
|
|
||||||
run_migrations(db, path_config.trade_log_file)
|
|
||||||
|
|
||||||
# Initialize UI components if enabled
|
|
||||||
log_queue: Optional[queue.Queue] = None
|
|
||||||
shared_state: Optional[SharedState] = None
|
|
||||||
dashboard: Optional[TradingDashboard] = None
|
|
||||||
|
|
||||||
if use_ui:
|
|
||||||
log_queue = queue.Queue(maxsize=1000)
|
|
||||||
shared_state = SharedState()
|
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logger = setup_logging(path_config.logs_dir, log_queue)
|
logger = setup_logging(path_config.logs_dir)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create bot
|
# Create and run bot
|
||||||
bot = LiveTradingBot(
|
bot = LiveTradingBot(okx_config, trading_config, path_config)
|
||||||
okx_config,
|
|
||||||
trading_config,
|
|
||||||
path_config,
|
|
||||||
database=db,
|
|
||||||
shared_state=shared_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start dashboard if UI enabled
|
|
||||||
if use_ui and shared_state and log_queue:
|
|
||||||
dashboard = TradingDashboard(
|
|
||||||
state=shared_state,
|
|
||||||
db=db,
|
|
||||||
log_queue=log_queue,
|
|
||||||
on_quit=lambda: setattr(bot, 'running', False),
|
|
||||||
)
|
|
||||||
dashboard.start()
|
|
||||||
logger.info("Dashboard started")
|
|
||||||
|
|
||||||
# Run bot
|
|
||||||
bot.run()
|
bot.run()
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Configuration error: {e}")
|
logger.error(f"Configuration error: {e}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Fatal error: {e}", exc_info=True)
|
logger.error(f"Fatal error: {e}", exc_info=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
finally:
|
|
||||||
# Cleanup
|
|
||||||
if dashboard:
|
|
||||||
dashboard.stop()
|
|
||||||
if db:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
145
live_trading/multi_pair/README.md
Normal file
145
live_trading/multi_pair/README.md
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
# Multi-Pair Divergence Live Trading
|
||||||
|
|
||||||
|
This module implements live trading for the Multi-Pair Divergence Selection Strategy on OKX perpetual futures.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The strategy scans 10 cryptocurrency pairs for spread divergence opportunities:
|
||||||
|
|
||||||
|
1. **Pair Universe**: Top 10 assets by market cap (BTC, ETH, SOL, XRP, BNB, DOGE, ADA, AVAX, LINK, DOT)
|
||||||
|
2. **Spread Z-Score**: Identifies when pairs are divergent from their historical mean
|
||||||
|
3. **Universal ML Model**: Predicts probability of successful mean reversion
|
||||||
|
4. **Dynamic Selection**: Trades the pair with highest divergence score
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
Before running live trading, you must train the model via backtesting:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python scripts/run_multi_pair_backtest.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This creates `data/multi_pair_model.pkl` which the live trading bot requires.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
### 1. API Keys
|
||||||
|
|
||||||
|
Same as single-pair trading. Set in `.env`:
|
||||||
|
|
||||||
|
```env
|
||||||
|
OKX_API_KEY=your_api_key
|
||||||
|
OKX_SECRET=your_secret
|
||||||
|
OKX_PASSWORD=your_passphrase
|
||||||
|
OKX_DEMO_MODE=true # Use demo for testing
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Dependencies
|
||||||
|
|
||||||
|
All dependencies are in `pyproject.toml`. No additional installation needed.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Run with Demo Account (Recommended First)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python -m live_trading.multi_pair.main
|
||||||
|
```
|
||||||
|
|
||||||
|
### Command Line Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Custom position size
|
||||||
|
uv run python -m live_trading.multi_pair.main --max-position 500
|
||||||
|
|
||||||
|
# Custom leverage
|
||||||
|
uv run python -m live_trading.multi_pair.main --leverage 2
|
||||||
|
|
||||||
|
# Custom cycle interval (in seconds)
|
||||||
|
uv run python -m live_trading.multi_pair.main --interval 1800
|
||||||
|
|
||||||
|
# Combine options
|
||||||
|
uv run python -m live_trading.multi_pair.main --max-position 1000 --leverage 3 --interval 3600
|
||||||
|
```
|
||||||
|
|
||||||
|
### Live Trading (Use with Caution)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python -m live_trading.multi_pair.main --live
|
||||||
|
```
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
### Each Trading Cycle
|
||||||
|
|
||||||
|
1. **Fetch Data**: Gets OHLCV for all 10 assets from OKX
|
||||||
|
2. **Calculate Features**: Computes Z-Score, RSI, volatility for all 45 pair combinations
|
||||||
|
3. **Score Pairs**: Uses ML model to rank pairs by divergence score (|Z| x probability)
|
||||||
|
4. **Check Exits**: If holding, check mean reversion or SL/TP
|
||||||
|
5. **Enter Best**: If no position, enter the highest-scoring divergent pair
|
||||||
|
|
||||||
|
### Entry Conditions
|
||||||
|
|
||||||
|
- |Z-Score| > 1.0 (spread diverged from mean)
|
||||||
|
- ML probability > 0.5 (model predicts successful reversion)
|
||||||
|
- Funding rate filter passes (avoid crowded trades)
|
||||||
|
|
||||||
|
### Exit Conditions
|
||||||
|
|
||||||
|
- Mean reversion: |Z-Score| returns to ~0
|
||||||
|
- Stop-loss: ATR-based (default ~6%)
|
||||||
|
- Take-profit: ATR-based (default ~5%)
|
||||||
|
|
||||||
|
## Strategy Parameters
|
||||||
|
|
||||||
|
| Parameter | Default | Description |
|
||||||
|
|-----------|---------|-------------|
|
||||||
|
| `z_entry_threshold` | 1.0 | Enter when \|Z-Score\| > threshold |
|
||||||
|
| `z_exit_threshold` | 0.0 | Exit when Z reverts to mean |
|
||||||
|
| `z_window` | 24 | Rolling window for Z-Score (hours) |
|
||||||
|
| `prob_threshold` | 0.5 | ML probability threshold for entry |
|
||||||
|
| `funding_threshold` | 0.0005 | Funding rate filter (0.05%) |
|
||||||
|
| `sl_atr_multiplier` | 10.0 | Stop-loss as ATR multiple |
|
||||||
|
| `tp_atr_multiplier` | 8.0 | Take-profit as ATR multiple |
|
||||||
|
|
||||||
|
## Files
|
||||||
|
|
||||||
|
### Input
|
||||||
|
|
||||||
|
- `data/multi_pair_model.pkl` - Pre-trained ML model (required)
|
||||||
|
|
||||||
|
### Output
|
||||||
|
|
||||||
|
- `logs/multi_pair_live.log` - Trading logs
|
||||||
|
- `live_trading/multi_pair_positions.json` - Position persistence
|
||||||
|
- `live_trading/multi_pair_trade_log.csv` - Trade history
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
live_trading/multi_pair/
|
||||||
|
__init__.py # Module exports
|
||||||
|
config.py # Configuration classes
|
||||||
|
data_feed.py # Multi-asset OHLCV fetcher
|
||||||
|
strategy.py # ML scoring and signal generation
|
||||||
|
main.py # Bot orchestrator
|
||||||
|
README.md # This file
|
||||||
|
```
|
||||||
|
|
||||||
|
## Differences from Single-Pair
|
||||||
|
|
||||||
|
| Aspect | Single-Pair | Multi-Pair |
|
||||||
|
|--------|-------------|------------|
|
||||||
|
| Assets | ETH only (BTC context) | 10 assets, 45 pairs |
|
||||||
|
| Model | ETH-specific | Universal across pairs |
|
||||||
|
| Selection | Fixed pair | Dynamic best pair |
|
||||||
|
| Stops | Fixed 6%/5% | ATR-based dynamic |
|
||||||
|
|
||||||
|
## Risk Warning
|
||||||
|
|
||||||
|
This is experimental trading software. Use at your own risk:
|
||||||
|
|
||||||
|
- Always start with demo trading
|
||||||
|
- Never risk more than you can afford to lose
|
||||||
|
- Monitor the bot regularly
|
||||||
|
- The model was trained on historical data and may not predict future performance
|
||||||
11
live_trading/multi_pair/__init__.py
Normal file
11
live_trading/multi_pair/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""Multi-Pair Divergence Live Trading Module."""
|
||||||
|
from .config import MultiPairLiveConfig, get_multi_pair_config
|
||||||
|
from .data_feed import MultiPairDataFeed
|
||||||
|
from .strategy import LiveMultiPairStrategy
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MultiPairLiveConfig",
|
||||||
|
"get_multi_pair_config",
|
||||||
|
"MultiPairDataFeed",
|
||||||
|
"LiveMultiPairStrategy",
|
||||||
|
]
|
||||||
145
live_trading/multi_pair/config.py
Normal file
145
live_trading/multi_pair/config.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""
|
||||||
|
Configuration for Multi-Pair Live Trading.
|
||||||
|
|
||||||
|
Extends the base live trading config with multi-pair specific settings.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OKXConfig:
|
||||||
|
"""OKX API configuration."""
|
||||||
|
api_key: str = field(default_factory=lambda: "")
|
||||||
|
secret: str = field(default_factory=lambda: "")
|
||||||
|
password: str = field(default_factory=lambda: "")
|
||||||
|
demo_mode: bool = field(default_factory=lambda: True)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Load credentials based on demo mode setting."""
|
||||||
|
self.demo_mode = os.getenv("OKX_DEMO_MODE", "true").lower() in ("true", "1", "yes")
|
||||||
|
|
||||||
|
if self.demo_mode:
|
||||||
|
self.api_key = os.getenv("OKX_DEMO_API_KEY", os.getenv("OKX_API_KEY", ""))
|
||||||
|
self.secret = os.getenv("OKX_DEMO_SECRET", os.getenv("OKX_SECRET", ""))
|
||||||
|
self.password = os.getenv("OKX_DEMO_PASSWORD", os.getenv("OKX_PASSWORD", ""))
|
||||||
|
else:
|
||||||
|
self.api_key = os.getenv("OKX_API_KEY", "")
|
||||||
|
self.secret = os.getenv("OKX_SECRET", "")
|
||||||
|
self.password = os.getenv("OKX_PASSWORD", "")
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Validate that required credentials are present."""
|
||||||
|
mode = "demo" if self.demo_mode else "live"
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(f"OKX API key not set for {mode} mode")
|
||||||
|
if not self.secret:
|
||||||
|
raise ValueError(f"OKX secret not set for {mode} mode")
|
||||||
|
if not self.password:
|
||||||
|
raise ValueError(f"OKX password not set for {mode} mode")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiPairLiveConfig:
|
||||||
|
"""
|
||||||
|
Configuration for multi-pair live trading.
|
||||||
|
|
||||||
|
Combines trading parameters, strategy settings, and risk management.
|
||||||
|
"""
|
||||||
|
# Asset Universe (top 10 by market cap perpetuals)
|
||||||
|
assets: list[str] = field(default_factory=lambda: [
|
||||||
|
"BTC/USDT:USDT", "ETH/USDT:USDT", "SOL/USDT:USDT", "XRP/USDT:USDT",
|
||||||
|
"BNB/USDT:USDT", "DOGE/USDT:USDT", "ADA/USDT:USDT", "AVAX/USDT:USDT",
|
||||||
|
"LINK/USDT:USDT", "DOT/USDT:USDT"
|
||||||
|
])
|
||||||
|
|
||||||
|
# Timeframe
|
||||||
|
timeframe: str = "1h"
|
||||||
|
candles_to_fetch: int = 500 # Enough for feature calculation
|
||||||
|
|
||||||
|
# Z-Score Thresholds
|
||||||
|
z_window: int = 24
|
||||||
|
z_entry_threshold: float = 1.0
|
||||||
|
z_exit_threshold: float = 0.0 # Exit at mean reversion
|
||||||
|
|
||||||
|
# ML Thresholds
|
||||||
|
prob_threshold: float = 0.5
|
||||||
|
|
||||||
|
# Position sizing
|
||||||
|
max_position_usdt: float = -1.0 # If <= 0, use all available funds
|
||||||
|
min_position_usdt: float = 10.0
|
||||||
|
leverage: int = 1
|
||||||
|
margin_mode: str = "cross"
|
||||||
|
max_concurrent_positions: int = 1 # Trade one pair at a time
|
||||||
|
|
||||||
|
# Risk Management - ATR-Based Stops
|
||||||
|
atr_period: int = 14
|
||||||
|
sl_atr_multiplier: float = 10.0
|
||||||
|
tp_atr_multiplier: float = 8.0
|
||||||
|
|
||||||
|
# Fallback fixed percentages
|
||||||
|
base_sl_pct: float = 0.06
|
||||||
|
base_tp_pct: float = 0.05
|
||||||
|
|
||||||
|
# ATR bounds
|
||||||
|
min_sl_pct: float = 0.02
|
||||||
|
max_sl_pct: float = 0.10
|
||||||
|
min_tp_pct: float = 0.02
|
||||||
|
max_tp_pct: float = 0.15
|
||||||
|
|
||||||
|
# Funding Rate Filter
|
||||||
|
funding_threshold: float = 0.0005 # 0.05%
|
||||||
|
|
||||||
|
# Trade Management
|
||||||
|
min_hold_bars: int = 0
|
||||||
|
cooldown_bars: int = 0
|
||||||
|
|
||||||
|
# Execution
|
||||||
|
sleep_seconds: int = 3600 # Run every hour
|
||||||
|
slippage_pct: float = 0.001
|
||||||
|
|
||||||
|
def get_asset_short_name(self, symbol: str) -> str:
|
||||||
|
"""Convert symbol to short name (e.g., BTC/USDT:USDT -> btc)."""
|
||||||
|
return symbol.split("/")[0].lower()
|
||||||
|
|
||||||
|
def get_pair_count(self) -> int:
|
||||||
|
"""Calculate number of unique pairs from asset list."""
|
||||||
|
n = len(self.assets)
|
||||||
|
return n * (n - 1) // 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PathConfig:
|
||||||
|
"""File paths configuration."""
|
||||||
|
base_dir: Path = field(
|
||||||
|
default_factory=lambda: Path(__file__).parent.parent.parent
|
||||||
|
)
|
||||||
|
data_dir: Path = field(default=None)
|
||||||
|
logs_dir: Path = field(default=None)
|
||||||
|
model_path: Path = field(default=None)
|
||||||
|
positions_file: Path = field(default=None)
|
||||||
|
trade_log_file: Path = field(default=None)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.data_dir = self.base_dir / "data"
|
||||||
|
self.logs_dir = self.base_dir / "logs"
|
||||||
|
# Use the same model as backtesting
|
||||||
|
self.model_path = self.base_dir / "data" / "multi_pair_model.pkl"
|
||||||
|
self.positions_file = self.base_dir / "live_trading" / "multi_pair_positions.json"
|
||||||
|
self.trade_log_file = self.base_dir / "live_trading" / "multi_pair_trade_log.csv"
|
||||||
|
|
||||||
|
# Ensure directories exist
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_multi_pair_config() -> tuple[OKXConfig, MultiPairLiveConfig, PathConfig]:
|
||||||
|
"""Get all configuration objects for multi-pair trading."""
|
||||||
|
okx = OKXConfig()
|
||||||
|
trading = MultiPairLiveConfig()
|
||||||
|
paths = PathConfig()
|
||||||
|
return okx, trading, paths
|
||||||
336
live_trading/multi_pair/data_feed.py
Normal file
336
live_trading/multi_pair/data_feed.py
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
"""
|
||||||
|
Multi-Pair Data Feed for Live Trading.
|
||||||
|
|
||||||
|
Fetches real-time OHLCV and funding data for all assets in the universe.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from itertools import combinations
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import ta
|
||||||
|
|
||||||
|
from live_trading.okx_client import OKXClient
|
||||||
|
from .config import MultiPairLiveConfig, PathConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TradingPair:
|
||||||
|
"""
|
||||||
|
Represents a tradeable pair for spread analysis.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_asset: First asset symbol (e.g., ETH/USDT:USDT)
|
||||||
|
quote_asset: Second asset symbol (e.g., BTC/USDT:USDT)
|
||||||
|
pair_id: Unique identifier
|
||||||
|
"""
|
||||||
|
def __init__(self, base_asset: str, quote_asset: str):
|
||||||
|
self.base_asset = base_asset
|
||||||
|
self.quote_asset = quote_asset
|
||||||
|
self.pair_id = f"{base_asset}__{quote_asset}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Human-readable pair name."""
|
||||||
|
base = self.base_asset.split("/")[0]
|
||||||
|
quote = self.quote_asset.split("/")[0]
|
||||||
|
return f"{base}/{quote}"
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.pair_id)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, TradingPair):
|
||||||
|
return False
|
||||||
|
return self.pair_id == other.pair_id
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPairDataFeed:
|
||||||
|
"""
|
||||||
|
Real-time data feed for multi-pair strategy.
|
||||||
|
|
||||||
|
Fetches OHLCV data for all assets and calculates spread features
|
||||||
|
for all pair combinations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
okx_client: OKXClient,
|
||||||
|
config: MultiPairLiveConfig,
|
||||||
|
path_config: PathConfig
|
||||||
|
):
|
||||||
|
self.client = okx_client
|
||||||
|
self.config = config
|
||||||
|
self.paths = path_config
|
||||||
|
|
||||||
|
# Cache for asset data
|
||||||
|
self._asset_data: dict[str, pd.DataFrame] = {}
|
||||||
|
self._funding_rates: dict[str, float] = {}
|
||||||
|
self._pairs: list[TradingPair] = []
|
||||||
|
|
||||||
|
# Generate pairs
|
||||||
|
self._generate_pairs()
|
||||||
|
|
||||||
|
def _generate_pairs(self) -> None:
|
||||||
|
"""Generate all unique pairs from asset universe."""
|
||||||
|
self._pairs = []
|
||||||
|
for base, quote in combinations(self.config.assets, 2):
|
||||||
|
pair = TradingPair(base_asset=base, quote_asset=quote)
|
||||||
|
self._pairs.append(pair)
|
||||||
|
|
||||||
|
logger.info("Generated %d pairs from %d assets",
|
||||||
|
len(self._pairs), len(self.config.assets))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pairs(self) -> list[TradingPair]:
|
||||||
|
"""Get list of trading pairs."""
|
||||||
|
return self._pairs
|
||||||
|
|
||||||
|
def fetch_all_ohlcv(self) -> dict[str, pd.DataFrame]:
|
||||||
|
"""
|
||||||
|
Fetch OHLCV data for all assets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping symbol to OHLCV DataFrame
|
||||||
|
"""
|
||||||
|
self._asset_data = {}
|
||||||
|
|
||||||
|
for symbol in self.config.assets:
|
||||||
|
try:
|
||||||
|
ohlcv = self.client.fetch_ohlcv(
|
||||||
|
symbol,
|
||||||
|
self.config.timeframe,
|
||||||
|
self.config.candles_to_fetch
|
||||||
|
)
|
||||||
|
df = self._ohlcv_to_dataframe(ohlcv)
|
||||||
|
|
||||||
|
if len(df) >= 200:
|
||||||
|
self._asset_data[symbol] = df
|
||||||
|
logger.debug("Fetched %s: %d candles", symbol, len(df))
|
||||||
|
else:
|
||||||
|
logger.warning("Skipping %s: insufficient data (%d)",
|
||||||
|
symbol, len(df))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error fetching %s: %s", symbol, e)
|
||||||
|
|
||||||
|
logger.info("Fetched data for %d/%d assets",
|
||||||
|
len(self._asset_data), len(self.config.assets))
|
||||||
|
return self._asset_data
|
||||||
|
|
||||||
|
def _ohlcv_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
|
||||||
|
"""Convert OHLCV list to DataFrame."""
|
||||||
|
df = pd.DataFrame(
|
||||||
|
ohlcv,
|
||||||
|
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||||
|
)
|
||||||
|
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||||
|
df.set_index('timestamp', inplace=True)
|
||||||
|
return df
|
||||||
|
|
||||||
|
def fetch_all_funding_rates(self) -> dict[str, float]:
|
||||||
|
"""
|
||||||
|
Fetch current funding rates for all assets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping symbol to funding rate
|
||||||
|
"""
|
||||||
|
self._funding_rates = {}
|
||||||
|
|
||||||
|
for symbol in self.config.assets:
|
||||||
|
try:
|
||||||
|
rate = self.client.get_funding_rate(symbol)
|
||||||
|
self._funding_rates[symbol] = rate
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not get funding for %s: %s", symbol, e)
|
||||||
|
self._funding_rates[symbol] = 0.0
|
||||||
|
|
||||||
|
return self._funding_rates
|
||||||
|
|
||||||
|
def calculate_pair_features(
|
||||||
|
self,
|
||||||
|
pair: TradingPair
|
||||||
|
) -> pd.DataFrame | None:
|
||||||
|
"""
|
||||||
|
Calculate features for a single pair.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pair: Trading pair
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with features, or None if insufficient data
|
||||||
|
"""
|
||||||
|
base = pair.base_asset
|
||||||
|
quote = pair.quote_asset
|
||||||
|
|
||||||
|
if base not in self._asset_data or quote not in self._asset_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
df_base = self._asset_data[base]
|
||||||
|
df_quote = self._asset_data[quote]
|
||||||
|
|
||||||
|
# Align indices
|
||||||
|
common_idx = df_base.index.intersection(df_quote.index)
|
||||||
|
if len(common_idx) < 200:
|
||||||
|
return None
|
||||||
|
|
||||||
|
df_a = df_base.loc[common_idx]
|
||||||
|
df_b = df_quote.loc[common_idx]
|
||||||
|
|
||||||
|
# Calculate spread (base / quote)
|
||||||
|
spread = df_a['close'] / df_b['close']
|
||||||
|
|
||||||
|
# Z-Score
|
||||||
|
z_window = self.config.z_window
|
||||||
|
rolling_mean = spread.rolling(window=z_window).mean()
|
||||||
|
rolling_std = spread.rolling(window=z_window).std()
|
||||||
|
z_score = (spread - rolling_mean) / rolling_std
|
||||||
|
|
||||||
|
# Spread Technicals
|
||||||
|
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||||
|
spread_roc = spread.pct_change(periods=5) * 100
|
||||||
|
spread_change_1h = spread.pct_change(periods=1)
|
||||||
|
|
||||||
|
# Volume Analysis
|
||||||
|
vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10)
|
||||||
|
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||||
|
vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10)
|
||||||
|
|
||||||
|
# Volatility
|
||||||
|
ret_a = df_a['close'].pct_change()
|
||||||
|
ret_b = df_b['close'].pct_change()
|
||||||
|
vol_a = ret_a.rolling(window=z_window).std()
|
||||||
|
vol_b = ret_b.rolling(window=z_window).std()
|
||||||
|
vol_spread_ratio = vol_a / (vol_b + 1e-10)
|
||||||
|
|
||||||
|
# Realized Volatility
|
||||||
|
realized_vol_a = ret_a.rolling(window=24).std()
|
||||||
|
realized_vol_b = ret_b.rolling(window=24).std()
|
||||||
|
|
||||||
|
# ATR (Average True Range)
|
||||||
|
high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close']
|
||||||
|
|
||||||
|
tr_a = pd.concat([
|
||||||
|
high_a - low_a,
|
||||||
|
(high_a - close_a.shift(1)).abs(),
|
||||||
|
(low_a - close_a.shift(1)).abs()
|
||||||
|
], axis=1).max(axis=1)
|
||||||
|
atr_a = tr_a.rolling(window=self.config.atr_period).mean()
|
||||||
|
atr_pct_a = atr_a / close_a
|
||||||
|
|
||||||
|
# Build feature DataFrame
|
||||||
|
features = pd.DataFrame(index=common_idx)
|
||||||
|
features['pair_id'] = pair.pair_id
|
||||||
|
features['base_asset'] = base
|
||||||
|
features['quote_asset'] = quote
|
||||||
|
|
||||||
|
# Price data
|
||||||
|
features['spread'] = spread
|
||||||
|
features['base_close'] = df_a['close']
|
||||||
|
features['quote_close'] = df_b['close']
|
||||||
|
features['base_volume'] = df_a['volume']
|
||||||
|
|
||||||
|
# Core Features
|
||||||
|
features['z_score'] = z_score
|
||||||
|
features['spread_rsi'] = spread_rsi
|
||||||
|
features['spread_roc'] = spread_roc
|
||||||
|
features['spread_change_1h'] = spread_change_1h
|
||||||
|
features['vol_ratio'] = vol_ratio
|
||||||
|
features['vol_ratio_rel'] = vol_ratio_rel
|
||||||
|
features['vol_diff_ratio'] = vol_spread_ratio
|
||||||
|
|
||||||
|
# Volatility
|
||||||
|
features['realized_vol_base'] = realized_vol_a
|
||||||
|
features['realized_vol_quote'] = realized_vol_b
|
||||||
|
features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2
|
||||||
|
|
||||||
|
# ATR
|
||||||
|
features['atr_base'] = atr_a
|
||||||
|
features['atr_pct_base'] = atr_pct_a
|
||||||
|
|
||||||
|
# Pair encoding
|
||||||
|
assets = self.config.assets
|
||||||
|
features['base_idx'] = assets.index(base) if base in assets else -1
|
||||||
|
features['quote_idx'] = assets.index(quote) if quote in assets else -1
|
||||||
|
|
||||||
|
# Funding rates
|
||||||
|
base_funding = self._funding_rates.get(base, 0.0)
|
||||||
|
quote_funding = self._funding_rates.get(quote, 0.0)
|
||||||
|
features['base_funding'] = base_funding
|
||||||
|
features['quote_funding'] = quote_funding
|
||||||
|
features['funding_diff'] = base_funding - quote_funding
|
||||||
|
features['funding_avg'] = (base_funding + quote_funding) / 2
|
||||||
|
|
||||||
|
# Drop NaN rows in core features
|
||||||
|
core_cols = [
|
||||||
|
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
|
||||||
|
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
|
||||||
|
'realized_vol_base', 'atr_base', 'atr_pct_base'
|
||||||
|
]
|
||||||
|
features = features.dropna(subset=core_cols)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def calculate_all_pair_features(self) -> dict[str, pd.DataFrame]:
|
||||||
|
"""
|
||||||
|
Calculate features for all pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping pair_id to feature DataFrame
|
||||||
|
"""
|
||||||
|
all_features = {}
|
||||||
|
|
||||||
|
for pair in self._pairs:
|
||||||
|
features = self.calculate_pair_features(pair)
|
||||||
|
if features is not None and len(features) > 0:
|
||||||
|
all_features[pair.pair_id] = features
|
||||||
|
|
||||||
|
logger.info("Calculated features for %d/%d pairs",
|
||||||
|
len(all_features), len(self._pairs))
|
||||||
|
|
||||||
|
return all_features
|
||||||
|
|
||||||
|
def get_latest_data(self) -> dict[str, pd.DataFrame] | None:
|
||||||
|
"""
|
||||||
|
Fetch and process latest market data for all pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of pair features or None on error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Fetch OHLCV for all assets
|
||||||
|
self.fetch_all_ohlcv()
|
||||||
|
|
||||||
|
if len(self._asset_data) < 2:
|
||||||
|
logger.warning("Insufficient assets fetched")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Fetch funding rates
|
||||||
|
self.fetch_all_funding_rates()
|
||||||
|
|
||||||
|
# Calculate features for all pairs
|
||||||
|
pair_features = self.calculate_all_pair_features()
|
||||||
|
|
||||||
|
if not pair_features:
|
||||||
|
logger.warning("No pair features calculated")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info("Processed %d pairs with valid features", len(pair_features))
|
||||||
|
return pair_features
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error fetching market data: %s", e, exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_pair_by_id(self, pair_id: str) -> TradingPair | None:
|
||||||
|
"""Get pair object by ID."""
|
||||||
|
for pair in self._pairs:
|
||||||
|
if pair.pair_id == pair_id:
|
||||||
|
return pair
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_current_price(self, symbol: str) -> float | None:
|
||||||
|
"""Get current price for a symbol."""
|
||||||
|
if symbol in self._asset_data:
|
||||||
|
return self._asset_data[symbol]['close'].iloc[-1]
|
||||||
|
return None
|
||||||
609
live_trading/multi_pair/main.py
Normal file
609
live_trading/multi_pair/main.py
Normal file
@@ -0,0 +1,609 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Multi-Pair Divergence Live Trading Bot.
|
||||||
|
|
||||||
|
Trades the top 10 cryptocurrency pairs based on spread divergence
|
||||||
|
using a universal ML model for signal generation.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Run with demo account (default)
|
||||||
|
uv run python -m live_trading.multi_pair.main
|
||||||
|
|
||||||
|
# Run with specific settings
|
||||||
|
uv run python -m live_trading.multi_pair.main --max-position 500 --leverage 2
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
|
from live_trading.okx_client import OKXClient
|
||||||
|
from live_trading.position_manager import PositionManager
|
||||||
|
from live_trading.multi_pair.config import (
|
||||||
|
OKXConfig, MultiPairLiveConfig, PathConfig, get_multi_pair_config
|
||||||
|
)
|
||||||
|
from live_trading.multi_pair.data_feed import MultiPairDataFeed, TradingPair
|
||||||
|
from live_trading.multi_pair.strategy import LiveMultiPairStrategy
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(log_dir: Path) -> logging.Logger:
|
||||||
|
"""Configure logging for the trading bot."""
|
||||||
|
log_file = log_dir / "multi_pair_live.log"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(log_file),
|
||||||
|
logging.StreamHandler(sys.stdout),
|
||||||
|
],
|
||||||
|
force=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PositionState:
|
||||||
|
"""Track current position state for multi-pair."""
|
||||||
|
pair: TradingPair | None = None
|
||||||
|
pair_id: str | None = None
|
||||||
|
direction: str | None = None
|
||||||
|
entry_price: float = 0.0
|
||||||
|
size: float = 0.0
|
||||||
|
stop_loss: float = 0.0
|
||||||
|
take_profit: float = 0.0
|
||||||
|
entry_time: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPairLiveTradingBot:
|
||||||
|
"""
|
||||||
|
Main trading bot for multi-pair divergence strategy.
|
||||||
|
|
||||||
|
Coordinates data fetching, pair scoring, and order execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
okx_config: OKXConfig,
|
||||||
|
trading_config: MultiPairLiveConfig,
|
||||||
|
path_config: PathConfig
|
||||||
|
):
|
||||||
|
self.okx_config = okx_config
|
||||||
|
self.trading_config = trading_config
|
||||||
|
self.path_config = path_config
|
||||||
|
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self.logger.info("Initializing multi-pair trading bot...")
|
||||||
|
|
||||||
|
# Create OKX client with adapted config
|
||||||
|
self._adapted_trading_config = self._adapt_config_for_okx_client()
|
||||||
|
self.okx_client = OKXClient(okx_config, self._adapted_trading_config)
|
||||||
|
|
||||||
|
# Initialize data feed
|
||||||
|
self.data_feed = MultiPairDataFeed(
|
||||||
|
self.okx_client, trading_config, path_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize position manager (reuse from single-pair)
|
||||||
|
self.position_manager = PositionManager(
|
||||||
|
self.okx_client, self._adapted_trading_config, path_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize strategy
|
||||||
|
self.strategy = LiveMultiPairStrategy(trading_config, path_config)
|
||||||
|
|
||||||
|
# Current position state
|
||||||
|
self.position = PositionState()
|
||||||
|
|
||||||
|
# Register signal handlers
|
||||||
|
signal.signal(signal.SIGINT, self._handle_shutdown)
|
||||||
|
signal.signal(signal.SIGTERM, self._handle_shutdown)
|
||||||
|
|
||||||
|
self._print_startup_banner()
|
||||||
|
|
||||||
|
# Sync with exchange positions on startup
|
||||||
|
self._sync_position_from_exchange()
|
||||||
|
|
||||||
|
def _adapt_config_for_okx_client(self):
|
||||||
|
"""Create config compatible with OKXClient."""
|
||||||
|
# OKXClient expects specific attributes
|
||||||
|
@dataclass
|
||||||
|
class AdaptedConfig:
|
||||||
|
eth_symbol: str = "ETH/USDT:USDT"
|
||||||
|
btc_symbol: str = "BTC/USDT:USDT"
|
||||||
|
timeframe: str = "1h"
|
||||||
|
candles_to_fetch: int = 500
|
||||||
|
max_position_usdt: float = -1.0
|
||||||
|
min_position_usdt: float = 10.0
|
||||||
|
leverage: int = 1
|
||||||
|
margin_mode: str = "cross"
|
||||||
|
stop_loss_pct: float = 0.06
|
||||||
|
take_profit_pct: float = 0.05
|
||||||
|
max_concurrent_positions: int = 1
|
||||||
|
z_entry_threshold: float = 1.0
|
||||||
|
z_window: int = 24
|
||||||
|
model_prob_threshold: float = 0.5
|
||||||
|
funding_threshold: float = 0.0005
|
||||||
|
sleep_seconds: int = 3600
|
||||||
|
slippage_pct: float = 0.001
|
||||||
|
|
||||||
|
adapted = AdaptedConfig()
|
||||||
|
adapted.timeframe = self.trading_config.timeframe
|
||||||
|
adapted.candles_to_fetch = self.trading_config.candles_to_fetch
|
||||||
|
adapted.max_position_usdt = self.trading_config.max_position_usdt
|
||||||
|
adapted.min_position_usdt = self.trading_config.min_position_usdt
|
||||||
|
adapted.leverage = self.trading_config.leverage
|
||||||
|
adapted.margin_mode = self.trading_config.margin_mode
|
||||||
|
adapted.max_concurrent_positions = self.trading_config.max_concurrent_positions
|
||||||
|
adapted.sleep_seconds = self.trading_config.sleep_seconds
|
||||||
|
adapted.slippage_pct = self.trading_config.slippage_pct
|
||||||
|
|
||||||
|
return adapted
|
||||||
|
|
||||||
|
def _print_startup_banner(self) -> None:
|
||||||
|
"""Print startup information."""
|
||||||
|
mode = "DEMO/SANDBOX" if self.okx_config.demo_mode else "LIVE"
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print(" Multi-Pair Divergence Strategy - Live Trading Bot")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f" Mode: {mode}")
|
||||||
|
print(f" Assets: {len(self.trading_config.assets)} assets")
|
||||||
|
print(f" Pairs: {self.trading_config.get_pair_count()} pairs")
|
||||||
|
print(f" Timeframe: {self.trading_config.timeframe}")
|
||||||
|
print(f" Max Position: ${self.trading_config.max_position_usdt if self.trading_config.max_position_usdt > 0 else 'All available'}")
|
||||||
|
print(f" Leverage: {self.trading_config.leverage}x")
|
||||||
|
print(f" Z-Entry: > {self.trading_config.z_entry_threshold}")
|
||||||
|
print(f" Prob Threshold: > {self.trading_config.prob_threshold}")
|
||||||
|
print(f" Cycle Interval: {self.trading_config.sleep_seconds // 60} minutes")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f" Assets: {', '.join([a.split('/')[0] for a in self.trading_config.assets])}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if not self.okx_config.demo_mode:
|
||||||
|
print("\n *** WARNING: LIVE TRADING MODE - REAL FUNDS AT RISK ***\n")
|
||||||
|
|
||||||
|
def _handle_shutdown(self, signum, frame) -> None:
|
||||||
|
"""Handle shutdown signals gracefully."""
|
||||||
|
self.logger.info("Shutdown signal received, stopping...")
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
def _sync_position_from_exchange(self) -> bool:
|
||||||
|
"""
|
||||||
|
Sync internal position state with exchange positions.
|
||||||
|
|
||||||
|
Checks for existing open positions on the exchange and updates
|
||||||
|
internal state to match. This prevents stacking positions when
|
||||||
|
the bot is restarted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a position was synced, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
positions = self.okx_client.get_positions()
|
||||||
|
|
||||||
|
if not positions:
|
||||||
|
if self.position.pair is not None:
|
||||||
|
# Position was closed externally (e.g., SL/TP hit)
|
||||||
|
self.logger.info(
|
||||||
|
"Position %s was closed externally, resetting state",
|
||||||
|
self.position.pair.name if self.position.pair else "unknown"
|
||||||
|
)
|
||||||
|
self.position = PositionState()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check each position against our tradeable assets
|
||||||
|
our_assets = set(self.trading_config.assets)
|
||||||
|
|
||||||
|
for pos in positions:
|
||||||
|
pos_symbol = pos.get('symbol', '')
|
||||||
|
contracts = abs(float(pos.get('contracts', 0)))
|
||||||
|
|
||||||
|
if contracts == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this position is for one of our assets
|
||||||
|
if pos_symbol not in our_assets:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Found a position for one of our assets
|
||||||
|
side = pos.get('side', 'long')
|
||||||
|
entry_price = float(pos.get('entryPrice', 0))
|
||||||
|
unrealized_pnl = float(pos.get('unrealizedPnl', 0))
|
||||||
|
|
||||||
|
# If we already track this position, just update
|
||||||
|
if (self.position.pair is not None and
|
||||||
|
self.position.pair.base_asset == pos_symbol):
|
||||||
|
self.logger.debug(
|
||||||
|
"Position already tracked: %s %s %.2f contracts",
|
||||||
|
side, pos_symbol, contracts
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# New position found - sync it
|
||||||
|
# Find or create a TradingPair for this position
|
||||||
|
matched_pair = None
|
||||||
|
for pair in self.data_feed.pairs:
|
||||||
|
if pair.base_asset == pos_symbol:
|
||||||
|
matched_pair = pair
|
||||||
|
break
|
||||||
|
|
||||||
|
if matched_pair is None:
|
||||||
|
# Create a placeholder pair (we don't know the quote asset)
|
||||||
|
matched_pair = TradingPair(
|
||||||
|
base_asset=pos_symbol,
|
||||||
|
quote_asset="UNKNOWN"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate approximate SL/TP based on config defaults
|
||||||
|
sl_pct = self.trading_config.base_sl_pct
|
||||||
|
tp_pct = self.trading_config.base_tp_pct
|
||||||
|
|
||||||
|
if side == 'long':
|
||||||
|
stop_loss = entry_price * (1 - sl_pct)
|
||||||
|
take_profit = entry_price * (1 + tp_pct)
|
||||||
|
else:
|
||||||
|
stop_loss = entry_price * (1 + sl_pct)
|
||||||
|
take_profit = entry_price * (1 - tp_pct)
|
||||||
|
|
||||||
|
self.position = PositionState(
|
||||||
|
pair=matched_pair,
|
||||||
|
pair_id=matched_pair.pair_id,
|
||||||
|
direction=side,
|
||||||
|
entry_price=entry_price,
|
||||||
|
size=contracts,
|
||||||
|
stop_loss=stop_loss,
|
||||||
|
take_profit=take_profit,
|
||||||
|
entry_time=None # Unknown for synced positions
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"Synced existing position from exchange: %s %s %.4f @ %.4f (PnL: %.2f)",
|
||||||
|
side.upper(),
|
||||||
|
pos_symbol,
|
||||||
|
contracts,
|
||||||
|
entry_price,
|
||||||
|
unrealized_pnl
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# No matching positions found
|
||||||
|
if self.position.pair is not None:
|
||||||
|
self.logger.info(
|
||||||
|
"Position %s no longer exists on exchange, resetting state",
|
||||||
|
self.position.pair.name
|
||||||
|
)
|
||||||
|
self.position = PositionState()
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Failed to sync position from exchange: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def run_trading_cycle(self) -> None:
|
||||||
|
"""
|
||||||
|
Execute one trading cycle.
|
||||||
|
|
||||||
|
1. Sync position state with exchange
|
||||||
|
2. Fetch latest market data for all assets
|
||||||
|
3. Calculate features for all pairs
|
||||||
|
4. Score pairs and find best opportunity
|
||||||
|
5. Check exit conditions for current position
|
||||||
|
6. Execute trades if needed
|
||||||
|
"""
|
||||||
|
cycle_start = datetime.now(timezone.utc)
|
||||||
|
self.logger.info("--- Trading Cycle Start: %s ---", cycle_start.isoformat())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Sync position state with exchange (detect SL/TP closures)
|
||||||
|
self._sync_position_from_exchange()
|
||||||
|
|
||||||
|
# 2. Fetch all market data
|
||||||
|
pair_features = self.data_feed.get_latest_data()
|
||||||
|
if pair_features is None:
|
||||||
|
self.logger.warning("No market data available, skipping cycle")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. Check exit conditions for current position
|
||||||
|
if self.position.pair is not None:
|
||||||
|
exit_signal = self.strategy.check_exit_signal(
|
||||||
|
pair_features,
|
||||||
|
self.position.pair_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if exit_signal['action'] == 'exit':
|
||||||
|
self._execute_exit(exit_signal)
|
||||||
|
else:
|
||||||
|
# Check SL/TP
|
||||||
|
current_price = self.data_feed.get_current_price(
|
||||||
|
self.position.pair.base_asset
|
||||||
|
)
|
||||||
|
if current_price:
|
||||||
|
sl_tp_exit = self._check_sl_tp(current_price)
|
||||||
|
if sl_tp_exit:
|
||||||
|
self._execute_exit({'reason': sl_tp_exit})
|
||||||
|
|
||||||
|
# 3. Generate entry signal if no position
|
||||||
|
if self.position.pair is None:
|
||||||
|
entry_signal = self.strategy.generate_signal(
|
||||||
|
pair_features,
|
||||||
|
self.data_feed.pairs
|
||||||
|
)
|
||||||
|
|
||||||
|
if entry_signal['action'] == 'entry':
|
||||||
|
self._execute_entry(entry_signal)
|
||||||
|
|
||||||
|
# 4. Log status
|
||||||
|
if self.position.pair:
|
||||||
|
self.logger.info(
|
||||||
|
"Position: %s %s, entry=%.4f, current PnL check pending",
|
||||||
|
self.position.direction,
|
||||||
|
self.position.pair.name,
|
||||||
|
self.position.entry_price
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.logger.info("No open position")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Trading cycle error: %s", e, exc_info=True)
|
||||||
|
|
||||||
|
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
|
||||||
|
self.logger.info("--- Cycle completed in %.1fs ---", cycle_duration)
|
||||||
|
|
||||||
|
def _check_sl_tp(self, current_price: float) -> str | None:
|
||||||
|
"""Check stop-loss and take-profit levels."""
|
||||||
|
if self.position.direction == 'long':
|
||||||
|
if current_price <= self.position.stop_loss:
|
||||||
|
return f"stop_loss ({current_price:.4f} <= {self.position.stop_loss:.4f})"
|
||||||
|
if current_price >= self.position.take_profit:
|
||||||
|
return f"take_profit ({current_price:.4f} >= {self.position.take_profit:.4f})"
|
||||||
|
else: # short
|
||||||
|
if current_price >= self.position.stop_loss:
|
||||||
|
return f"stop_loss ({current_price:.4f} >= {self.position.stop_loss:.4f})"
|
||||||
|
if current_price <= self.position.take_profit:
|
||||||
|
return f"take_profit ({current_price:.4f} <= {self.position.take_profit:.4f})"
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _execute_entry(self, signal: dict) -> None:
|
||||||
|
"""Execute entry trade."""
|
||||||
|
pair = signal['pair']
|
||||||
|
symbol = pair.base_asset # Trade the base asset
|
||||||
|
direction = signal['direction']
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"Entry signal: %s %s (z=%.2f, p=%.2f, score=%.3f)",
|
||||||
|
direction.upper(),
|
||||||
|
pair.name,
|
||||||
|
signal['z_score'],
|
||||||
|
signal['probability'],
|
||||||
|
signal['divergence_score']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get account balance
|
||||||
|
try:
|
||||||
|
balance = self.okx_client.get_balance()
|
||||||
|
available_usdt = balance['free']
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Could not get balance: %s", e)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate position size
|
||||||
|
size_usdt = self.strategy.calculate_position_size(
|
||||||
|
signal['divergence_score'],
|
||||||
|
available_usdt
|
||||||
|
)
|
||||||
|
|
||||||
|
if size_usdt <= 0:
|
||||||
|
self.logger.info("Position size too small, skipping entry")
|
||||||
|
return
|
||||||
|
|
||||||
|
current_price = signal['base_price']
|
||||||
|
size_asset = size_usdt / current_price
|
||||||
|
|
||||||
|
# Calculate SL/TP
|
||||||
|
stop_loss, take_profit = self.strategy.calculate_sl_tp(
|
||||||
|
current_price,
|
||||||
|
direction,
|
||||||
|
signal['atr'],
|
||||||
|
signal['atr_pct']
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"Executing %s entry: %.6f %s @ %.4f ($%.2f), SL=%.4f, TP=%.4f",
|
||||||
|
direction.upper(),
|
||||||
|
size_asset,
|
||||||
|
symbol.split('/')[0],
|
||||||
|
current_price,
|
||||||
|
size_usdt,
|
||||||
|
stop_loss,
|
||||||
|
take_profit
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Place market order
|
||||||
|
order_side = "buy" if direction == "long" else "sell"
|
||||||
|
order = self.okx_client.place_market_order(symbol, order_side, size_asset)
|
||||||
|
|
||||||
|
filled_price = order.get('average') or order.get('price') or current_price
|
||||||
|
filled_amount = order.get('filled') or order.get('amount') or size_asset
|
||||||
|
|
||||||
|
if filled_price is None or filled_price == 0:
|
||||||
|
filled_price = current_price
|
||||||
|
if filled_amount is None or filled_amount == 0:
|
||||||
|
filled_amount = size_asset
|
||||||
|
|
||||||
|
# Recalculate SL/TP with filled price
|
||||||
|
stop_loss, take_profit = self.strategy.calculate_sl_tp(
|
||||||
|
filled_price, direction, signal['atr'], signal['atr_pct']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update position state
|
||||||
|
self.position = PositionState(
|
||||||
|
pair=pair,
|
||||||
|
pair_id=pair.pair_id,
|
||||||
|
direction=direction,
|
||||||
|
entry_price=filled_price,
|
||||||
|
size=filled_amount,
|
||||||
|
stop_loss=stop_loss,
|
||||||
|
take_profit=take_profit,
|
||||||
|
entry_time=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"Position opened: %s %s %.6f @ %.4f",
|
||||||
|
direction.upper(),
|
||||||
|
pair.name,
|
||||||
|
filled_amount,
|
||||||
|
filled_price
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to set SL/TP on exchange
|
||||||
|
try:
|
||||||
|
self.okx_client.set_stop_loss_take_profit(
|
||||||
|
symbol, direction, filled_amount, stop_loss, take_profit
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning("Could not set SL/TP on exchange: %s", e)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Order execution failed: %s", e, exc_info=True)
|
||||||
|
|
||||||
|
def _execute_exit(self, signal: dict) -> None:
|
||||||
|
"""Execute exit trade."""
|
||||||
|
if self.position.pair is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
symbol = self.position.pair.base_asset
|
||||||
|
reason = signal.get('reason', 'unknown')
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"Exit signal: %s %s, reason: %s",
|
||||||
|
self.position.direction,
|
||||||
|
self.position.pair.name,
|
||||||
|
reason
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Close position on exchange
|
||||||
|
self.okx_client.close_position(symbol)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"Position closed: %s %s",
|
||||||
|
self.position.direction,
|
||||||
|
self.position.pair.name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset position state
|
||||||
|
self.position = PositionState()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Exit execution failed: %s", e, exc_info=True)
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""Main trading loop."""
|
||||||
|
self.logger.info("Starting multi-pair trading loop...")
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
self.run_trading_cycle()
|
||||||
|
|
||||||
|
if self.running:
|
||||||
|
sleep_seconds = self.trading_config.sleep_seconds
|
||||||
|
minutes = sleep_seconds // 60
|
||||||
|
self.logger.info("Sleeping for %d minutes...", minutes)
|
||||||
|
|
||||||
|
for _ in range(sleep_seconds):
|
||||||
|
if not self.running:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
self.logger.info("Keyboard interrupt received")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Unexpected error in main loop: %s", e, exc_info=True)
|
||||||
|
time.sleep(60)
|
||||||
|
|
||||||
|
self.logger.info("Shutting down...")
|
||||||
|
self.logger.info("Shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""Parse command line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Multi-Pair Divergence Live Trading Bot"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-position",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Maximum position size in USDT"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--leverage",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Trading leverage (1-125)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--interval",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Trading cycle interval in seconds"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--live",
|
||||||
|
action="store_true",
|
||||||
|
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point."""
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
okx_config, trading_config, path_config = get_multi_pair_config()
|
||||||
|
|
||||||
|
# Apply command line overrides
|
||||||
|
if args.max_position is not None:
|
||||||
|
trading_config.max_position_usdt = args.max_position
|
||||||
|
if args.leverage is not None:
|
||||||
|
trading_config.leverage = args.leverage
|
||||||
|
if args.interval is not None:
|
||||||
|
trading_config.sleep_seconds = args.interval
|
||||||
|
if args.live:
|
||||||
|
okx_config.demo_mode = False
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logger = setup_logging(path_config.logs_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate config
|
||||||
|
okx_config.validate()
|
||||||
|
|
||||||
|
# Create and run bot
|
||||||
|
bot = MultiPairLiveTradingBot(okx_config, trading_config, path_config)
|
||||||
|
bot.run()
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error("Configuration error: %s", e)
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Fatal error: %s", e, exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
396
live_trading/multi_pair/strategy.py
Normal file
396
live_trading/multi_pair/strategy.py
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
"""
|
||||||
|
Live Multi-Pair Divergence Strategy.
|
||||||
|
|
||||||
|
Scores all pairs and selects the best divergence opportunity for trading.
|
||||||
|
Uses the pre-trained universal ML model from backtesting.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
|
||||||
|
# Opt-in to future pandas behavior to silence FutureWarning on fillna
|
||||||
|
pd.set_option('future.no_silent_downcasting', True)
|
||||||
|
|
||||||
|
from .config import MultiPairLiveConfig, PathConfig
|
||||||
|
from .data_feed import TradingPair
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DivergenceSignal:
|
||||||
|
"""
|
||||||
|
Signal for a divergent pair.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pair: Trading pair
|
||||||
|
z_score: Current Z-Score of the spread
|
||||||
|
probability: ML model probability of profitable reversion
|
||||||
|
divergence_score: Combined score (|z_score| * probability)
|
||||||
|
direction: 'long' or 'short' (relative to base asset)
|
||||||
|
base_price: Current price of base asset
|
||||||
|
quote_price: Current price of quote asset
|
||||||
|
atr: Average True Range in price units
|
||||||
|
atr_pct: ATR as percentage of price
|
||||||
|
"""
|
||||||
|
pair: TradingPair
|
||||||
|
z_score: float
|
||||||
|
probability: float
|
||||||
|
divergence_score: float
|
||||||
|
direction: str
|
||||||
|
base_price: float
|
||||||
|
quote_price: float
|
||||||
|
atr: float
|
||||||
|
atr_pct: float
|
||||||
|
base_funding: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class LiveMultiPairStrategy:
|
||||||
|
"""
|
||||||
|
Live trading implementation of multi-pair divergence strategy.
|
||||||
|
|
||||||
|
Scores all pairs using the universal ML model and selects
|
||||||
|
the best opportunity for mean-reversion trading.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MultiPairLiveConfig,
|
||||||
|
path_config: PathConfig
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.paths = path_config
|
||||||
|
self.model: RandomForestClassifier | None = None
|
||||||
|
self.feature_cols: list[str] | None = None
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self) -> None:
|
||||||
|
"""Load pre-trained model from backtesting."""
|
||||||
|
if self.paths.model_path.exists():
|
||||||
|
try:
|
||||||
|
with open(self.paths.model_path, 'rb') as f:
|
||||||
|
saved = pickle.load(f)
|
||||||
|
self.model = saved['model']
|
||||||
|
self.feature_cols = saved['feature_cols']
|
||||||
|
logger.info("Loaded model from %s", self.paths.model_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Could not load model: %s", e)
|
||||||
|
raise ValueError(
|
||||||
|
f"Multi-pair model not found at {self.paths.model_path}. "
|
||||||
|
"Run the backtest first to train the model."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multi-pair model not found at {self.paths.model_path}. "
|
||||||
|
"Run the backtest first to train the model."
|
||||||
|
)
|
||||||
|
|
||||||
|
def score_pairs(
|
||||||
|
self,
|
||||||
|
pair_features: dict[str, pd.DataFrame],
|
||||||
|
pairs: list[TradingPair]
|
||||||
|
) -> list[DivergenceSignal]:
|
||||||
|
"""
|
||||||
|
Score all pairs and return ranked signals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pair_features: Feature DataFrames by pair_id
|
||||||
|
pairs: List of TradingPair objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of DivergenceSignal sorted by score (descending)
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
logger.warning("Model not loaded")
|
||||||
|
return []
|
||||||
|
|
||||||
|
signals = []
|
||||||
|
pair_map = {p.pair_id: p for p in pairs}
|
||||||
|
|
||||||
|
for pair_id, features in pair_features.items():
|
||||||
|
if pair_id not in pair_map:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pair = pair_map[pair_id]
|
||||||
|
|
||||||
|
# Get latest features
|
||||||
|
if len(features) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
latest = features.iloc[-1]
|
||||||
|
z_score = latest['z_score']
|
||||||
|
|
||||||
|
# Skip if Z-score below threshold
|
||||||
|
if abs(z_score) < self.config.z_entry_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare features for prediction
|
||||||
|
# Handle missing feature columns gracefully
|
||||||
|
available_cols = [c for c in self.feature_cols if c in latest.index]
|
||||||
|
missing_cols = [c for c in self.feature_cols if c not in latest.index]
|
||||||
|
|
||||||
|
if missing_cols:
|
||||||
|
logger.debug("Missing feature columns: %s", missing_cols)
|
||||||
|
|
||||||
|
feature_row = latest[available_cols].fillna(0)
|
||||||
|
feature_row = feature_row.replace([np.inf, -np.inf], 0)
|
||||||
|
|
||||||
|
# Create full feature vector with zeros for missing
|
||||||
|
X_dict = {c: 0 for c in self.feature_cols}
|
||||||
|
for col in available_cols:
|
||||||
|
X_dict[col] = feature_row[col]
|
||||||
|
|
||||||
|
X = pd.DataFrame([X_dict])
|
||||||
|
|
||||||
|
# Predict probability
|
||||||
|
prob = self.model.predict_proba(X)[0, 1]
|
||||||
|
|
||||||
|
# Skip if probability below threshold
|
||||||
|
if prob < self.config.prob_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Apply funding rate filter
|
||||||
|
base_funding = latest.get('base_funding', 0) or 0
|
||||||
|
funding_thresh = self.config.funding_threshold
|
||||||
|
|
||||||
|
if z_score > 0: # Short signal
|
||||||
|
if base_funding < -funding_thresh:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping %s short: funding too negative (%.4f)",
|
||||||
|
pair.name, base_funding
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else: # Long signal
|
||||||
|
if base_funding > funding_thresh:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping %s long: funding too positive (%.4f)",
|
||||||
|
pair.name, base_funding
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate divergence score
|
||||||
|
divergence_score = abs(z_score) * prob
|
||||||
|
|
||||||
|
# Determine direction
|
||||||
|
direction = 'short' if z_score > 0 else 'long'
|
||||||
|
|
||||||
|
signal = DivergenceSignal(
|
||||||
|
pair=pair,
|
||||||
|
z_score=z_score,
|
||||||
|
probability=prob,
|
||||||
|
divergence_score=divergence_score,
|
||||||
|
direction=direction,
|
||||||
|
base_price=latest['base_close'],
|
||||||
|
quote_price=latest['quote_close'],
|
||||||
|
atr=latest.get('atr_base', 0),
|
||||||
|
atr_pct=latest.get('atr_pct_base', 0.02),
|
||||||
|
base_funding=base_funding
|
||||||
|
)
|
||||||
|
signals.append(signal)
|
||||||
|
|
||||||
|
# Sort by divergence score (highest first)
|
||||||
|
signals.sort(key=lambda s: s.divergence_score, reverse=True)
|
||||||
|
|
||||||
|
if signals:
|
||||||
|
logger.info(
|
||||||
|
"Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f, dir=%s)",
|
||||||
|
len(signals),
|
||||||
|
signals[0].pair.name,
|
||||||
|
signals[0].divergence_score,
|
||||||
|
signals[0].z_score,
|
||||||
|
signals[0].probability,
|
||||||
|
signals[0].direction
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("No pairs meet entry criteria")
|
||||||
|
|
||||||
|
return signals
|
||||||
|
|
||||||
|
def select_best_pair(
|
||||||
|
self,
|
||||||
|
signals: list[DivergenceSignal]
|
||||||
|
) -> DivergenceSignal | None:
|
||||||
|
"""
|
||||||
|
Select the best pair from scored signals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signals: List of DivergenceSignal (pre-sorted by score)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Best signal or None if no valid candidates
|
||||||
|
"""
|
||||||
|
if not signals:
|
||||||
|
return None
|
||||||
|
return signals[0]
|
||||||
|
|
||||||
|
def generate_signal(
|
||||||
|
self,
|
||||||
|
pair_features: dict[str, pd.DataFrame],
|
||||||
|
pairs: list[TradingPair]
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Generate trading signal from latest features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pair_features: Feature DataFrames by pair_id
|
||||||
|
pairs: List of TradingPair objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signal dictionary with action, pair, direction, etc.
|
||||||
|
"""
|
||||||
|
# Score all pairs
|
||||||
|
signals = self.score_pairs(pair_features, pairs)
|
||||||
|
|
||||||
|
# Select best
|
||||||
|
best = self.select_best_pair(signals)
|
||||||
|
|
||||||
|
if best is None:
|
||||||
|
return {
|
||||||
|
'action': 'hold',
|
||||||
|
'reason': 'no_valid_signals'
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'action': 'entry',
|
||||||
|
'pair': best.pair,
|
||||||
|
'pair_id': best.pair.pair_id,
|
||||||
|
'direction': best.direction,
|
||||||
|
'z_score': best.z_score,
|
||||||
|
'probability': best.probability,
|
||||||
|
'divergence_score': best.divergence_score,
|
||||||
|
'base_price': best.base_price,
|
||||||
|
'quote_price': best.quote_price,
|
||||||
|
'atr': best.atr,
|
||||||
|
'atr_pct': best.atr_pct,
|
||||||
|
'base_funding': best.base_funding,
|
||||||
|
'reason': f'{best.pair.name} z={best.z_score:.2f} p={best.probability:.2f}'
|
||||||
|
}
|
||||||
|
|
||||||
|
def check_exit_signal(
|
||||||
|
self,
|
||||||
|
pair_features: dict[str, pd.DataFrame],
|
||||||
|
current_pair_id: str
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Check if current position should be exited.
|
||||||
|
|
||||||
|
Exit conditions:
|
||||||
|
1. Z-Score reverted to mean (|Z| < threshold)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pair_features: Feature DataFrames by pair_id
|
||||||
|
current_pair_id: Current position's pair ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signal dictionary with action and reason
|
||||||
|
"""
|
||||||
|
if current_pair_id not in pair_features:
|
||||||
|
return {
|
||||||
|
'action': 'exit',
|
||||||
|
'reason': 'pair_data_missing'
|
||||||
|
}
|
||||||
|
|
||||||
|
features = pair_features[current_pair_id]
|
||||||
|
if len(features) == 0:
|
||||||
|
return {
|
||||||
|
'action': 'exit',
|
||||||
|
'reason': 'no_data'
|
||||||
|
}
|
||||||
|
|
||||||
|
latest = features.iloc[-1]
|
||||||
|
z_score = latest['z_score']
|
||||||
|
|
||||||
|
# Check mean reversion
|
||||||
|
if abs(z_score) < self.config.z_exit_threshold:
|
||||||
|
return {
|
||||||
|
'action': 'exit',
|
||||||
|
'reason': f'mean_reversion (z={z_score:.2f})'
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'action': 'hold',
|
||||||
|
'z_score': z_score,
|
||||||
|
'reason': f'holding (z={z_score:.2f})'
|
||||||
|
}
|
||||||
|
|
||||||
|
def calculate_sl_tp(
|
||||||
|
self,
|
||||||
|
entry_price: float,
|
||||||
|
direction: str,
|
||||||
|
atr: float,
|
||||||
|
atr_pct: float
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Calculate ATR-based dynamic stop-loss and take-profit prices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry_price: Entry price
|
||||||
|
direction: 'long' or 'short'
|
||||||
|
atr: ATR in price units
|
||||||
|
atr_pct: ATR as percentage of price
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (stop_loss_price, take_profit_price)
|
||||||
|
"""
|
||||||
|
if atr > 0 and atr_pct > 0:
|
||||||
|
sl_distance = atr * self.config.sl_atr_multiplier
|
||||||
|
tp_distance = atr * self.config.tp_atr_multiplier
|
||||||
|
|
||||||
|
sl_pct = sl_distance / entry_price
|
||||||
|
tp_pct = tp_distance / entry_price
|
||||||
|
else:
|
||||||
|
sl_pct = self.config.base_sl_pct
|
||||||
|
tp_pct = self.config.base_tp_pct
|
||||||
|
|
||||||
|
# Apply bounds
|
||||||
|
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
|
||||||
|
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
|
||||||
|
|
||||||
|
if direction == 'long':
|
||||||
|
stop_loss = entry_price * (1 - sl_pct)
|
||||||
|
take_profit = entry_price * (1 + tp_pct)
|
||||||
|
else:
|
||||||
|
stop_loss = entry_price * (1 + sl_pct)
|
||||||
|
take_profit = entry_price * (1 - tp_pct)
|
||||||
|
|
||||||
|
return stop_loss, take_profit
|
||||||
|
|
||||||
|
def calculate_position_size(
|
||||||
|
self,
|
||||||
|
divergence_score: float,
|
||||||
|
available_usdt: float
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Calculate position size based on divergence score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
divergence_score: Combined score (|z| * prob)
|
||||||
|
available_usdt: Available USDT balance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Position size in USDT
|
||||||
|
"""
|
||||||
|
if self.config.max_position_usdt <= 0:
|
||||||
|
base_size = available_usdt
|
||||||
|
else:
|
||||||
|
base_size = min(available_usdt, self.config.max_position_usdt)
|
||||||
|
|
||||||
|
# Scale by divergence (1.0 at 0.5 score, up to 2.0 at 1.0+ score)
|
||||||
|
base_threshold = 0.5
|
||||||
|
if divergence_score <= base_threshold:
|
||||||
|
scale = 1.0
|
||||||
|
else:
|
||||||
|
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
|
||||||
|
scale = min(scale, 2.0)
|
||||||
|
|
||||||
|
size = base_size * scale
|
||||||
|
|
||||||
|
if size < self.config.min_position_usdt:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return min(size, available_usdt * 0.95)
|
||||||
@@ -153,7 +153,7 @@ class OKXClient:
|
|||||||
reduce_only: bool = False
|
reduce_only: bool = False
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Place a market order and fetch the fill price.
|
Place a market order.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
symbol: Trading pair symbol
|
symbol: Trading pair symbol
|
||||||
@@ -162,10 +162,7 @@ class OKXClient:
|
|||||||
reduce_only: If True, only reduce existing position
|
reduce_only: If True, only reduce existing position
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Order result dictionary with guaranteed 'average' fill price
|
Order result dictionary
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If order placement fails or fill price unavailable
|
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
'tdMode': self.trading_config.margin_mode,
|
'tdMode': self.trading_config.margin_mode,
|
||||||
@@ -176,48 +173,10 @@ class OKXClient:
|
|||||||
order = self.exchange.create_market_order(
|
order = self.exchange.create_market_order(
|
||||||
symbol, side, amount, params=params
|
symbol, side, amount, params=params
|
||||||
)
|
)
|
||||||
|
|
||||||
order_id = order.get('id')
|
|
||||||
if not order_id:
|
|
||||||
raise RuntimeError(f"Order placement failed: no order ID returned")
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Market {side.upper()} order placed: {amount} {symbol} "
|
f"Market {side.upper()} order placed: {amount} {symbol} "
|
||||||
f"@ market price, order_id={order_id}"
|
f"@ market price, order_id={order['id']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fetch order to get actual fill price if not in initial response
|
|
||||||
fill_price = order.get('average')
|
|
||||||
if fill_price is None or fill_price == 0:
|
|
||||||
logger.info(f"Fetching order {order_id} for fill price...")
|
|
||||||
try:
|
|
||||||
fetched_order = self.exchange.fetch_order(order_id, symbol)
|
|
||||||
fill_price = fetched_order.get('average')
|
|
||||||
order['average'] = fill_price
|
|
||||||
order['filled'] = fetched_order.get('filled', order.get('filled'))
|
|
||||||
order['status'] = fetched_order.get('status', order.get('status'))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not fetch order details: {e}")
|
|
||||||
|
|
||||||
# Final fallback: use current ticker price
|
|
||||||
if fill_price is None or fill_price == 0:
|
|
||||||
logger.warning(
|
|
||||||
f"No fill price from order response, fetching ticker..."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
ticker = self.get_ticker(symbol)
|
|
||||||
fill_price = ticker.get('last')
|
|
||||||
order['average'] = fill_price
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Could not fetch ticker: {e}")
|
|
||||||
|
|
||||||
if fill_price is None or fill_price <= 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Could not determine fill price for order {order_id}. "
|
|
||||||
f"Order response: {order}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Order {order_id} filled at {fill_price}")
|
|
||||||
return order
|
return order
|
||||||
|
|
||||||
def place_limit_order(
|
def place_limit_order(
|
||||||
|
|||||||
@@ -3,21 +3,16 @@ Position Manager for Live Trading.
|
|||||||
|
|
||||||
Tracks open positions, manages risk, and handles SL/TP logic.
|
Tracks open positions, manages risk, and handles SL/TP logic.
|
||||||
"""
|
"""
|
||||||
import csv
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, TYPE_CHECKING
|
from typing import Optional
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
|
|
||||||
from .okx_client import OKXClient
|
from .okx_client import OKXClient
|
||||||
from .config import TradingConfig, PathConfig
|
from .config import TradingConfig, PathConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .db.database import TradingDatabase
|
|
||||||
from .db.models import Trade
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -83,13 +78,11 @@ class PositionManager:
|
|||||||
self,
|
self,
|
||||||
okx_client: OKXClient,
|
okx_client: OKXClient,
|
||||||
trading_config: TradingConfig,
|
trading_config: TradingConfig,
|
||||||
path_config: PathConfig,
|
path_config: PathConfig
|
||||||
database: Optional["TradingDatabase"] = None,
|
|
||||||
):
|
):
|
||||||
self.client = okx_client
|
self.client = okx_client
|
||||||
self.config = trading_config
|
self.config = trading_config
|
||||||
self.paths = path_config
|
self.paths = path_config
|
||||||
self.db = database
|
|
||||||
self.positions: dict[str, Position] = {}
|
self.positions: dict[str, Position] = {}
|
||||||
self.trade_log: list[dict] = []
|
self.trade_log: list[dict] = []
|
||||||
self._load_positions()
|
self._load_positions()
|
||||||
@@ -256,55 +249,16 @@ class PositionManager:
|
|||||||
return trade_record
|
return trade_record
|
||||||
|
|
||||||
def _append_trade_log(self, trade_record: dict) -> None:
|
def _append_trade_log(self, trade_record: dict) -> None:
|
||||||
"""Append trade record to CSV and SQLite database."""
|
|
||||||
# Write to CSV (backup/compatibility)
|
|
||||||
self._append_trade_csv(trade_record)
|
|
||||||
|
|
||||||
# Write to SQLite (primary)
|
|
||||||
self._append_trade_db(trade_record)
|
|
||||||
|
|
||||||
def _append_trade_csv(self, trade_record: dict) -> None:
|
|
||||||
"""Append trade record to CSV log file."""
|
"""Append trade record to CSV log file."""
|
||||||
|
import csv
|
||||||
|
|
||||||
file_exists = self.paths.trade_log_file.exists()
|
file_exists = self.paths.trade_log_file.exists()
|
||||||
|
|
||||||
try:
|
|
||||||
with open(self.paths.trade_log_file, 'a', newline='') as f:
|
with open(self.paths.trade_log_file, 'a', newline='') as f:
|
||||||
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
|
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
|
||||||
if not file_exists:
|
if not file_exists:
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
writer.writerow(trade_record)
|
writer.writerow(trade_record)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to write trade to CSV: {e}")
|
|
||||||
|
|
||||||
def _append_trade_db(self, trade_record: dict) -> None:
|
|
||||||
"""Append trade record to SQLite database."""
|
|
||||||
if self.db is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .db.models import Trade
|
|
||||||
|
|
||||||
trade = Trade(
|
|
||||||
trade_id=trade_record['trade_id'],
|
|
||||||
symbol=trade_record['symbol'],
|
|
||||||
side=trade_record['side'],
|
|
||||||
entry_price=trade_record['entry_price'],
|
|
||||||
exit_price=trade_record.get('exit_price'),
|
|
||||||
size=trade_record['size'],
|
|
||||||
size_usdt=trade_record['size_usdt'],
|
|
||||||
pnl_usd=trade_record.get('pnl_usd'),
|
|
||||||
pnl_pct=trade_record.get('pnl_pct'),
|
|
||||||
entry_time=trade_record['entry_time'],
|
|
||||||
exit_time=trade_record.get('exit_time'),
|
|
||||||
hold_duration_hours=trade_record.get('hold_duration_hours'),
|
|
||||||
reason=trade_record.get('reason'),
|
|
||||||
order_id_entry=trade_record.get('order_id_entry'),
|
|
||||||
order_id_exit=trade_record.get('order_id_exit'),
|
|
||||||
)
|
|
||||||
self.db.insert_trade(trade)
|
|
||||||
logger.debug(f"Trade {trade.trade_id} saved to database")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to write trade to database: {e}")
|
|
||||||
|
|
||||||
def update_positions(self, current_prices: dict[str, float]) -> list[dict]:
|
def update_positions(self, current_prices: dict[str, float]) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
"""Terminal UI module for live trading dashboard."""
|
|
||||||
from .dashboard import TradingDashboard
|
|
||||||
from .state import SharedState
|
|
||||||
from .log_handler import UILogHandler
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"TradingDashboard",
|
|
||||||
"SharedState",
|
|
||||||
"UILogHandler",
|
|
||||||
]
|
|
||||||
@@ -1,240 +0,0 @@
|
|||||||
"""Main trading dashboard UI orchestration."""
|
|
||||||
import logging
|
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from typing import Optional, Callable
|
|
||||||
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.layout import Layout
|
|
||||||
from rich.live import Live
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from .state import SharedState
|
|
||||||
from .log_handler import LogBuffer, UILogHandler
|
|
||||||
from .keyboard import KeyboardHandler
|
|
||||||
from .panels import (
|
|
||||||
HeaderPanel,
|
|
||||||
TabBar,
|
|
||||||
LogPanel,
|
|
||||||
HelpBar,
|
|
||||||
build_summary_panel,
|
|
||||||
)
|
|
||||||
from ..db.database import TradingDatabase
|
|
||||||
from ..db.metrics import MetricsCalculator, PeriodMetrics
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TradingDashboard:
|
|
||||||
"""
|
|
||||||
Main trading dashboard orchestrator.
|
|
||||||
|
|
||||||
Runs in a separate thread and provides real-time UI updates
|
|
||||||
while the trading loop runs in the main thread.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state: SharedState,
|
|
||||||
db: TradingDatabase,
|
|
||||||
log_queue: queue.Queue,
|
|
||||||
on_quit: Optional[Callable] = None,
|
|
||||||
):
|
|
||||||
self.state = state
|
|
||||||
self.db = db
|
|
||||||
self.log_queue = log_queue
|
|
||||||
self.on_quit = on_quit
|
|
||||||
|
|
||||||
self.console = Console()
|
|
||||||
self.log_buffer = LogBuffer(max_entries=1000)
|
|
||||||
self.keyboard = KeyboardHandler()
|
|
||||||
self.metrics_calculator = MetricsCalculator(db)
|
|
||||||
|
|
||||||
self._running = False
|
|
||||||
self._thread: Optional[threading.Thread] = None
|
|
||||||
self._active_tab = 0
|
|
||||||
self._cached_metrics: dict[int, PeriodMetrics] = {}
|
|
||||||
self._last_metrics_refresh = 0.0
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
"""Start the dashboard in a separate thread."""
|
|
||||||
if self._running:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._running = True
|
|
||||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
logger.debug("Dashboard thread started")
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Stop the dashboard."""
|
|
||||||
self._running = False
|
|
||||||
if self._thread:
|
|
||||||
self._thread.join(timeout=2.0)
|
|
||||||
logger.debug("Dashboard thread stopped")
|
|
||||||
|
|
||||||
def _run(self) -> None:
|
|
||||||
"""Main dashboard loop."""
|
|
||||||
try:
|
|
||||||
with self.keyboard:
|
|
||||||
with Live(
|
|
||||||
self._build_layout(),
|
|
||||||
console=self.console,
|
|
||||||
refresh_per_second=1,
|
|
||||||
screen=True,
|
|
||||||
) as live:
|
|
||||||
while self._running and self.state.is_running():
|
|
||||||
# Process keyboard input
|
|
||||||
action = self.keyboard.get_action(timeout=0.1)
|
|
||||||
if action:
|
|
||||||
self._handle_action(action)
|
|
||||||
|
|
||||||
# Drain log queue
|
|
||||||
self.log_buffer.drain_queue(self.log_queue)
|
|
||||||
|
|
||||||
# Refresh metrics periodically (every 5 seconds)
|
|
||||||
now = time.time()
|
|
||||||
if now - self._last_metrics_refresh > 5.0:
|
|
||||||
self._refresh_metrics()
|
|
||||||
self._last_metrics_refresh = now
|
|
||||||
|
|
||||||
# Update display
|
|
||||||
live.update(self._build_layout())
|
|
||||||
|
|
||||||
# Small sleep to prevent CPU spinning
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Dashboard error: {e}", exc_info=True)
|
|
||||||
finally:
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
def _handle_action(self, action: str) -> None:
|
|
||||||
"""Handle keyboard action."""
|
|
||||||
if action == "quit":
|
|
||||||
logger.info("Quit requested from UI")
|
|
||||||
self.state.stop()
|
|
||||||
if self.on_quit:
|
|
||||||
self.on_quit()
|
|
||||||
|
|
||||||
elif action == "refresh":
|
|
||||||
self._refresh_metrics()
|
|
||||||
logger.debug("Manual refresh triggered")
|
|
||||||
|
|
||||||
elif action == "filter":
|
|
||||||
new_filter = self.log_buffer.cycle_filter()
|
|
||||||
logger.debug(f"Log filter changed to: {new_filter}")
|
|
||||||
|
|
||||||
elif action == "filter_trades":
|
|
||||||
self.log_buffer.set_filter(LogBuffer.FILTER_TRADES)
|
|
||||||
logger.debug("Log filter set to: trades")
|
|
||||||
|
|
||||||
elif action == "filter_all":
|
|
||||||
self.log_buffer.set_filter(LogBuffer.FILTER_ALL)
|
|
||||||
logger.debug("Log filter set to: all")
|
|
||||||
|
|
||||||
elif action == "filter_errors":
|
|
||||||
self.log_buffer.set_filter(LogBuffer.FILTER_ERRORS)
|
|
||||||
logger.debug("Log filter set to: errors")
|
|
||||||
|
|
||||||
elif action == "tab_general":
|
|
||||||
self._active_tab = 0
|
|
||||||
elif action == "tab_monthly":
|
|
||||||
if self._has_monthly_data():
|
|
||||||
self._active_tab = 1
|
|
||||||
elif action == "tab_weekly":
|
|
||||||
if self._has_weekly_data():
|
|
||||||
self._active_tab = 2
|
|
||||||
elif action == "tab_daily":
|
|
||||||
self._active_tab = 3
|
|
||||||
|
|
||||||
def _refresh_metrics(self) -> None:
|
|
||||||
"""Refresh metrics from database."""
|
|
||||||
try:
|
|
||||||
self._cached_metrics[0] = self.metrics_calculator.get_all_time_metrics()
|
|
||||||
self._cached_metrics[1] = self.metrics_calculator.get_monthly_metrics()
|
|
||||||
self._cached_metrics[2] = self.metrics_calculator.get_weekly_metrics()
|
|
||||||
self._cached_metrics[3] = self.metrics_calculator.get_daily_metrics()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to refresh metrics: {e}")
|
|
||||||
|
|
||||||
def _has_monthly_data(self) -> bool:
|
|
||||||
"""Check if monthly tab should be shown."""
|
|
||||||
try:
|
|
||||||
return self.metrics_calculator.has_monthly_data()
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _has_weekly_data(self) -> bool:
|
|
||||||
"""Check if weekly tab should be shown."""
|
|
||||||
try:
|
|
||||||
return self.metrics_calculator.has_weekly_data()
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _build_layout(self) -> Layout:
|
|
||||||
"""Build the complete dashboard layout."""
|
|
||||||
layout = Layout()
|
|
||||||
|
|
||||||
# Calculate available height
|
|
||||||
term_height = self.console.height or 40
|
|
||||||
|
|
||||||
# Header takes 3 lines
|
|
||||||
# Help bar takes 1 line
|
|
||||||
# Summary panel takes about 12-14 lines
|
|
||||||
# Rest goes to logs
|
|
||||||
log_height = max(8, term_height - 20)
|
|
||||||
|
|
||||||
layout.split_column(
|
|
||||||
Layout(name="header", size=3),
|
|
||||||
Layout(name="summary", size=14),
|
|
||||||
Layout(name="logs", size=log_height),
|
|
||||||
Layout(name="help", size=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Header
|
|
||||||
layout["header"].update(HeaderPanel(self.state).render())
|
|
||||||
|
|
||||||
# Summary panel with tabs
|
|
||||||
current_metrics = self._cached_metrics.get(self._active_tab)
|
|
||||||
tab_bar = TabBar(active_tab=self._active_tab)
|
|
||||||
|
|
||||||
layout["summary"].update(
|
|
||||||
build_summary_panel(
|
|
||||||
state=self.state,
|
|
||||||
metrics=current_metrics,
|
|
||||||
tab_bar=tab_bar,
|
|
||||||
has_monthly=self._has_monthly_data(),
|
|
||||||
has_weekly=self._has_weekly_data(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log panel
|
|
||||||
layout["logs"].update(LogPanel(self.log_buffer).render(height=log_height))
|
|
||||||
|
|
||||||
# Help bar
|
|
||||||
layout["help"].update(HelpBar().render())
|
|
||||||
|
|
||||||
return layout
|
|
||||||
|
|
||||||
|
|
||||||
def setup_ui_logging(log_queue: queue.Queue) -> UILogHandler:
|
|
||||||
"""
|
|
||||||
Set up logging to capture messages for UI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
log_queue: Queue to send log messages to
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UILogHandler instance
|
|
||||||
"""
|
|
||||||
handler = UILogHandler(log_queue)
|
|
||||||
handler.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
# Add handler to root logger
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
root_logger.addHandler(handler)
|
|
||||||
|
|
||||||
return handler
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
"""Keyboard input handling for terminal UI."""
|
|
||||||
import sys
|
|
||||||
import select
|
|
||||||
import termios
|
|
||||||
import tty
|
|
||||||
from typing import Optional, Callable
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class KeyAction:
|
|
||||||
"""Represents a keyboard action."""
|
|
||||||
|
|
||||||
key: str
|
|
||||||
action: str
|
|
||||||
description: str
|
|
||||||
|
|
||||||
|
|
||||||
class KeyboardHandler:
|
|
||||||
"""
|
|
||||||
Non-blocking keyboard input handler.
|
|
||||||
|
|
||||||
Uses terminal raw mode to capture single keypresses
|
|
||||||
without waiting for Enter.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Key mappings
|
|
||||||
ACTIONS = {
|
|
||||||
"q": "quit",
|
|
||||||
"Q": "quit",
|
|
||||||
"\x03": "quit", # Ctrl+C
|
|
||||||
"r": "refresh",
|
|
||||||
"R": "refresh",
|
|
||||||
"f": "filter",
|
|
||||||
"F": "filter",
|
|
||||||
"t": "filter_trades",
|
|
||||||
"T": "filter_trades",
|
|
||||||
"l": "filter_all",
|
|
||||||
"L": "filter_all",
|
|
||||||
"e": "filter_errors",
|
|
||||||
"E": "filter_errors",
|
|
||||||
"1": "tab_general",
|
|
||||||
"2": "tab_monthly",
|
|
||||||
"3": "tab_weekly",
|
|
||||||
"4": "tab_daily",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._old_settings = None
|
|
||||||
self._enabled = False
|
|
||||||
|
|
||||||
def enable(self) -> bool:
|
|
||||||
"""
|
|
||||||
Enable raw keyboard input mode.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if enabled successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not sys.stdin.isatty():
|
|
||||||
return False
|
|
||||||
|
|
||||||
self._old_settings = termios.tcgetattr(sys.stdin)
|
|
||||||
tty.setcbreak(sys.stdin.fileno())
|
|
||||||
self._enabled = True
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def disable(self) -> None:
|
|
||||||
"""Restore normal terminal mode."""
|
|
||||||
if self._enabled and self._old_settings:
|
|
||||||
try:
|
|
||||||
termios.tcsetattr(
|
|
||||||
sys.stdin,
|
|
||||||
termios.TCSADRAIN,
|
|
||||||
self._old_settings,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._enabled = False
|
|
||||||
|
|
||||||
def get_key(self, timeout: float = 0.1) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get a keypress if available (non-blocking).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timeout: Seconds to wait for input
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Key character or None if no input
|
|
||||||
"""
|
|
||||||
if not self._enabled:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
readable, _, _ = select.select([sys.stdin], [], [], timeout)
|
|
||||||
if readable:
|
|
||||||
return sys.stdin.read(1)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_action(self, timeout: float = 0.1) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get action name for pressed key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timeout: Seconds to wait for input
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Action name or None
|
|
||||||
"""
|
|
||||||
key = self.get_key(timeout)
|
|
||||||
if key:
|
|
||||||
return self.ACTIONS.get(key)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
"""Context manager entry."""
|
|
||||||
self.enable()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
"""Context manager exit."""
|
|
||||||
self.disable()
|
|
||||||
return False
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
"""Custom logging handler for UI integration."""
|
|
||||||
import logging
|
|
||||||
import queue
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LogEntry:
|
|
||||||
"""A single log entry."""
|
|
||||||
|
|
||||||
timestamp: str
|
|
||||||
level: str
|
|
||||||
message: str
|
|
||||||
logger_name: str
|
|
||||||
|
|
||||||
@property
|
|
||||||
def level_color(self) -> str:
|
|
||||||
"""Get Rich color for log level."""
|
|
||||||
colors = {
|
|
||||||
"DEBUG": "dim",
|
|
||||||
"INFO": "white",
|
|
||||||
"WARNING": "yellow",
|
|
||||||
"ERROR": "red",
|
|
||||||
"CRITICAL": "bold red",
|
|
||||||
}
|
|
||||||
return colors.get(self.level, "white")
|
|
||||||
|
|
||||||
|
|
||||||
class UILogHandler(logging.Handler):
|
|
||||||
"""
|
|
||||||
Custom logging handler that sends logs to UI.
|
|
||||||
|
|
||||||
Uses a thread-safe queue to pass log entries from the trading
|
|
||||||
thread to the UI thread.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
log_queue: queue.Queue,
|
|
||||||
max_entries: int = 1000,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.log_queue = log_queue
|
|
||||||
self.max_entries = max_entries
|
|
||||||
self.setFormatter(
|
|
||||||
logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
|
||||||
)
|
|
||||||
|
|
||||||
def emit(self, record: logging.LogRecord) -> None:
|
|
||||||
"""Emit a log record to the queue."""
|
|
||||||
try:
|
|
||||||
entry = LogEntry(
|
|
||||||
timestamp=datetime.fromtimestamp(record.created).strftime(
|
|
||||||
"%H:%M:%S"
|
|
||||||
),
|
|
||||||
level=record.levelname,
|
|
||||||
message=self.format_message(record),
|
|
||||||
logger_name=record.name,
|
|
||||||
)
|
|
||||||
# Non-blocking put, drop if queue is full
|
|
||||||
try:
|
|
||||||
self.log_queue.put_nowait(entry)
|
|
||||||
except queue.Full:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
self.handleError(record)
|
|
||||||
|
|
||||||
def format_message(self, record: logging.LogRecord) -> str:
|
|
||||||
"""Format the log message."""
|
|
||||||
return record.getMessage()
|
|
||||||
|
|
||||||
|
|
||||||
class LogBuffer:
|
|
||||||
"""
|
|
||||||
Thread-safe buffer for log entries with filtering support.
|
|
||||||
|
|
||||||
Maintains a fixed-size buffer of log entries and supports
|
|
||||||
filtering by log type.
|
|
||||||
"""
|
|
||||||
|
|
||||||
FILTER_ALL = "all"
|
|
||||||
FILTER_ERRORS = "errors"
|
|
||||||
FILTER_TRADES = "trades"
|
|
||||||
FILTER_SIGNALS = "signals"
|
|
||||||
|
|
||||||
FILTERS = [FILTER_ALL, FILTER_ERRORS, FILTER_TRADES, FILTER_SIGNALS]
|
|
||||||
|
|
||||||
def __init__(self, max_entries: int = 1000):
|
|
||||||
self.max_entries = max_entries
|
|
||||||
self._entries: deque[LogEntry] = deque(maxlen=max_entries)
|
|
||||||
self._current_filter = self.FILTER_ALL
|
|
||||||
|
|
||||||
def add(self, entry: LogEntry) -> None:
|
|
||||||
"""Add a log entry to the buffer."""
|
|
||||||
self._entries.append(entry)
|
|
||||||
|
|
||||||
def get_filtered(self, limit: int = 50) -> list[LogEntry]:
|
|
||||||
"""
|
|
||||||
Get filtered log entries.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
limit: Maximum number of entries to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of filtered LogEntry objects (most recent first)
|
|
||||||
"""
|
|
||||||
entries = list(self._entries)
|
|
||||||
|
|
||||||
if self._current_filter == self.FILTER_ERRORS:
|
|
||||||
entries = [e for e in entries if e.level in ("ERROR", "CRITICAL")]
|
|
||||||
elif self._current_filter == self.FILTER_TRADES:
|
|
||||||
# Key terms indicating actual trading activity
|
|
||||||
include_keywords = [
|
|
||||||
"order", "entry", "exit", "executed", "filled",
|
|
||||||
"opening", "closing", "position opened", "position closed"
|
|
||||||
]
|
|
||||||
# Terms to exclude (noise)
|
|
||||||
exclude_keywords = [
|
|
||||||
"sync complete", "0 positions", "portfolio: 0 positions"
|
|
||||||
]
|
|
||||||
|
|
||||||
entries = [
|
|
||||||
e for e in entries
|
|
||||||
if any(kw in e.message.lower() for kw in include_keywords)
|
|
||||||
and not any(ex in e.message.lower() for ex in exclude_keywords)
|
|
||||||
]
|
|
||||||
elif self._current_filter == self.FILTER_SIGNALS:
|
|
||||||
signal_keywords = ["signal", "z_score", "prob", "z="]
|
|
||||||
entries = [
|
|
||||||
e for e in entries
|
|
||||||
if any(kw in e.message.lower() for kw in signal_keywords)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Return most recent entries
|
|
||||||
return list(reversed(entries[-limit:]))
|
|
||||||
|
|
||||||
def set_filter(self, filter_name: str) -> None:
|
|
||||||
"""Set a specific filter."""
|
|
||||||
if filter_name in self.FILTERS:
|
|
||||||
self._current_filter = filter_name
|
|
||||||
|
|
||||||
def cycle_filter(self) -> str:
|
|
||||||
"""Cycle to next filter and return its name."""
|
|
||||||
current_idx = self.FILTERS.index(self._current_filter)
|
|
||||||
next_idx = (current_idx + 1) % len(self.FILTERS)
|
|
||||||
self._current_filter = self.FILTERS[next_idx]
|
|
||||||
return self._current_filter
|
|
||||||
|
|
||||||
def get_current_filter(self) -> str:
|
|
||||||
"""Get current filter name."""
|
|
||||||
return self._current_filter
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""Clear all log entries."""
|
|
||||||
self._entries.clear()
|
|
||||||
|
|
||||||
def drain_queue(self, log_queue: queue.Queue) -> int:
|
|
||||||
"""
|
|
||||||
Drain log entries from queue into buffer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
log_queue: Queue to drain from
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of entries drained
|
|
||||||
"""
|
|
||||||
count = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
entry = log_queue.get_nowait()
|
|
||||||
self.add(entry)
|
|
||||||
count += 1
|
|
||||||
except queue.Empty:
|
|
||||||
break
|
|
||||||
return count
|
|
||||||
@@ -1,399 +0,0 @@
|
|||||||
"""UI panel components using Rich."""
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from rich.console import Console, Group
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.text import Text
|
|
||||||
from rich.layout import Layout
|
|
||||||
|
|
||||||
from .state import SharedState, PositionState, StrategyState, AccountState
|
|
||||||
from .log_handler import LogBuffer, LogEntry
|
|
||||||
from ..db.metrics import PeriodMetrics
|
|
||||||
|
|
||||||
|
|
||||||
def format_pnl(value: float, include_sign: bool = True) -> Text:
|
|
||||||
"""Format PnL value with color."""
|
|
||||||
if value > 0:
|
|
||||||
sign = "+" if include_sign else ""
|
|
||||||
return Text(f"{sign}${value:.2f}", style="green")
|
|
||||||
elif value < 0:
|
|
||||||
return Text(f"${value:.2f}", style="red")
|
|
||||||
else:
|
|
||||||
return Text(f"${value:.2f}", style="white")
|
|
||||||
|
|
||||||
|
|
||||||
def format_pct(value: float, include_sign: bool = True) -> Text:
|
|
||||||
"""Format percentage value with color."""
|
|
||||||
if value > 0:
|
|
||||||
sign = "+" if include_sign else ""
|
|
||||||
return Text(f"{sign}{value:.2f}%", style="green")
|
|
||||||
elif value < 0:
|
|
||||||
return Text(f"{value:.2f}%", style="red")
|
|
||||||
else:
|
|
||||||
return Text(f"{value:.2f}%", style="white")
|
|
||||||
|
|
||||||
|
|
||||||
def format_side(side: str) -> Text:
|
|
||||||
"""Format position side with color."""
|
|
||||||
if side.lower() == "long":
|
|
||||||
return Text("LONG", style="bold green")
|
|
||||||
else:
|
|
||||||
return Text("SHORT", style="bold red")
|
|
||||||
|
|
||||||
|
|
||||||
class HeaderPanel:
|
|
||||||
"""Header panel with title and mode indicator."""
|
|
||||||
|
|
||||||
def __init__(self, state: SharedState):
|
|
||||||
self.state = state
|
|
||||||
|
|
||||||
def render(self) -> Panel:
|
|
||||||
"""Render the header panel."""
|
|
||||||
mode = self.state.get_mode()
|
|
||||||
eth_symbol, _ = self.state.get_symbols()
|
|
||||||
|
|
||||||
mode_style = "yellow" if mode == "DEMO" else "bold red"
|
|
||||||
mode_text = Text(f"[{mode}]", style=mode_style)
|
|
||||||
|
|
||||||
title = Text()
|
|
||||||
title.append("REGIME REVERSION STRATEGY - LIVE TRADING", style="bold white")
|
|
||||||
title.append(" ")
|
|
||||||
title.append(mode_text)
|
|
||||||
title.append(" ")
|
|
||||||
title.append(eth_symbol, style="cyan")
|
|
||||||
|
|
||||||
return Panel(title, style="blue", height=3)
|
|
||||||
|
|
||||||
|
|
||||||
class TabBar:
|
|
||||||
"""Tab bar for period selection."""
|
|
||||||
|
|
||||||
TABS = ["1:General", "2:Monthly", "3:Weekly", "4:Daily"]
|
|
||||||
|
|
||||||
def __init__(self, active_tab: int = 0):
|
|
||||||
self.active_tab = active_tab
|
|
||||||
|
|
||||||
def render(
|
|
||||||
self,
|
|
||||||
has_monthly: bool = True,
|
|
||||||
has_weekly: bool = True,
|
|
||||||
) -> Text:
|
|
||||||
"""Render the tab bar."""
|
|
||||||
text = Text()
|
|
||||||
text.append(" ")
|
|
||||||
|
|
||||||
for i, tab in enumerate(self.TABS):
|
|
||||||
# Check if tab should be shown
|
|
||||||
if i == 1 and not has_monthly:
|
|
||||||
continue
|
|
||||||
if i == 2 and not has_weekly:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if i == self.active_tab:
|
|
||||||
text.append(f"[{tab}]", style="bold white on blue")
|
|
||||||
else:
|
|
||||||
text.append(f"[{tab}]", style="dim")
|
|
||||||
text.append(" ")
|
|
||||||
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class MetricsPanel:
|
|
||||||
"""Panel showing trading metrics."""
|
|
||||||
|
|
||||||
def __init__(self, metrics: Optional[PeriodMetrics] = None):
|
|
||||||
self.metrics = metrics
|
|
||||||
|
|
||||||
def render(self) -> Table:
|
|
||||||
"""Render metrics as a table."""
|
|
||||||
table = Table(
|
|
||||||
show_header=False,
|
|
||||||
show_edge=False,
|
|
||||||
box=None,
|
|
||||||
padding=(0, 1),
|
|
||||||
)
|
|
||||||
table.add_column("Label", style="dim")
|
|
||||||
table.add_column("Value")
|
|
||||||
|
|
||||||
if self.metrics is None or self.metrics.total_trades == 0:
|
|
||||||
table.add_row("Status", Text("No trade data", style="dim"))
|
|
||||||
return table
|
|
||||||
|
|
||||||
m = self.metrics
|
|
||||||
|
|
||||||
table.add_row("Total PnL:", format_pnl(m.total_pnl))
|
|
||||||
table.add_row("Win Rate:", Text(f"{m.win_rate:.1f}%", style="white"))
|
|
||||||
table.add_row("Total Trades:", Text(str(m.total_trades), style="white"))
|
|
||||||
table.add_row(
|
|
||||||
"Win/Loss:",
|
|
||||||
Text(f"{m.winning_trades}/{m.losing_trades}", style="white"),
|
|
||||||
)
|
|
||||||
table.add_row(
|
|
||||||
"Avg Duration:",
|
|
||||||
Text(f"{m.avg_trade_duration_hours:.1f}h", style="white"),
|
|
||||||
)
|
|
||||||
table.add_row("Max Drawdown:", format_pnl(-m.max_drawdown))
|
|
||||||
table.add_row("Best Trade:", format_pnl(m.best_trade))
|
|
||||||
table.add_row("Worst Trade:", format_pnl(m.worst_trade))
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
class PositionPanel:
|
|
||||||
"""Panel showing current position."""
|
|
||||||
|
|
||||||
def __init__(self, position: Optional[PositionState] = None):
|
|
||||||
self.position = position
|
|
||||||
|
|
||||||
def render(self) -> Table:
|
|
||||||
"""Render position as a table."""
|
|
||||||
table = Table(
|
|
||||||
show_header=False,
|
|
||||||
show_edge=False,
|
|
||||||
box=None,
|
|
||||||
padding=(0, 1),
|
|
||||||
)
|
|
||||||
table.add_column("Label", style="dim")
|
|
||||||
table.add_column("Value")
|
|
||||||
|
|
||||||
if self.position is None:
|
|
||||||
table.add_row("Status", Text("No open position", style="dim"))
|
|
||||||
return table
|
|
||||||
|
|
||||||
p = self.position
|
|
||||||
|
|
||||||
table.add_row("Side:", format_side(p.side))
|
|
||||||
table.add_row("Entry:", Text(f"${p.entry_price:.2f}", style="white"))
|
|
||||||
table.add_row("Current:", Text(f"${p.current_price:.2f}", style="white"))
|
|
||||||
|
|
||||||
# Unrealized PnL
|
|
||||||
pnl_text = Text()
|
|
||||||
pnl_text.append_text(format_pnl(p.unrealized_pnl))
|
|
||||||
pnl_text.append(" (")
|
|
||||||
pnl_text.append_text(format_pct(p.unrealized_pnl_pct))
|
|
||||||
pnl_text.append(")")
|
|
||||||
table.add_row("Unrealized:", pnl_text)
|
|
||||||
|
|
||||||
table.add_row("Size:", Text(f"${p.size_usdt:.2f}", style="white"))
|
|
||||||
|
|
||||||
# SL/TP
|
|
||||||
if p.side == "long":
|
|
||||||
sl_dist = (p.stop_loss_price / p.entry_price - 1) * 100
|
|
||||||
tp_dist = (p.take_profit_price / p.entry_price - 1) * 100
|
|
||||||
else:
|
|
||||||
sl_dist = (1 - p.stop_loss_price / p.entry_price) * 100
|
|
||||||
tp_dist = (1 - p.take_profit_price / p.entry_price) * 100
|
|
||||||
|
|
||||||
sl_text = Text(f"${p.stop_loss_price:.2f} ({sl_dist:+.1f}%)", style="red")
|
|
||||||
tp_text = Text(f"${p.take_profit_price:.2f} ({tp_dist:+.1f}%)", style="green")
|
|
||||||
|
|
||||||
table.add_row("Stop Loss:", sl_text)
|
|
||||||
table.add_row("Take Profit:", tp_text)
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
class AccountPanel:
|
|
||||||
"""Panel showing account information."""
|
|
||||||
|
|
||||||
def __init__(self, account: Optional[AccountState] = None):
|
|
||||||
self.account = account
|
|
||||||
|
|
||||||
def render(self) -> Table:
|
|
||||||
"""Render account info as a table."""
|
|
||||||
table = Table(
|
|
||||||
show_header=False,
|
|
||||||
show_edge=False,
|
|
||||||
box=None,
|
|
||||||
padding=(0, 1),
|
|
||||||
)
|
|
||||||
table.add_column("Label", style="dim")
|
|
||||||
table.add_column("Value")
|
|
||||||
|
|
||||||
if self.account is None:
|
|
||||||
table.add_row("Status", Text("Loading...", style="dim"))
|
|
||||||
return table
|
|
||||||
|
|
||||||
a = self.account
|
|
||||||
|
|
||||||
table.add_row("Balance:", Text(f"${a.balance:.2f}", style="white"))
|
|
||||||
table.add_row("Available:", Text(f"${a.available:.2f}", style="white"))
|
|
||||||
table.add_row("Leverage:", Text(f"{a.leverage}x", style="cyan"))
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
class StrategyPanel:
|
|
||||||
"""Panel showing strategy state."""
|
|
||||||
|
|
||||||
def __init__(self, strategy: Optional[StrategyState] = None):
|
|
||||||
self.strategy = strategy
|
|
||||||
|
|
||||||
def render(self) -> Table:
|
|
||||||
"""Render strategy state as a table."""
|
|
||||||
table = Table(
|
|
||||||
show_header=False,
|
|
||||||
show_edge=False,
|
|
||||||
box=None,
|
|
||||||
padding=(0, 1),
|
|
||||||
)
|
|
||||||
table.add_column("Label", style="dim")
|
|
||||||
table.add_column("Value")
|
|
||||||
|
|
||||||
if self.strategy is None:
|
|
||||||
table.add_row("Status", Text("Waiting...", style="dim"))
|
|
||||||
return table
|
|
||||||
|
|
||||||
s = self.strategy
|
|
||||||
|
|
||||||
# Z-score with color based on threshold
|
|
||||||
z_style = "white"
|
|
||||||
if abs(s.z_score) > 1.0:
|
|
||||||
z_style = "yellow"
|
|
||||||
if abs(s.z_score) > 1.5:
|
|
||||||
z_style = "bold yellow"
|
|
||||||
table.add_row("Z-Score:", Text(f"{s.z_score:.2f}", style=z_style))
|
|
||||||
|
|
||||||
# Probability with color
|
|
||||||
prob_style = "white"
|
|
||||||
if s.probability > 0.5:
|
|
||||||
prob_style = "green"
|
|
||||||
if s.probability > 0.7:
|
|
||||||
prob_style = "bold green"
|
|
||||||
table.add_row("Probability:", Text(f"{s.probability:.2f}", style=prob_style))
|
|
||||||
|
|
||||||
# Funding rate
|
|
||||||
funding_style = "green" if s.funding_rate >= 0 else "red"
|
|
||||||
table.add_row(
|
|
||||||
"Funding:",
|
|
||||||
Text(f"{s.funding_rate:.4f}", style=funding_style),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Last action
|
|
||||||
action_style = "white"
|
|
||||||
if s.last_action == "entry":
|
|
||||||
action_style = "bold cyan"
|
|
||||||
elif s.last_action == "check_exit":
|
|
||||||
action_style = "yellow"
|
|
||||||
table.add_row("Last Action:", Text(s.last_action, style=action_style))
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
class LogPanel:
|
|
||||||
"""Panel showing log entries."""
|
|
||||||
|
|
||||||
def __init__(self, log_buffer: LogBuffer):
|
|
||||||
self.log_buffer = log_buffer
|
|
||||||
|
|
||||||
def render(self, height: int = 10) -> Panel:
|
|
||||||
"""Render log panel."""
|
|
||||||
filter_name = self.log_buffer.get_current_filter().title()
|
|
||||||
entries = self.log_buffer.get_filtered(limit=height - 2)
|
|
||||||
|
|
||||||
lines = []
|
|
||||||
for entry in entries:
|
|
||||||
line = Text()
|
|
||||||
line.append(f"{entry.timestamp} ", style="dim")
|
|
||||||
line.append(f"[{entry.level}] ", style=entry.level_color)
|
|
||||||
line.append(entry.message)
|
|
||||||
lines.append(line)
|
|
||||||
|
|
||||||
if not lines:
|
|
||||||
lines.append(Text("No logs to display", style="dim"))
|
|
||||||
|
|
||||||
content = Group(*lines)
|
|
||||||
|
|
||||||
# Build "tabbed" title
|
|
||||||
tabs = []
|
|
||||||
|
|
||||||
# All Logs tab
|
|
||||||
if filter_name == "All":
|
|
||||||
tabs.append("[bold white on blue] [L]ogs [/]")
|
|
||||||
else:
|
|
||||||
tabs.append("[dim] [L]ogs [/]")
|
|
||||||
|
|
||||||
# Trades tab
|
|
||||||
if filter_name == "Trades":
|
|
||||||
tabs.append("[bold white on blue] [T]rades [/]")
|
|
||||||
else:
|
|
||||||
tabs.append("[dim] [T]rades [/]")
|
|
||||||
|
|
||||||
# Errors tab
|
|
||||||
if filter_name == "Errors":
|
|
||||||
tabs.append("[bold white on blue] [E]rrors [/]")
|
|
||||||
else:
|
|
||||||
tabs.append("[dim] [E]rrors [/]")
|
|
||||||
|
|
||||||
title = " ".join(tabs)
|
|
||||||
subtitle = "Press 'l', 't', 'e' to switch tabs"
|
|
||||||
|
|
||||||
return Panel(
|
|
||||||
content,
|
|
||||||
title=title,
|
|
||||||
subtitle=subtitle,
|
|
||||||
title_align="left",
|
|
||||||
subtitle_align="right",
|
|
||||||
border_style="blue",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HelpBar:
|
|
||||||
"""Bottom help bar with keyboard shortcuts."""
|
|
||||||
|
|
||||||
def render(self) -> Text:
|
|
||||||
"""Render help bar."""
|
|
||||||
text = Text()
|
|
||||||
text.append(" [q]", style="bold")
|
|
||||||
text.append("Quit ", style="dim")
|
|
||||||
text.append("[r]", style="bold")
|
|
||||||
text.append("Refresh ", style="dim")
|
|
||||||
text.append("[1-4]", style="bold")
|
|
||||||
text.append("Tabs ", style="dim")
|
|
||||||
text.append("[l/t/e]", style="bold")
|
|
||||||
text.append("LogView", style="dim")
|
|
||||||
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def build_summary_panel(
|
|
||||||
state: SharedState,
|
|
||||||
metrics: Optional[PeriodMetrics],
|
|
||||||
tab_bar: TabBar,
|
|
||||||
has_monthly: bool,
|
|
||||||
has_weekly: bool,
|
|
||||||
) -> Panel:
|
|
||||||
"""Build the complete summary panel with all sections."""
|
|
||||||
# Create layout for summary content
|
|
||||||
layout = Layout()
|
|
||||||
|
|
||||||
# Tab bar at top
|
|
||||||
tabs = tab_bar.render(has_monthly, has_weekly)
|
|
||||||
|
|
||||||
# Create tables for each section
|
|
||||||
metrics_table = MetricsPanel(metrics).render()
|
|
||||||
position_table = PositionPanel(state.get_position()).render()
|
|
||||||
account_table = AccountPanel(state.get_account()).render()
|
|
||||||
strategy_table = StrategyPanel(state.get_strategy()).render()
|
|
||||||
|
|
||||||
# Build two-column layout
|
|
||||||
left_col = Table(show_header=True, show_edge=False, box=None, padding=(0, 2))
|
|
||||||
left_col.add_column("PERFORMANCE", style="bold cyan")
|
|
||||||
left_col.add_column("ACCOUNT", style="bold cyan")
|
|
||||||
left_col.add_row(metrics_table, account_table)
|
|
||||||
|
|
||||||
right_col = Table(show_header=True, show_edge=False, box=None, padding=(0, 2))
|
|
||||||
right_col.add_column("CURRENT POSITION", style="bold cyan")
|
|
||||||
right_col.add_column("STRATEGY STATE", style="bold cyan")
|
|
||||||
right_col.add_row(position_table, strategy_table)
|
|
||||||
|
|
||||||
# Combine into main table
|
|
||||||
main_table = Table(show_header=False, show_edge=False, box=None, expand=True)
|
|
||||||
main_table.add_column(ratio=1)
|
|
||||||
main_table.add_column(ratio=1)
|
|
||||||
main_table.add_row(left_col, right_col)
|
|
||||||
|
|
||||||
content = Group(tabs, Text(""), main_table)
|
|
||||||
|
|
||||||
return Panel(content, border_style="blue")
|
|
||||||
@@ -1,195 +0,0 @@
|
|||||||
"""Thread-safe shared state for UI and trading loop."""
|
|
||||||
import threading
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Optional
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PositionState:
|
|
||||||
"""Current position information."""
|
|
||||||
|
|
||||||
trade_id: str = ""
|
|
||||||
symbol: str = ""
|
|
||||||
side: str = ""
|
|
||||||
entry_price: float = 0.0
|
|
||||||
current_price: float = 0.0
|
|
||||||
size: float = 0.0
|
|
||||||
size_usdt: float = 0.0
|
|
||||||
unrealized_pnl: float = 0.0
|
|
||||||
unrealized_pnl_pct: float = 0.0
|
|
||||||
stop_loss_price: float = 0.0
|
|
||||||
take_profit_price: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StrategyState:
|
|
||||||
"""Current strategy signal state."""
|
|
||||||
|
|
||||||
z_score: float = 0.0
|
|
||||||
probability: float = 0.0
|
|
||||||
funding_rate: float = 0.0
|
|
||||||
last_action: str = "hold"
|
|
||||||
last_reason: str = ""
|
|
||||||
last_signal_time: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AccountState:
|
|
||||||
"""Account balance information."""
|
|
||||||
|
|
||||||
balance: float = 0.0
|
|
||||||
available: float = 0.0
|
|
||||||
leverage: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
class SharedState:
|
|
||||||
"""
|
|
||||||
Thread-safe shared state between trading loop and UI.
|
|
||||||
|
|
||||||
All access to state fields should go through the getter/setter methods
|
|
||||||
which use a lock for thread safety.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._position: Optional[PositionState] = None
|
|
||||||
self._strategy = StrategyState()
|
|
||||||
self._account = AccountState()
|
|
||||||
self._is_running = True
|
|
||||||
self._last_cycle_time: Optional[str] = None
|
|
||||||
self._mode = "DEMO"
|
|
||||||
self._eth_symbol = "ETH/USDT:USDT"
|
|
||||||
self._btc_symbol = "BTC/USDT:USDT"
|
|
||||||
|
|
||||||
# Position methods
|
|
||||||
def get_position(self) -> Optional[PositionState]:
|
|
||||||
"""Get current position state."""
|
|
||||||
with self._lock:
|
|
||||||
return self._position
|
|
||||||
|
|
||||||
def set_position(self, position: Optional[PositionState]) -> None:
|
|
||||||
"""Set current position state."""
|
|
||||||
with self._lock:
|
|
||||||
self._position = position
|
|
||||||
|
|
||||||
def update_position_price(self, current_price: float) -> None:
|
|
||||||
"""Update current price and recalculate PnL."""
|
|
||||||
with self._lock:
|
|
||||||
if self._position is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._position.current_price = current_price
|
|
||||||
|
|
||||||
if self._position.side == "long":
|
|
||||||
pnl = (current_price - self._position.entry_price)
|
|
||||||
self._position.unrealized_pnl = pnl * self._position.size
|
|
||||||
pnl_pct = (current_price / self._position.entry_price - 1) * 100
|
|
||||||
else:
|
|
||||||
pnl = (self._position.entry_price - current_price)
|
|
||||||
self._position.unrealized_pnl = pnl * self._position.size
|
|
||||||
pnl_pct = (1 - current_price / self._position.entry_price) * 100
|
|
||||||
|
|
||||||
self._position.unrealized_pnl_pct = pnl_pct
|
|
||||||
|
|
||||||
def clear_position(self) -> None:
|
|
||||||
"""Clear current position."""
|
|
||||||
with self._lock:
|
|
||||||
self._position = None
|
|
||||||
|
|
||||||
# Strategy methods
|
|
||||||
def get_strategy(self) -> StrategyState:
|
|
||||||
"""Get current strategy state."""
|
|
||||||
with self._lock:
|
|
||||||
return StrategyState(
|
|
||||||
z_score=self._strategy.z_score,
|
|
||||||
probability=self._strategy.probability,
|
|
||||||
funding_rate=self._strategy.funding_rate,
|
|
||||||
last_action=self._strategy.last_action,
|
|
||||||
last_reason=self._strategy.last_reason,
|
|
||||||
last_signal_time=self._strategy.last_signal_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_strategy(
|
|
||||||
self,
|
|
||||||
z_score: float,
|
|
||||||
probability: float,
|
|
||||||
funding_rate: float,
|
|
||||||
action: str,
|
|
||||||
reason: str,
|
|
||||||
) -> None:
|
|
||||||
"""Update strategy state."""
|
|
||||||
with self._lock:
|
|
||||||
self._strategy.z_score = z_score
|
|
||||||
self._strategy.probability = probability
|
|
||||||
self._strategy.funding_rate = funding_rate
|
|
||||||
self._strategy.last_action = action
|
|
||||||
self._strategy.last_reason = reason
|
|
||||||
self._strategy.last_signal_time = datetime.now(
|
|
||||||
timezone.utc
|
|
||||||
).isoformat()
|
|
||||||
|
|
||||||
# Account methods
|
|
||||||
def get_account(self) -> AccountState:
|
|
||||||
"""Get current account state."""
|
|
||||||
with self._lock:
|
|
||||||
return AccountState(
|
|
||||||
balance=self._account.balance,
|
|
||||||
available=self._account.available,
|
|
||||||
leverage=self._account.leverage,
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_account(
|
|
||||||
self,
|
|
||||||
balance: float,
|
|
||||||
available: float,
|
|
||||||
leverage: int,
|
|
||||||
) -> None:
|
|
||||||
"""Update account state."""
|
|
||||||
with self._lock:
|
|
||||||
self._account.balance = balance
|
|
||||||
self._account.available = available
|
|
||||||
self._account.leverage = leverage
|
|
||||||
|
|
||||||
# Control methods
|
|
||||||
def is_running(self) -> bool:
|
|
||||||
"""Check if trading loop is running."""
|
|
||||||
with self._lock:
|
|
||||||
return self._is_running
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Signal to stop trading loop."""
|
|
||||||
with self._lock:
|
|
||||||
self._is_running = False
|
|
||||||
|
|
||||||
def get_last_cycle_time(self) -> Optional[str]:
|
|
||||||
"""Get last trading cycle time."""
|
|
||||||
with self._lock:
|
|
||||||
return self._last_cycle_time
|
|
||||||
|
|
||||||
def set_last_cycle_time(self, time_str: str) -> None:
|
|
||||||
"""Set last trading cycle time."""
|
|
||||||
with self._lock:
|
|
||||||
self._last_cycle_time = time_str
|
|
||||||
|
|
||||||
# Config methods
|
|
||||||
def get_mode(self) -> str:
|
|
||||||
"""Get trading mode (DEMO/LIVE)."""
|
|
||||||
with self._lock:
|
|
||||||
return self._mode
|
|
||||||
|
|
||||||
def set_mode(self, mode: str) -> None:
|
|
||||||
"""Set trading mode."""
|
|
||||||
with self._lock:
|
|
||||||
self._mode = mode
|
|
||||||
|
|
||||||
def get_symbols(self) -> tuple[str, str]:
|
|
||||||
"""Get trading symbols (eth, btc)."""
|
|
||||||
with self._lock:
|
|
||||||
return self._eth_symbol, self._btc_symbol
|
|
||||||
|
|
||||||
def set_symbols(self, eth_symbol: str, btc_symbol: str) -> None:
|
|
||||||
"""Set trading symbols."""
|
|
||||||
with self._lock:
|
|
||||||
self._eth_symbol = eth_symbol
|
|
||||||
self._btc_symbol = btc_symbol
|
|
||||||
@@ -15,8 +15,6 @@ dependencies = [
|
|||||||
"plotly>=5.24.0",
|
"plotly>=5.24.0",
|
||||||
"requests>=2.32.5",
|
"requests>=2.32.5",
|
||||||
"python-dotenv>=1.2.1",
|
"python-dotenv>=1.2.1",
|
||||||
# Terminal UI
|
|
||||||
"rich>=13.0.0",
|
|
||||||
# API dependencies
|
# API dependencies
|
||||||
"fastapi>=0.115.0",
|
"fastapi>=0.115.0",
|
||||||
"uvicorn[standard]>=0.34.0",
|
"uvicorn[standard]>=0.34.0",
|
||||||
|
|||||||
@@ -3,16 +3,7 @@ Regime Detection Research Script with Walk-Forward Training.
|
|||||||
|
|
||||||
Tests multiple holding horizons to find optimal parameters
|
Tests multiple holding horizons to find optimal parameters
|
||||||
without look-ahead bias.
|
without look-ahead bias.
|
||||||
|
|
||||||
Usage:
|
|
||||||
uv run python research/regime_detection.py [options]
|
|
||||||
|
|
||||||
Options:
|
|
||||||
--days DAYS Number of days of data (default: 90)
|
|
||||||
--start DATE Start date (YYYY-MM-DD), overrides --days
|
|
||||||
--end DATE End date (YYYY-MM-DD), defaults to now
|
|
||||||
"""
|
"""
|
||||||
import argparse
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
@@ -32,39 +23,20 @@ logger = get_logger(__name__)
|
|||||||
# Configuration
|
# Configuration
|
||||||
TRAIN_RATIO = 0.7 # 70% train, 30% test
|
TRAIN_RATIO = 0.7 # 70% train, 30% test
|
||||||
PROFIT_THRESHOLD = 0.005 # 0.5% profit target
|
PROFIT_THRESHOLD = 0.005 # 0.5% profit target
|
||||||
STOP_LOSS_PCT = 0.06 # 6% stop loss
|
|
||||||
Z_WINDOW = 24
|
Z_WINDOW = 24
|
||||||
FEE_RATE = 0.001 # 0.1% round-trip fee
|
FEE_RATE = 0.001 # 0.1% round-trip fee
|
||||||
DEFAULT_DAYS = 90 # Default lookback period in days
|
|
||||||
|
|
||||||
|
|
||||||
def load_data(days: int = DEFAULT_DAYS, start_date: str = None, end_date: str = None):
|
def load_data():
|
||||||
"""
|
"""Load and align BTC/ETH data."""
|
||||||
Load and align BTC/ETH data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
days: Number of days of historical data (default: 90)
|
|
||||||
start_date: Optional start date (YYYY-MM-DD), overrides days
|
|
||||||
end_date: Optional end date (YYYY-MM-DD), defaults to now
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (df_btc, df_eth) DataFrames
|
|
||||||
"""
|
|
||||||
dm = DataManager()
|
dm = DataManager()
|
||||||
|
|
||||||
df_btc = dm.load_data("okx", "BTC-USDT", "1h", MarketType.SPOT)
|
df_btc = dm.load_data("okx", "BTC-USDT", "1h", MarketType.SPOT)
|
||||||
df_eth = dm.load_data("okx", "ETH-USDT", "1h", MarketType.SPOT)
|
df_eth = dm.load_data("okx", "ETH-USDT", "1h", MarketType.SPOT)
|
||||||
|
|
||||||
# Determine date range
|
# Filter to Oct-Dec 2025
|
||||||
if end_date:
|
start = pd.Timestamp("2025-10-01", tz="UTC")
|
||||||
end = pd.Timestamp(end_date, tz="UTC")
|
end = pd.Timestamp("2025-12-31", tz="UTC")
|
||||||
else:
|
|
||||||
end = pd.Timestamp.now(tz="UTC")
|
|
||||||
|
|
||||||
if start_date:
|
|
||||||
start = pd.Timestamp(start_date, tz="UTC")
|
|
||||||
else:
|
|
||||||
start = end - pd.Timedelta(days=days)
|
|
||||||
|
|
||||||
df_btc = df_btc[(df_btc.index >= start) & (df_btc.index <= end)]
|
df_btc = df_btc[(df_btc.index >= start) & (df_btc.index <= end)]
|
||||||
df_eth = df_eth[(df_eth.index >= start) & (df_eth.index <= end)]
|
df_eth = df_eth[(df_eth.index >= start) & (df_eth.index <= end)]
|
||||||
@@ -74,7 +46,7 @@ def load_data(days: int = DEFAULT_DAYS, start_date: str = None, end_date: str =
|
|||||||
df_btc = df_btc.loc[common]
|
df_btc = df_btc.loc[common]
|
||||||
df_eth = df_eth.loc[common]
|
df_eth = df_eth.loc[common]
|
||||||
|
|
||||||
logger.info(f"Loaded {len(common)} aligned hourly bars from {start} to {end}")
|
logger.info(f"Loaded {len(common)} aligned hourly bars")
|
||||||
return df_btc, df_eth
|
return df_btc, df_eth
|
||||||
|
|
||||||
|
|
||||||
@@ -140,74 +112,26 @@ def calculate_features(df_btc, df_eth, cq_df=None):
|
|||||||
|
|
||||||
|
|
||||||
def calculate_targets(features, horizon):
|
def calculate_targets(features, horizon):
|
||||||
"""
|
"""Calculate target labels for a given horizon."""
|
||||||
Calculate target labels for a given horizon.
|
spread = features['spread']
|
||||||
|
z_score = features['z_score']
|
||||||
|
|
||||||
Uses path-dependent labeling: Success is hitting Profit Target BEFORE Stop Loss.
|
# For Short (Z > 1): Did spread drop below target?
|
||||||
"""
|
future_min = spread.rolling(window=horizon).min().shift(-horizon)
|
||||||
spread = features['spread'].values
|
target_short = spread * (1 - PROFIT_THRESHOLD)
|
||||||
z_score = features['z_score'].values
|
success_short = (z_score > 1.0) & (future_min < target_short)
|
||||||
n = len(spread)
|
|
||||||
|
|
||||||
targets = np.zeros(n, dtype=int)
|
# For Long (Z < -1): Did spread rise above target?
|
||||||
|
future_max = spread.rolling(window=horizon).max().shift(-horizon)
|
||||||
|
target_long = spread * (1 + PROFIT_THRESHOLD)
|
||||||
|
success_long = (z_score < -1.0) & (future_max > target_long)
|
||||||
|
|
||||||
|
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||||
|
|
||||||
# Create valid mask (rows with complete future data)
|
# Create valid mask (rows with complete future data)
|
||||||
valid_mask = np.zeros(n, dtype=bool)
|
valid_mask = future_min.notna() & future_max.notna()
|
||||||
valid_mask[:n-horizon] = True
|
|
||||||
|
|
||||||
# Only iterate relevant rows for efficiency
|
return targets, valid_mask, future_min, future_max
|
||||||
candidates = np.where((z_score > 1.0) | (z_score < -1.0))[0]
|
|
||||||
|
|
||||||
for i in candidates:
|
|
||||||
if i + horizon >= n:
|
|
||||||
continue
|
|
||||||
|
|
||||||
entry_price = spread[i]
|
|
||||||
future_prices = spread[i+1 : i+1+horizon]
|
|
||||||
|
|
||||||
if z_score[i] > 1.0: # Short
|
|
||||||
target_price = entry_price * (1 - PROFIT_THRESHOLD)
|
|
||||||
stop_price = entry_price * (1 + STOP_LOSS_PCT)
|
|
||||||
|
|
||||||
# Identify first hit indices
|
|
||||||
hit_tp = future_prices <= target_price
|
|
||||||
hit_sl = future_prices >= stop_price
|
|
||||||
|
|
||||||
if not np.any(hit_tp):
|
|
||||||
targets[i] = 0 # Target never hit
|
|
||||||
elif not np.any(hit_sl):
|
|
||||||
targets[i] = 1 # Target hit, SL never hit
|
|
||||||
else:
|
|
||||||
first_tp_idx = np.argmax(hit_tp)
|
|
||||||
first_sl_idx = np.argmax(hit_sl)
|
|
||||||
|
|
||||||
# Success if TP hit before SL
|
|
||||||
if first_tp_idx < first_sl_idx:
|
|
||||||
targets[i] = 1
|
|
||||||
else:
|
|
||||||
targets[i] = 0
|
|
||||||
|
|
||||||
else: # Long
|
|
||||||
target_price = entry_price * (1 + PROFIT_THRESHOLD)
|
|
||||||
stop_price = entry_price * (1 - STOP_LOSS_PCT)
|
|
||||||
|
|
||||||
hit_tp = future_prices >= target_price
|
|
||||||
hit_sl = future_prices <= stop_price
|
|
||||||
|
|
||||||
if not np.any(hit_tp):
|
|
||||||
targets[i] = 0
|
|
||||||
elif not np.any(hit_sl):
|
|
||||||
targets[i] = 1
|
|
||||||
else:
|
|
||||||
first_tp_idx = np.argmax(hit_tp)
|
|
||||||
first_sl_idx = np.argmax(hit_sl)
|
|
||||||
|
|
||||||
if first_tp_idx < first_sl_idx:
|
|
||||||
targets[i] = 1
|
|
||||||
else:
|
|
||||||
targets[i] = 0
|
|
||||||
|
|
||||||
return targets, pd.Series(valid_mask, index=features.index), None, None
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_mae(features, predictions, test_idx, horizon):
|
def calculate_mae(features, predictions, test_idx, horizon):
|
||||||
@@ -244,10 +168,7 @@ def calculate_mae(features, predictions, test_idx, horizon):
|
|||||||
|
|
||||||
|
|
||||||
def calculate_net_profit(features, predictions, test_idx, horizon):
|
def calculate_net_profit(features, predictions, test_idx, horizon):
|
||||||
"""
|
"""Calculate estimated net profit including fees."""
|
||||||
Calculate estimated net profit including fees.
|
|
||||||
Enforces 'one trade at a time' and simulates SL/TP exits.
|
|
||||||
"""
|
|
||||||
test_features = features.loc[test_idx]
|
test_features = features.loc[test_idx]
|
||||||
spread = test_features['spread']
|
spread = test_features['spread']
|
||||||
z_score = test_features['z_score']
|
z_score = test_features['z_score']
|
||||||
@@ -255,17 +176,7 @@ def calculate_net_profit(features, predictions, test_idx, horizon):
|
|||||||
total_pnl = 0.0
|
total_pnl = 0.0
|
||||||
n_trades = 0
|
n_trades = 0
|
||||||
|
|
||||||
# Track when we are free to trade again
|
|
||||||
next_trade_idx = 0
|
|
||||||
|
|
||||||
# Pre-calculate indices for speed
|
|
||||||
all_indices = features.index
|
|
||||||
|
|
||||||
for i, (idx, pred) in enumerate(zip(test_idx, predictions)):
|
for i, (idx, pred) in enumerate(zip(test_idx, predictions)):
|
||||||
# Skip if we are still in a trade
|
|
||||||
if i < next_trade_idx:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if pred != 1:
|
if pred != 1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -273,77 +184,26 @@ def calculate_net_profit(features, predictions, test_idx, horizon):
|
|||||||
z = z_score.loc[idx]
|
z = z_score.loc[idx]
|
||||||
|
|
||||||
# Get future spread values
|
# Get future spread values
|
||||||
current_loc = features.index.get_loc(idx)
|
future_idx = features.index.get_loc(idx)
|
||||||
future_end_loc = min(current_loc + horizon, len(features))
|
future_end = min(future_idx + horizon, len(features))
|
||||||
future_spreads = features['spread'].iloc[current_loc+1 : future_end_loc]
|
future_spreads = features['spread'].iloc[future_idx:future_end]
|
||||||
|
|
||||||
if len(future_spreads) < 1:
|
if len(future_spreads) < 2:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
pnl = 0.0
|
# Calculate PnL based on direction
|
||||||
trade_duration = len(future_spreads)
|
if z > 1.0: # Short trade - profit if spread drops
|
||||||
|
exit_spread = future_spreads.iloc[-1] # Exit at horizon
|
||||||
if z > 1.0: # Short trade
|
pnl = (entry_spread - exit_spread) / entry_spread
|
||||||
tp_price = entry_spread * (1 - PROFIT_THRESHOLD)
|
else: # Long trade - profit if spread rises
|
||||||
sl_price = entry_spread * (1 + STOP_LOSS_PCT)
|
exit_spread = future_spreads.iloc[-1]
|
||||||
|
pnl = (exit_spread - entry_spread) / entry_spread
|
||||||
hit_tp = future_spreads <= tp_price
|
|
||||||
hit_sl = future_spreads >= sl_price
|
|
||||||
|
|
||||||
# Check what happened first
|
|
||||||
first_tp = np.argmax(hit_tp.values) if hit_tp.any() else 99999
|
|
||||||
first_sl = np.argmax(hit_sl.values) if hit_sl.any() else 99999
|
|
||||||
|
|
||||||
if first_sl < first_tp and first_sl < 99999:
|
|
||||||
# Stopped out
|
|
||||||
exit_price = future_spreads.iloc[first_sl] # Approx SL price
|
|
||||||
# Use exact SL price for realistic simulation? Or close
|
|
||||||
# Let's use the close price of the bar where it crossed
|
|
||||||
pnl = (entry_spread - exit_price) / entry_spread
|
|
||||||
trade_duration = first_sl + 1
|
|
||||||
elif first_tp < first_sl and first_tp < 99999:
|
|
||||||
# Take profit
|
|
||||||
exit_price = future_spreads.iloc[first_tp]
|
|
||||||
pnl = (entry_spread - exit_price) / entry_spread
|
|
||||||
trade_duration = first_tp + 1
|
|
||||||
else:
|
|
||||||
# Held to horizon
|
|
||||||
exit_price = future_spreads.iloc[-1]
|
|
||||||
pnl = (entry_spread - exit_price) / entry_spread
|
|
||||||
|
|
||||||
else: # Long trade
|
|
||||||
tp_price = entry_spread * (1 + PROFIT_THRESHOLD)
|
|
||||||
sl_price = entry_spread * (1 - STOP_LOSS_PCT)
|
|
||||||
|
|
||||||
hit_tp = future_spreads >= tp_price
|
|
||||||
hit_sl = future_spreads <= sl_price
|
|
||||||
|
|
||||||
first_tp = np.argmax(hit_tp.values) if hit_tp.any() else 99999
|
|
||||||
first_sl = np.argmax(hit_sl.values) if hit_sl.any() else 99999
|
|
||||||
|
|
||||||
if first_sl < first_tp and first_sl < 99999:
|
|
||||||
# Stopped out
|
|
||||||
exit_price = future_spreads.iloc[first_sl]
|
|
||||||
pnl = (exit_price - entry_spread) / entry_spread
|
|
||||||
trade_duration = first_sl + 1
|
|
||||||
elif first_tp < first_sl and first_tp < 99999:
|
|
||||||
# Take profit
|
|
||||||
exit_price = future_spreads.iloc[first_tp]
|
|
||||||
pnl = (exit_price - entry_spread) / entry_spread
|
|
||||||
trade_duration = first_tp + 1
|
|
||||||
else:
|
|
||||||
# Held to horizon
|
|
||||||
exit_price = future_spreads.iloc[-1]
|
|
||||||
pnl = (exit_price - entry_spread) / entry_spread
|
|
||||||
|
|
||||||
# Subtract fees
|
# Subtract fees
|
||||||
net_pnl = pnl - FEE_RATE
|
net_pnl = pnl - FEE_RATE
|
||||||
total_pnl += net_pnl
|
total_pnl += net_pnl
|
||||||
n_trades += 1
|
n_trades += 1
|
||||||
|
|
||||||
# Set next available trade index
|
|
||||||
next_trade_idx = i + trade_duration
|
|
||||||
|
|
||||||
return total_pnl, n_trades
|
return total_pnl, n_trades
|
||||||
|
|
||||||
|
|
||||||
@@ -420,7 +280,7 @@ def test_horizons(features, horizons):
|
|||||||
|
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print("WALK-FORWARD HORIZON OPTIMIZATION")
|
print("WALK-FORWARD HORIZON OPTIMIZATION")
|
||||||
print(f"Train Ratio: {TRAIN_RATIO*100:.0f}% | Profit Target: {PROFIT_THRESHOLD*100:.1f}% | Stop Loss: {STOP_LOSS_PCT*100:.1f}% | Fee Rate: {FEE_RATE*100:.2f}%")
|
print(f"Train Ratio: {TRAIN_RATIO*100:.0f}% | Profit Target: {PROFIT_THRESHOLD*100:.1f}% | Fee Rate: {FEE_RATE*100:.2f}%")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
for h in horizons:
|
for h in horizons:
|
||||||
@@ -435,54 +295,10 @@ def test_horizons(features, horizons):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
"""Parse command line arguments."""
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Regime detection research - test multiple horizons"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--days",
|
|
||||||
type=int,
|
|
||||||
default=DEFAULT_DAYS,
|
|
||||||
help=f"Number of days of data (default: {DEFAULT_DAYS})"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--start",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Start date (YYYY-MM-DD), overrides --days"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--end",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="End date (YYYY-MM-DD), defaults to now"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output",
|
|
||||||
type=str,
|
|
||||||
default="research/horizon_optimization_results.csv",
|
|
||||||
help="Output CSV path"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-horizon",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to save the best horizon (integer) to a file"
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Main research function."""
|
"""Main research function."""
|
||||||
args = parse_args()
|
# Load data
|
||||||
|
df_btc, df_eth = load_data()
|
||||||
# Load data with dynamic date range
|
|
||||||
df_btc, df_eth = load_data(
|
|
||||||
days=args.days,
|
|
||||||
start_date=args.start,
|
|
||||||
end_date=args.end
|
|
||||||
)
|
|
||||||
cq_df = load_cryptoquant_data()
|
cq_df = load_cryptoquant_data()
|
||||||
|
|
||||||
# Calculate features
|
# Calculate features
|
||||||
@@ -496,7 +312,7 @@ def main():
|
|||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
print("No valid results!")
|
print("No valid results!")
|
||||||
return None
|
return
|
||||||
|
|
||||||
# Find best by different metrics
|
# Find best by different metrics
|
||||||
results_df = pd.DataFrame(results)
|
results_df = pd.DataFrame(results)
|
||||||
@@ -515,15 +331,9 @@ def main():
|
|||||||
print(f"Lowest MAE: {lowest_mae['horizon']:.0f}h (MAE={lowest_mae['avg_mae']:.2f}%)")
|
print(f"Lowest MAE: {lowest_mae['horizon']:.0f}h (MAE={lowest_mae['avg_mae']:.2f}%)")
|
||||||
|
|
||||||
# Save results
|
# Save results
|
||||||
results_df.to_csv(args.output, index=False)
|
output_path = "research/horizon_optimization_results.csv"
|
||||||
print(f"\nResults saved to {args.output}")
|
results_df.to_csv(output_path, index=False)
|
||||||
|
print(f"\nResults saved to {output_path}")
|
||||||
# Save best horizon if requested
|
|
||||||
if args.output_horizon:
|
|
||||||
best_h = int(best_pnl['horizon'])
|
|
||||||
with open(args.output_horizon, 'w') as f:
|
|
||||||
f.write(str(best_h))
|
|
||||||
print(f"Best horizon {best_h}h saved to {args.output_horizon}")
|
|
||||||
|
|
||||||
return results_df
|
return results_df
|
||||||
|
|
||||||
|
|||||||
47
scripts/download_multi_pair_data.py
Normal file
47
scripts/download_multi_pair_data.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Download historical data for Multi-Pair Divergence Strategy.
|
||||||
|
|
||||||
|
Downloads 1h OHLCV data for top 10 cryptocurrencies from OKX.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, '.')
|
||||||
|
|
||||||
|
from engine.data_manager import DataManager
|
||||||
|
from engine.market import MarketType
|
||||||
|
from engine.logging_config import setup_logging, get_logger
|
||||||
|
from strategies.multi_pair import MultiPairConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Download data for all configured assets."""
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
|
config = MultiPairConfig()
|
||||||
|
dm = DataManager()
|
||||||
|
|
||||||
|
logger.info("Downloading data for %d assets...", len(config.assets))
|
||||||
|
|
||||||
|
for symbol in config.assets:
|
||||||
|
logger.info("Downloading %s perpetual 1h data...", symbol)
|
||||||
|
try:
|
||||||
|
df = dm.download_data(
|
||||||
|
exchange_id=config.exchange_id,
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe=config.timeframe,
|
||||||
|
market_type=MarketType.PERPETUAL
|
||||||
|
)
|
||||||
|
if df is not None:
|
||||||
|
logger.info("Downloaded %d candles for %s", len(df), symbol)
|
||||||
|
else:
|
||||||
|
logger.warning("No data downloaded for %s", symbol)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to download %s: %s", symbol, e)
|
||||||
|
|
||||||
|
logger.info("Download complete!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
156
scripts/run_multi_pair_backtest.py
Normal file
156
scripts/run_multi_pair_backtest.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Run Multi-Pair Divergence Strategy backtest and compare with baseline.
|
||||||
|
|
||||||
|
Compares the multi-pair strategy against the single-pair BTC/ETH regime strategy.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, '.')
|
||||||
|
|
||||||
|
from engine.backtester import Backtester
|
||||||
|
from engine.data_manager import DataManager
|
||||||
|
from engine.logging_config import setup_logging, get_logger
|
||||||
|
from engine.reporting import Reporter
|
||||||
|
from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig
|
||||||
|
from strategies.regime_strategy import RegimeReversionStrategy
|
||||||
|
from engine.market import MarketType
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def run_baseline():
|
||||||
|
"""Run baseline BTC/ETH regime strategy."""
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("BASELINE: BTC/ETH Regime Reversion Strategy")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
dm = DataManager()
|
||||||
|
bt = Backtester(dm)
|
||||||
|
|
||||||
|
strategy = RegimeReversionStrategy()
|
||||||
|
|
||||||
|
result = bt.run_strategy(
|
||||||
|
strategy,
|
||||||
|
'okx',
|
||||||
|
'ETH-USDT',
|
||||||
|
timeframe='1h',
|
||||||
|
init_cash=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Baseline Results:")
|
||||||
|
logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100)
|
||||||
|
logger.info(" Total Trades: %d", result.portfolio.trades.count())
|
||||||
|
logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def run_multi_pair(assets: list[str] | None = None):
|
||||||
|
"""Run multi-pair divergence strategy."""
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("MULTI-PAIR: Divergence Selection Strategy")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
dm = DataManager()
|
||||||
|
bt = Backtester(dm)
|
||||||
|
|
||||||
|
# Use provided assets or default
|
||||||
|
if assets:
|
||||||
|
config = MultiPairConfig(assets=assets)
|
||||||
|
else:
|
||||||
|
config = MultiPairConfig()
|
||||||
|
|
||||||
|
logger.info("Configured %d assets, %d pairs", len(config.assets), config.get_pair_count())
|
||||||
|
|
||||||
|
strategy = MultiPairDivergenceStrategy(config=config)
|
||||||
|
|
||||||
|
result = bt.run_strategy(
|
||||||
|
strategy,
|
||||||
|
'okx',
|
||||||
|
'ETH-USDT', # Reference asset (not used for trading, just index alignment)
|
||||||
|
timeframe='1h',
|
||||||
|
init_cash=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Multi-Pair Results:")
|
||||||
|
logger.info(" Total Return: %.2f%%", result.portfolio.total_return() * 100)
|
||||||
|
logger.info(" Total Trades: %d", result.portfolio.trades.count())
|
||||||
|
logger.info(" Win Rate: %.1f%%", result.portfolio.trades.win_rate() * 100)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def compare_results(baseline, multi_pair):
|
||||||
|
"""Compare and display results."""
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("COMPARISON")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
baseline_return = baseline.portfolio.total_return() * 100
|
||||||
|
multi_return = multi_pair.portfolio.total_return() * 100
|
||||||
|
|
||||||
|
improvement = multi_return - baseline_return
|
||||||
|
|
||||||
|
logger.info("Baseline Return: %.2f%%", baseline_return)
|
||||||
|
logger.info("Multi-Pair Return: %.2f%%", multi_return)
|
||||||
|
logger.info("Improvement: %.2f%% (%.1fx)",
|
||||||
|
improvement,
|
||||||
|
multi_return / baseline_return if baseline_return != 0 else 0)
|
||||||
|
|
||||||
|
baseline_trades = baseline.portfolio.trades.count()
|
||||||
|
multi_trades = multi_pair.portfolio.trades.count()
|
||||||
|
|
||||||
|
logger.info("Baseline Trades: %d", baseline_trades)
|
||||||
|
logger.info("Multi-Pair Trades: %d", multi_trades)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'baseline_return': baseline_return,
|
||||||
|
'multi_pair_return': multi_return,
|
||||||
|
'improvement': improvement,
|
||||||
|
'baseline_trades': baseline_trades,
|
||||||
|
'multi_pair_trades': multi_trades
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point."""
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
|
# Check available assets
|
||||||
|
dm = DataManager()
|
||||||
|
available = []
|
||||||
|
|
||||||
|
for symbol in MultiPairConfig().assets:
|
||||||
|
try:
|
||||||
|
dm.load_data('okx', symbol, '1h', market_type=MarketType.PERPETUAL)
|
||||||
|
available.append(symbol)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if len(available) < 2:
|
||||||
|
logger.error(
|
||||||
|
"Need at least 2 assets to run multi-pair strategy. "
|
||||||
|
"Run: uv run python scripts/download_multi_pair_data.py"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Found data for %d assets: %s", len(available), available)
|
||||||
|
|
||||||
|
# Run baseline
|
||||||
|
baseline_result = run_baseline()
|
||||||
|
|
||||||
|
# Run multi-pair
|
||||||
|
multi_result = run_multi_pair(available)
|
||||||
|
|
||||||
|
# Compare
|
||||||
|
comparison = compare_results(baseline_result, multi_result)
|
||||||
|
|
||||||
|
# Save reports
|
||||||
|
reporter = Reporter()
|
||||||
|
reporter.save_reports(multi_result, "multi_pair_divergence")
|
||||||
|
|
||||||
|
logger.info("Reports saved to backtest_logs/")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -37,6 +37,7 @@ def _build_registry() -> dict[str, StrategyConfig]:
|
|||||||
from strategies.examples import MaCrossStrategy, RsiStrategy
|
from strategies.examples import MaCrossStrategy, RsiStrategy
|
||||||
from strategies.supertrend import MetaSupertrendStrategy
|
from strategies.supertrend import MetaSupertrendStrategy
|
||||||
from strategies.regime_strategy import RegimeReversionStrategy
|
from strategies.regime_strategy import RegimeReversionStrategy
|
||||||
|
from strategies.multi_pair import MultiPairDivergenceStrategy, MultiPairConfig
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"rsi": StrategyConfig(
|
"rsi": StrategyConfig(
|
||||||
@@ -98,6 +99,18 @@ def _build_registry() -> dict[str, StrategyConfig]:
|
|||||||
'stop_loss': [0.04, 0.06, 0.08],
|
'stop_loss': [0.04, 0.06, 0.08],
|
||||||
'funding_threshold': [0.005, 0.01, 0.02]
|
'funding_threshold': [0.005, 0.01, 0.02]
|
||||||
}
|
}
|
||||||
|
),
|
||||||
|
"multi_pair": StrategyConfig(
|
||||||
|
strategy_class=MultiPairDivergenceStrategy,
|
||||||
|
default_params={
|
||||||
|
# Multi-pair divergence strategy uses config object
|
||||||
|
# Parameters passed here will override MultiPairConfig defaults
|
||||||
|
},
|
||||||
|
grid_params={
|
||||||
|
'z_entry_threshold': [0.8, 1.0, 1.2],
|
||||||
|
'prob_threshold': [0.4, 0.5, 0.6],
|
||||||
|
'correlation_threshold': [0.75, 0.85, 0.95]
|
||||||
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
24
strategies/multi_pair/__init__.py
Normal file
24
strategies/multi_pair/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
Multi-Pair Divergence Selection Strategy.
|
||||||
|
|
||||||
|
Extends regime detection to multiple cryptocurrency pairs and dynamically
|
||||||
|
selects the most divergent pair for trading.
|
||||||
|
"""
|
||||||
|
from .config import MultiPairConfig
|
||||||
|
from .pair_scanner import PairScanner, TradingPair
|
||||||
|
from .correlation import CorrelationFilter
|
||||||
|
from .feature_engine import MultiPairFeatureEngine
|
||||||
|
from .divergence_scorer import DivergenceScorer
|
||||||
|
from .strategy import MultiPairDivergenceStrategy
|
||||||
|
from .funding import FundingRateFetcher
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MultiPairConfig",
|
||||||
|
"PairScanner",
|
||||||
|
"TradingPair",
|
||||||
|
"CorrelationFilter",
|
||||||
|
"MultiPairFeatureEngine",
|
||||||
|
"DivergenceScorer",
|
||||||
|
"MultiPairDivergenceStrategy",
|
||||||
|
"FundingRateFetcher",
|
||||||
|
]
|
||||||
88
strategies/multi_pair/config.py
Normal file
88
strategies/multi_pair/config.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""
|
||||||
|
Configuration for Multi-Pair Divergence Strategy.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiPairConfig:
|
||||||
|
"""
|
||||||
|
Configuration parameters for multi-pair divergence strategy.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
assets: List of asset symbols to analyze (top 10 by market cap)
|
||||||
|
z_window: Rolling window for Z-Score calculation (hours)
|
||||||
|
z_entry_threshold: Minimum |Z-Score| to consider for entry
|
||||||
|
prob_threshold: Minimum ML probability to consider for entry
|
||||||
|
correlation_threshold: Max correlation to allow between pairs
|
||||||
|
correlation_window: Rolling window for correlation (hours)
|
||||||
|
atr_period: ATR lookback period for dynamic stops
|
||||||
|
sl_atr_multiplier: Stop-loss as multiple of ATR
|
||||||
|
tp_atr_multiplier: Take-profit as multiple of ATR
|
||||||
|
train_ratio: Walk-forward train/test split ratio
|
||||||
|
horizon: Look-ahead horizon for target calculation (hours)
|
||||||
|
profit_target: Minimum profit threshold for target labels
|
||||||
|
funding_threshold: Funding rate threshold for filtering
|
||||||
|
"""
|
||||||
|
# Asset Universe
|
||||||
|
assets: list[str] = field(default_factory=lambda: [
|
||||||
|
"BTC-USDT", "ETH-USDT", "SOL-USDT", "XRP-USDT", "BNB-USDT",
|
||||||
|
"DOGE-USDT", "ADA-USDT", "AVAX-USDT", "LINK-USDT", "DOT-USDT"
|
||||||
|
])
|
||||||
|
|
||||||
|
# Z-Score Thresholds
|
||||||
|
z_window: int = 24
|
||||||
|
z_entry_threshold: float = 1.0
|
||||||
|
|
||||||
|
# ML Thresholds
|
||||||
|
prob_threshold: float = 0.5
|
||||||
|
train_ratio: float = 0.7
|
||||||
|
horizon: int = 102
|
||||||
|
profit_target: float = 0.005
|
||||||
|
|
||||||
|
# Correlation Filtering
|
||||||
|
correlation_threshold: float = 0.85
|
||||||
|
correlation_window: int = 168 # 7 days in hours
|
||||||
|
|
||||||
|
# Risk Management - ATR-Based Stops
|
||||||
|
# SL/TP are calculated as multiples of ATR
|
||||||
|
# Mean ATR for crypto is ~0.6% per hour, so:
|
||||||
|
# - 10x ATR = ~6% SL (matches previous fixed 6%)
|
||||||
|
# - 8x ATR = ~5% TP (matches previous fixed 5%)
|
||||||
|
atr_period: int = 14 # ATR lookback period (hours for 1h timeframe)
|
||||||
|
sl_atr_multiplier: float = 10.0 # Stop-loss = entry +/- (ATR * multiplier)
|
||||||
|
tp_atr_multiplier: float = 8.0 # Take-profit = entry +/- (ATR * multiplier)
|
||||||
|
|
||||||
|
# Fallback fixed percentages (used if ATR is unavailable)
|
||||||
|
base_sl_pct: float = 0.06
|
||||||
|
base_tp_pct: float = 0.05
|
||||||
|
|
||||||
|
# ATR bounds to prevent extreme stops
|
||||||
|
min_sl_pct: float = 0.02 # Minimum 2% stop-loss
|
||||||
|
max_sl_pct: float = 0.10 # Maximum 10% stop-loss
|
||||||
|
min_tp_pct: float = 0.02 # Minimum 2% take-profit
|
||||||
|
max_tp_pct: float = 0.15 # Maximum 15% take-profit
|
||||||
|
|
||||||
|
volatility_window: int = 24
|
||||||
|
|
||||||
|
# Funding Rate Filter
|
||||||
|
# OKX funding rates are typically 0.0001 (0.01%) per 8h
|
||||||
|
# Extreme funding is > 0.0005 (0.05%) which indicates crowded trade
|
||||||
|
funding_threshold: float = 0.0005 # 0.05% - filter extreme funding
|
||||||
|
|
||||||
|
# Trade Management
|
||||||
|
# Note: Setting min_hold_bars=0 and z_exit_threshold=0 gives best results
|
||||||
|
# The mean-reversion exit at Z=0 is the primary profit driver
|
||||||
|
min_hold_bars: int = 0 # Disabled - let mean reversion drive exits
|
||||||
|
switch_threshold: float = 999.0 # Disabled - don't switch mid-trade
|
||||||
|
cooldown_bars: int = 0 # Disabled - enter when signal appears
|
||||||
|
z_exit_threshold: float = 0.0 # Exit at Z=0 (mean reversion complete)
|
||||||
|
|
||||||
|
# Exchange
|
||||||
|
exchange_id: str = "okx"
|
||||||
|
timeframe: str = "1h"
|
||||||
|
|
||||||
|
def get_pair_count(self) -> int:
|
||||||
|
"""Calculate number of unique pairs from asset list."""
|
||||||
|
n = len(self.assets)
|
||||||
|
return n * (n - 1) // 2
|
||||||
173
strategies/multi_pair/correlation.py
Normal file
173
strategies/multi_pair/correlation.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""
|
||||||
|
Correlation Filter for Multi-Pair Divergence Strategy.
|
||||||
|
|
||||||
|
Calculates rolling correlation matrix and filters pairs
|
||||||
|
to avoid highly correlated positions.
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from engine.logging_config import get_logger
|
||||||
|
from .config import MultiPairConfig
|
||||||
|
from .pair_scanner import TradingPair
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CorrelationFilter:
|
||||||
|
"""
|
||||||
|
Calculates and filters based on asset correlations.
|
||||||
|
|
||||||
|
Uses rolling correlation of returns to identify assets
|
||||||
|
moving together, avoiding redundant positions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: MultiPairConfig):
|
||||||
|
self.config = config
|
||||||
|
self._correlation_matrix: pd.DataFrame | None = None
|
||||||
|
self._last_update_idx: int = -1
|
||||||
|
|
||||||
|
def calculate_correlation_matrix(
|
||||||
|
self,
|
||||||
|
price_data: dict[str, pd.Series],
|
||||||
|
current_idx: int | None = None
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Calculate rolling correlation matrix between all assets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
price_data: Dictionary mapping asset symbols to price series
|
||||||
|
current_idx: Current bar index (for caching)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Correlation matrix DataFrame
|
||||||
|
"""
|
||||||
|
# Use cached if recent
|
||||||
|
if (
|
||||||
|
current_idx is not None
|
||||||
|
and self._correlation_matrix is not None
|
||||||
|
and current_idx - self._last_update_idx < 24 # Update every 24 bars
|
||||||
|
):
|
||||||
|
return self._correlation_matrix
|
||||||
|
|
||||||
|
# Calculate returns
|
||||||
|
returns = {}
|
||||||
|
for symbol, prices in price_data.items():
|
||||||
|
returns[symbol] = prices.pct_change()
|
||||||
|
|
||||||
|
returns_df = pd.DataFrame(returns)
|
||||||
|
|
||||||
|
# Rolling correlation
|
||||||
|
window = self.config.correlation_window
|
||||||
|
|
||||||
|
# Get latest correlation (last row of rolling correlation)
|
||||||
|
if len(returns_df) >= window:
|
||||||
|
rolling_corr = returns_df.rolling(window=window).corr()
|
||||||
|
# Extract last timestamp correlation matrix
|
||||||
|
last_idx = returns_df.index[-1]
|
||||||
|
corr_matrix = rolling_corr.loc[last_idx]
|
||||||
|
else:
|
||||||
|
# Fallback to full-period correlation if not enough data
|
||||||
|
corr_matrix = returns_df.corr()
|
||||||
|
|
||||||
|
self._correlation_matrix = corr_matrix
|
||||||
|
if current_idx is not None:
|
||||||
|
self._last_update_idx = current_idx
|
||||||
|
|
||||||
|
return corr_matrix
|
||||||
|
|
||||||
|
def filter_pairs(
|
||||||
|
self,
|
||||||
|
pairs: list[TradingPair],
|
||||||
|
current_position_asset: str | None,
|
||||||
|
price_data: dict[str, pd.Series],
|
||||||
|
current_idx: int | None = None
|
||||||
|
) -> list[TradingPair]:
|
||||||
|
"""
|
||||||
|
Filter pairs based on correlation with current position.
|
||||||
|
|
||||||
|
If we have an open position in an asset, exclude pairs where
|
||||||
|
either asset is highly correlated with the held asset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pairs: List of candidate pairs
|
||||||
|
current_position_asset: Currently held asset (or None)
|
||||||
|
price_data: Dictionary of price series by symbol
|
||||||
|
current_idx: Current bar index for caching
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of pairs
|
||||||
|
"""
|
||||||
|
if current_position_asset is None:
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
corr_matrix = self.calculate_correlation_matrix(price_data, current_idx)
|
||||||
|
threshold = self.config.correlation_threshold
|
||||||
|
|
||||||
|
filtered = []
|
||||||
|
for pair in pairs:
|
||||||
|
# Check correlation of base and quote with held asset
|
||||||
|
base_corr = self._get_correlation(
|
||||||
|
corr_matrix, pair.base_asset, current_position_asset
|
||||||
|
)
|
||||||
|
quote_corr = self._get_correlation(
|
||||||
|
corr_matrix, pair.quote_asset, current_position_asset
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter if either asset highly correlated with position
|
||||||
|
if abs(base_corr) > threshold or abs(quote_corr) > threshold:
|
||||||
|
logger.debug(
|
||||||
|
"Filtered %s: base_corr=%.2f, quote_corr=%.2f (held: %s)",
|
||||||
|
pair.name, base_corr, quote_corr, current_position_asset
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered.append(pair)
|
||||||
|
|
||||||
|
if len(filtered) < len(pairs):
|
||||||
|
logger.info(
|
||||||
|
"Correlation filter: %d/%d pairs remaining (held: %s)",
|
||||||
|
len(filtered), len(pairs), current_position_asset
|
||||||
|
)
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
def _get_correlation(
|
||||||
|
self,
|
||||||
|
corr_matrix: pd.DataFrame,
|
||||||
|
asset1: str,
|
||||||
|
asset2: str
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Get correlation between two assets from matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corr_matrix: Correlation matrix
|
||||||
|
asset1: First asset symbol
|
||||||
|
asset2: Second asset symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Correlation coefficient (-1 to 1), or 0 if not found
|
||||||
|
"""
|
||||||
|
if asset1 == asset2:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
return corr_matrix.loc[asset1, asset2]
|
||||||
|
except KeyError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def get_correlation_report(
|
||||||
|
self,
|
||||||
|
price_data: dict[str, pd.Series]
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Generate a readable correlation report.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
price_data: Dictionary of price series
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Correlation matrix as DataFrame
|
||||||
|
"""
|
||||||
|
return self.calculate_correlation_matrix(price_data)
|
||||||
311
strategies/multi_pair/divergence_scorer.py
Normal file
311
strategies/multi_pair/divergence_scorer.py
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
"""
|
||||||
|
Divergence Scorer for Multi-Pair Strategy.
|
||||||
|
|
||||||
|
Ranks pairs by divergence score and selects the best candidate.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from engine.logging_config import get_logger
|
||||||
|
from .config import MultiPairConfig
|
||||||
|
from .pair_scanner import TradingPair
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DivergenceSignal:
|
||||||
|
"""
|
||||||
|
Signal for a divergent pair.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pair: Trading pair
|
||||||
|
z_score: Current Z-Score of the spread
|
||||||
|
probability: ML model probability of profitable reversion
|
||||||
|
divergence_score: Combined score (|z_score| * probability)
|
||||||
|
direction: 'long' or 'short' (relative to base asset)
|
||||||
|
base_price: Current price of base asset
|
||||||
|
quote_price: Current price of quote asset
|
||||||
|
atr: Average True Range in price units
|
||||||
|
atr_pct: ATR as percentage of price
|
||||||
|
"""
|
||||||
|
pair: TradingPair
|
||||||
|
z_score: float
|
||||||
|
probability: float
|
||||||
|
divergence_score: float
|
||||||
|
direction: str
|
||||||
|
base_price: float
|
||||||
|
quote_price: float
|
||||||
|
atr: float
|
||||||
|
atr_pct: float
|
||||||
|
timestamp: pd.Timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class DivergenceScorer:
|
||||||
|
"""
|
||||||
|
Scores and ranks pairs by divergence potential.
|
||||||
|
|
||||||
|
Uses ML model predictions combined with Z-Score magnitude
|
||||||
|
to identify the most promising mean-reversion opportunity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: MultiPairConfig, model_path: str = "data/multi_pair_model.pkl"):
|
||||||
|
self.config = config
|
||||||
|
self.model_path = Path(model_path)
|
||||||
|
self.model: RandomForestClassifier | None = None
|
||||||
|
self.feature_cols: list[str] | None = None
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self) -> None:
|
||||||
|
"""Load pre-trained model if available."""
|
||||||
|
if self.model_path.exists():
|
||||||
|
try:
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
saved = pickle.load(f)
|
||||||
|
self.model = saved['model']
|
||||||
|
self.feature_cols = saved['feature_cols']
|
||||||
|
logger.info("Loaded model from %s", self.model_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not load model: %s", e)
|
||||||
|
|
||||||
|
def save_model(self) -> None:
|
||||||
|
"""Save trained model."""
|
||||||
|
if self.model is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(self.model_path, 'wb') as f:
|
||||||
|
pickle.dump({
|
||||||
|
'model': self.model,
|
||||||
|
'feature_cols': self.feature_cols,
|
||||||
|
}, f)
|
||||||
|
logger.info("Saved model to %s", self.model_path)
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
self,
|
||||||
|
combined_features: pd.DataFrame,
|
||||||
|
pair_features: dict[str, pd.DataFrame]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Train universal model on all pairs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
combined_features: Combined feature DataFrame from all pairs
|
||||||
|
pair_features: Individual pair feature DataFrames (for target calculation)
|
||||||
|
"""
|
||||||
|
logger.info("Training universal model on %d samples...", len(combined_features))
|
||||||
|
|
||||||
|
z_thresh = self.config.z_entry_threshold
|
||||||
|
horizon = self.config.horizon
|
||||||
|
profit_target = self.config.profit_target
|
||||||
|
|
||||||
|
# Calculate targets for each pair
|
||||||
|
all_targets = []
|
||||||
|
all_features = []
|
||||||
|
|
||||||
|
for pair_id, features in pair_features.items():
|
||||||
|
if len(features) < horizon + 50:
|
||||||
|
continue
|
||||||
|
|
||||||
|
spread = features['spread']
|
||||||
|
z_score = features['z_score']
|
||||||
|
|
||||||
|
# Future price movements
|
||||||
|
future_min = spread.rolling(window=horizon).min().shift(-horizon)
|
||||||
|
future_max = spread.rolling(window=horizon).max().shift(-horizon)
|
||||||
|
|
||||||
|
# Target labels
|
||||||
|
target_short = spread * (1 - profit_target)
|
||||||
|
target_long = spread * (1 + profit_target)
|
||||||
|
|
||||||
|
success_short = (z_score > z_thresh) & (future_min < target_short)
|
||||||
|
success_long = (z_score < -z_thresh) & (future_max > target_long)
|
||||||
|
|
||||||
|
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||||
|
|
||||||
|
# Valid mask (exclude rows without complete future data)
|
||||||
|
valid_mask = future_min.notna() & future_max.notna()
|
||||||
|
|
||||||
|
# Collect valid samples
|
||||||
|
valid_features = features[valid_mask]
|
||||||
|
valid_targets = targets[valid_mask.values]
|
||||||
|
|
||||||
|
if len(valid_features) > 0:
|
||||||
|
all_features.append(valid_features)
|
||||||
|
all_targets.extend(valid_targets)
|
||||||
|
|
||||||
|
if not all_features:
|
||||||
|
logger.warning("No valid training samples")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Combine all training data
|
||||||
|
X_df = pd.concat(all_features, ignore_index=True)
|
||||||
|
y = np.array(all_targets)
|
||||||
|
|
||||||
|
# Get feature columns
|
||||||
|
exclude_cols = [
|
||||||
|
'pair_id', 'base_asset', 'quote_asset',
|
||||||
|
'spread', 'base_close', 'quote_close', 'base_volume'
|
||||||
|
]
|
||||||
|
self.feature_cols = [c for c in X_df.columns if c not in exclude_cols]
|
||||||
|
|
||||||
|
# Prepare features
|
||||||
|
X = X_df[self.feature_cols].fillna(0)
|
||||||
|
X = X.replace([np.inf, -np.inf], 0)
|
||||||
|
|
||||||
|
# Train model
|
||||||
|
self.model = RandomForestClassifier(
|
||||||
|
n_estimators=300,
|
||||||
|
max_depth=5,
|
||||||
|
min_samples_leaf=30,
|
||||||
|
class_weight={0: 1, 1: 3},
|
||||||
|
random_state=42
|
||||||
|
)
|
||||||
|
self.model.fit(X, y)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Model trained on %d samples, %d features, %.1f%% positive class",
|
||||||
|
len(X), len(self.feature_cols), y.mean() * 100
|
||||||
|
)
|
||||||
|
self.save_model()
|
||||||
|
|
||||||
|
def score_pairs(
|
||||||
|
self,
|
||||||
|
pair_features: dict[str, pd.DataFrame],
|
||||||
|
pairs: list[TradingPair],
|
||||||
|
timestamp: pd.Timestamp | None = None
|
||||||
|
) -> list[DivergenceSignal]:
|
||||||
|
"""
|
||||||
|
Score all pairs and return ranked signals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pair_features: Feature DataFrames by pair_id
|
||||||
|
pairs: List of TradingPair objects
|
||||||
|
timestamp: Current timestamp for feature extraction
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of DivergenceSignal sorted by score (descending)
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
logger.warning("Model not trained, returning empty signals")
|
||||||
|
return []
|
||||||
|
|
||||||
|
signals = []
|
||||||
|
pair_map = {p.pair_id: p for p in pairs}
|
||||||
|
|
||||||
|
for pair_id, features in pair_features.items():
|
||||||
|
if pair_id not in pair_map:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pair = pair_map[pair_id]
|
||||||
|
|
||||||
|
# Get latest features
|
||||||
|
if timestamp is not None:
|
||||||
|
valid = features[features.index <= timestamp]
|
||||||
|
if len(valid) == 0:
|
||||||
|
continue
|
||||||
|
latest = valid.iloc[-1]
|
||||||
|
ts = valid.index[-1]
|
||||||
|
else:
|
||||||
|
latest = features.iloc[-1]
|
||||||
|
ts = features.index[-1]
|
||||||
|
|
||||||
|
z_score = latest['z_score']
|
||||||
|
|
||||||
|
# Skip if Z-score below threshold
|
||||||
|
if abs(z_score) < self.config.z_entry_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare features for prediction
|
||||||
|
feature_row = latest[self.feature_cols].fillna(0).infer_objects(copy=False)
|
||||||
|
feature_row = feature_row.replace([np.inf, -np.inf], 0)
|
||||||
|
X = pd.DataFrame([feature_row.values], columns=self.feature_cols)
|
||||||
|
|
||||||
|
# Predict probability
|
||||||
|
prob = self.model.predict_proba(X)[0, 1]
|
||||||
|
|
||||||
|
# Skip if probability below threshold
|
||||||
|
if prob < self.config.prob_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Apply funding rate filter
|
||||||
|
# Block trades where funding opposes our direction
|
||||||
|
base_funding = latest.get('base_funding', 0) or 0
|
||||||
|
funding_thresh = self.config.funding_threshold
|
||||||
|
|
||||||
|
if z_score > 0: # Short signal
|
||||||
|
# High negative funding = shorts are paying -> skip
|
||||||
|
if base_funding < -funding_thresh:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping %s short: funding too negative (%.4f)",
|
||||||
|
pair.name, base_funding
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else: # Long signal
|
||||||
|
# High positive funding = longs are paying -> skip
|
||||||
|
if base_funding > funding_thresh:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping %s long: funding too positive (%.4f)",
|
||||||
|
pair.name, base_funding
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate divergence score
|
||||||
|
divergence_score = abs(z_score) * prob
|
||||||
|
|
||||||
|
# Determine direction
|
||||||
|
# Z > 0: Spread high (base expensive vs quote) -> Short base
|
||||||
|
# Z < 0: Spread low (base cheap vs quote) -> Long base
|
||||||
|
direction = 'short' if z_score > 0 else 'long'
|
||||||
|
|
||||||
|
signal = DivergenceSignal(
|
||||||
|
pair=pair,
|
||||||
|
z_score=z_score,
|
||||||
|
probability=prob,
|
||||||
|
divergence_score=divergence_score,
|
||||||
|
direction=direction,
|
||||||
|
base_price=latest['base_close'],
|
||||||
|
quote_price=latest['quote_close'],
|
||||||
|
atr=latest.get('atr_base', 0),
|
||||||
|
atr_pct=latest.get('atr_pct_base', 0.02),
|
||||||
|
timestamp=ts
|
||||||
|
)
|
||||||
|
signals.append(signal)
|
||||||
|
|
||||||
|
# Sort by divergence score (highest first)
|
||||||
|
signals.sort(key=lambda s: s.divergence_score, reverse=True)
|
||||||
|
|
||||||
|
if signals:
|
||||||
|
logger.debug(
|
||||||
|
"Scored %d pairs, top: %s (score=%.3f, z=%.2f, p=%.2f)",
|
||||||
|
len(signals),
|
||||||
|
signals[0].pair.name,
|
||||||
|
signals[0].divergence_score,
|
||||||
|
signals[0].z_score,
|
||||||
|
signals[0].probability
|
||||||
|
)
|
||||||
|
|
||||||
|
return signals
|
||||||
|
|
||||||
|
def select_best_pair(
|
||||||
|
self,
|
||||||
|
signals: list[DivergenceSignal]
|
||||||
|
) -> DivergenceSignal | None:
|
||||||
|
"""
|
||||||
|
Select the best pair from scored signals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signals: List of DivergenceSignal (pre-sorted by score)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Best signal or None if no valid candidates
|
||||||
|
"""
|
||||||
|
if not signals:
|
||||||
|
return None
|
||||||
|
return signals[0]
|
||||||
433
strategies/multi_pair/feature_engine.py
Normal file
433
strategies/multi_pair/feature_engine.py
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
"""
|
||||||
|
Feature Engineering for Multi-Pair Divergence Strategy.
|
||||||
|
|
||||||
|
Calculates features for all pairs in the universe, including
|
||||||
|
spread technicals, volatility, and on-chain data.
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import ta
|
||||||
|
|
||||||
|
from engine.logging_config import get_logger
|
||||||
|
from engine.data_manager import DataManager
|
||||||
|
from engine.market import MarketType
|
||||||
|
from .config import MultiPairConfig
|
||||||
|
from .pair_scanner import TradingPair
|
||||||
|
from .funding import FundingRateFetcher
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPairFeatureEngine:
|
||||||
|
"""
|
||||||
|
Calculates features for multiple trading pairs.
|
||||||
|
|
||||||
|
Generates consistent feature sets across all pairs for
|
||||||
|
the universal ML model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: MultiPairConfig):
|
||||||
|
self.config = config
|
||||||
|
self.dm = DataManager()
|
||||||
|
self.funding_fetcher = FundingRateFetcher()
|
||||||
|
self._funding_data: pd.DataFrame | None = None
|
||||||
|
|
||||||
|
def load_all_assets(
|
||||||
|
self,
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None
|
||||||
|
) -> dict[str, pd.DataFrame]:
|
||||||
|
"""
|
||||||
|
Load OHLCV data for all assets in the universe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date filter (YYYY-MM-DD)
|
||||||
|
end_date: End date filter (YYYY-MM-DD)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping symbol to OHLCV DataFrame
|
||||||
|
"""
|
||||||
|
data = {}
|
||||||
|
market_type = MarketType.PERPETUAL
|
||||||
|
|
||||||
|
for symbol in self.config.assets:
|
||||||
|
try:
|
||||||
|
df = self.dm.load_data(
|
||||||
|
self.config.exchange_id,
|
||||||
|
symbol,
|
||||||
|
self.config.timeframe,
|
||||||
|
market_type
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply date filters
|
||||||
|
if start_date:
|
||||||
|
df = df[df.index >= pd.Timestamp(start_date, tz="UTC")]
|
||||||
|
if end_date:
|
||||||
|
df = df[df.index <= pd.Timestamp(end_date, tz="UTC")]
|
||||||
|
|
||||||
|
if len(df) >= 200: # Minimum data requirement
|
||||||
|
data[symbol] = df
|
||||||
|
logger.debug("Loaded %s: %d bars", symbol, len(df))
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping %s: insufficient data (%d bars)",
|
||||||
|
symbol, len(df)
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning("Data not found for %s", symbol)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error loading %s: %s", symbol, e)
|
||||||
|
|
||||||
|
logger.info("Loaded %d/%d assets", len(data), len(self.config.assets))
|
||||||
|
return data
|
||||||
|
|
||||||
|
def load_funding_data(
|
||||||
|
self,
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None,
|
||||||
|
use_cache: bool = True
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Load funding rate data for all assets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start date filter
|
||||||
|
end_date: End date filter
|
||||||
|
use_cache: Whether to use cached data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with funding rates for all assets
|
||||||
|
"""
|
||||||
|
self._funding_data = self.funding_fetcher.get_funding_data(
|
||||||
|
self.config.assets,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
use_cache=use_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._funding_data is not None and not self._funding_data.empty:
|
||||||
|
logger.info(
|
||||||
|
"Loaded funding data: %d rows, %d assets",
|
||||||
|
len(self._funding_data),
|
||||||
|
len(self._funding_data.columns)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("No funding data available")
|
||||||
|
|
||||||
|
return self._funding_data
|
||||||
|
|
||||||
|
def calculate_pair_features(
|
||||||
|
self,
|
||||||
|
pair: TradingPair,
|
||||||
|
asset_data: dict[str, pd.DataFrame],
|
||||||
|
on_chain_data: pd.DataFrame | None = None
|
||||||
|
) -> pd.DataFrame | None:
|
||||||
|
"""
|
||||||
|
Calculate features for a single pair.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pair: Trading pair
|
||||||
|
asset_data: Dictionary of OHLCV DataFrames by symbol
|
||||||
|
on_chain_data: Optional on-chain data (funding, inflows)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with features, or None if insufficient data
|
||||||
|
"""
|
||||||
|
base = pair.base_asset
|
||||||
|
quote = pair.quote_asset
|
||||||
|
|
||||||
|
if base not in asset_data or quote not in asset_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
df_base = asset_data[base]
|
||||||
|
df_quote = asset_data[quote]
|
||||||
|
|
||||||
|
# Align indices
|
||||||
|
common_idx = df_base.index.intersection(df_quote.index)
|
||||||
|
if len(common_idx) < 200:
|
||||||
|
logger.debug("Pair %s: insufficient aligned data", pair.name)
|
||||||
|
return None
|
||||||
|
|
||||||
|
df_a = df_base.loc[common_idx]
|
||||||
|
df_b = df_quote.loc[common_idx]
|
||||||
|
|
||||||
|
# Calculate spread (base / quote)
|
||||||
|
spread = df_a['close'] / df_b['close']
|
||||||
|
|
||||||
|
# Z-Score
|
||||||
|
z_window = self.config.z_window
|
||||||
|
rolling_mean = spread.rolling(window=z_window).mean()
|
||||||
|
rolling_std = spread.rolling(window=z_window).std()
|
||||||
|
z_score = (spread - rolling_mean) / rolling_std
|
||||||
|
|
||||||
|
# Spread Technicals
|
||||||
|
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||||
|
spread_roc = spread.pct_change(periods=5) * 100
|
||||||
|
spread_change_1h = spread.pct_change(periods=1)
|
||||||
|
|
||||||
|
# Volume Analysis
|
||||||
|
vol_ratio = df_a['volume'] / (df_b['volume'] + 1e-10)
|
||||||
|
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||||
|
vol_ratio_rel = vol_ratio / (vol_ratio_ma + 1e-10)
|
||||||
|
|
||||||
|
# Volatility
|
||||||
|
ret_a = df_a['close'].pct_change()
|
||||||
|
ret_b = df_b['close'].pct_change()
|
||||||
|
vol_a = ret_a.rolling(window=z_window).std()
|
||||||
|
vol_b = ret_b.rolling(window=z_window).std()
|
||||||
|
vol_spread_ratio = vol_a / (vol_b + 1e-10)
|
||||||
|
|
||||||
|
# Realized Volatility (for dynamic SL/TP)
|
||||||
|
realized_vol_a = ret_a.rolling(window=self.config.volatility_window).std()
|
||||||
|
realized_vol_b = ret_b.rolling(window=self.config.volatility_window).std()
|
||||||
|
|
||||||
|
# ATR (Average True Range) for dynamic stops
|
||||||
|
# ATR = average of max(high-low, |high-prev_close|, |low-prev_close|)
|
||||||
|
high_a, low_a, close_a = df_a['high'], df_a['low'], df_a['close']
|
||||||
|
high_b, low_b, close_b = df_b['high'], df_b['low'], df_b['close']
|
||||||
|
|
||||||
|
# True Range for base asset
|
||||||
|
tr_a = pd.concat([
|
||||||
|
high_a - low_a,
|
||||||
|
(high_a - close_a.shift(1)).abs(),
|
||||||
|
(low_a - close_a.shift(1)).abs()
|
||||||
|
], axis=1).max(axis=1)
|
||||||
|
atr_a = tr_a.rolling(window=self.config.atr_period).mean()
|
||||||
|
|
||||||
|
# True Range for quote asset
|
||||||
|
tr_b = pd.concat([
|
||||||
|
high_b - low_b,
|
||||||
|
(high_b - close_b.shift(1)).abs(),
|
||||||
|
(low_b - close_b.shift(1)).abs()
|
||||||
|
], axis=1).max(axis=1)
|
||||||
|
atr_b = tr_b.rolling(window=self.config.atr_period).mean()
|
||||||
|
|
||||||
|
# ATR as percentage of price (normalized)
|
||||||
|
atr_pct_a = atr_a / close_a
|
||||||
|
atr_pct_b = atr_b / close_b
|
||||||
|
|
||||||
|
# Build feature DataFrame
|
||||||
|
features = pd.DataFrame(index=common_idx)
|
||||||
|
features['pair_id'] = pair.pair_id
|
||||||
|
features['base_asset'] = base
|
||||||
|
features['quote_asset'] = quote
|
||||||
|
|
||||||
|
# Price data (for reference, not features)
|
||||||
|
features['spread'] = spread
|
||||||
|
features['base_close'] = df_a['close']
|
||||||
|
features['quote_close'] = df_b['close']
|
||||||
|
features['base_volume'] = df_a['volume']
|
||||||
|
|
||||||
|
# Core Features
|
||||||
|
features['z_score'] = z_score
|
||||||
|
features['spread_rsi'] = spread_rsi
|
||||||
|
features['spread_roc'] = spread_roc
|
||||||
|
features['spread_change_1h'] = spread_change_1h
|
||||||
|
features['vol_ratio'] = vol_ratio
|
||||||
|
features['vol_ratio_rel'] = vol_ratio_rel
|
||||||
|
features['vol_diff_ratio'] = vol_spread_ratio
|
||||||
|
|
||||||
|
# Volatility for SL/TP
|
||||||
|
features['realized_vol_base'] = realized_vol_a
|
||||||
|
features['realized_vol_quote'] = realized_vol_b
|
||||||
|
features['realized_vol_avg'] = (realized_vol_a + realized_vol_b) / 2
|
||||||
|
|
||||||
|
# ATR for dynamic stops (in price units and as percentage)
|
||||||
|
features['atr_base'] = atr_a
|
||||||
|
features['atr_quote'] = atr_b
|
||||||
|
features['atr_pct_base'] = atr_pct_a
|
||||||
|
features['atr_pct_quote'] = atr_pct_b
|
||||||
|
features['atr_pct_avg'] = (atr_pct_a + atr_pct_b) / 2
|
||||||
|
|
||||||
|
# Pair encoding (for universal model)
|
||||||
|
# Using base and quote indices for hierarchical encoding
|
||||||
|
assets = self.config.assets
|
||||||
|
features['base_idx'] = assets.index(base) if base in assets else -1
|
||||||
|
features['quote_idx'] = assets.index(quote) if quote in assets else -1
|
||||||
|
|
||||||
|
# Add funding and on-chain features
|
||||||
|
# Funding data is always added from self._funding_data (OKX, all 10 assets)
|
||||||
|
# On-chain data is optional (CryptoQuant, BTC/ETH only)
|
||||||
|
features = self._add_on_chain_features(
|
||||||
|
features, on_chain_data, base, quote
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop rows with NaN in core features only (not funding/on-chain)
|
||||||
|
core_cols = [
|
||||||
|
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
|
||||||
|
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
|
||||||
|
'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg',
|
||||||
|
'atr_base', 'atr_pct_base' # ATR is core for SL/TP
|
||||||
|
]
|
||||||
|
features = features.dropna(subset=core_cols)
|
||||||
|
|
||||||
|
# Fill missing funding/on-chain features with 0 (neutral)
|
||||||
|
optional_cols = [
|
||||||
|
'base_funding', 'quote_funding', 'funding_diff', 'funding_avg',
|
||||||
|
'base_inflow', 'quote_inflow', 'inflow_ratio'
|
||||||
|
]
|
||||||
|
for col in optional_cols:
|
||||||
|
if col in features.columns:
|
||||||
|
features[col] = features[col].fillna(0)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def calculate_all_pair_features(
|
||||||
|
self,
|
||||||
|
pairs: list[TradingPair],
|
||||||
|
asset_data: dict[str, pd.DataFrame],
|
||||||
|
on_chain_data: pd.DataFrame | None = None
|
||||||
|
) -> dict[str, pd.DataFrame]:
|
||||||
|
"""
|
||||||
|
Calculate features for all pairs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pairs: List of trading pairs
|
||||||
|
asset_data: Dictionary of OHLCV DataFrames
|
||||||
|
on_chain_data: Optional on-chain data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping pair_id to feature DataFrame
|
||||||
|
"""
|
||||||
|
all_features = {}
|
||||||
|
|
||||||
|
for pair in pairs:
|
||||||
|
features = self.calculate_pair_features(
|
||||||
|
pair, asset_data, on_chain_data
|
||||||
|
)
|
||||||
|
if features is not None and len(features) > 0:
|
||||||
|
all_features[pair.pair_id] = features
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Calculated features for %d/%d pairs",
|
||||||
|
len(all_features), len(pairs)
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_features
|
||||||
|
|
||||||
|
def get_combined_features(
|
||||||
|
self,
|
||||||
|
pair_features: dict[str, pd.DataFrame],
|
||||||
|
timestamp: pd.Timestamp | None = None
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Combine all pair features into a single DataFrame.
|
||||||
|
|
||||||
|
Useful for batch model prediction across all pairs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pair_features: Dictionary of feature DataFrames by pair_id
|
||||||
|
timestamp: Optional specific timestamp to filter to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined DataFrame with all pairs as rows
|
||||||
|
"""
|
||||||
|
if not pair_features:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
if timestamp is not None:
|
||||||
|
# Get latest row from each pair at or before timestamp
|
||||||
|
rows = []
|
||||||
|
for pair_id, features in pair_features.items():
|
||||||
|
valid = features[features.index <= timestamp]
|
||||||
|
if len(valid) > 0:
|
||||||
|
row = valid.iloc[-1:].copy()
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
if rows:
|
||||||
|
return pd.concat(rows, ignore_index=False)
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
# Combine all features (for training)
|
||||||
|
return pd.concat(pair_features.values(), ignore_index=False)
|
||||||
|
|
||||||
|
def _add_on_chain_features(
|
||||||
|
self,
|
||||||
|
features: pd.DataFrame,
|
||||||
|
on_chain_data: pd.DataFrame | None,
|
||||||
|
base_asset: str,
|
||||||
|
quote_asset: str
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Add on-chain and funding rate features for the pair.
|
||||||
|
|
||||||
|
Uses funding data from OKX (all 10 assets) and on-chain data
|
||||||
|
from CryptoQuant (BTC/ETH only for inflows).
|
||||||
|
"""
|
||||||
|
base_short = base_asset.replace('-USDT', '').lower()
|
||||||
|
quote_short = quote_asset.replace('-USDT', '').lower()
|
||||||
|
|
||||||
|
# Add funding rates from cached funding data
|
||||||
|
if self._funding_data is not None and not self._funding_data.empty:
|
||||||
|
funding_aligned = self._funding_data.reindex(
|
||||||
|
features.index, method='ffill'
|
||||||
|
)
|
||||||
|
|
||||||
|
base_funding_col = f'{base_short}_funding'
|
||||||
|
quote_funding_col = f'{quote_short}_funding'
|
||||||
|
|
||||||
|
if base_funding_col in funding_aligned.columns:
|
||||||
|
features['base_funding'] = funding_aligned[base_funding_col]
|
||||||
|
if quote_funding_col in funding_aligned.columns:
|
||||||
|
features['quote_funding'] = funding_aligned[quote_funding_col]
|
||||||
|
|
||||||
|
# Funding difference (positive = base has higher funding)
|
||||||
|
if 'base_funding' in features.columns and 'quote_funding' in features.columns:
|
||||||
|
features['funding_diff'] = (
|
||||||
|
features['base_funding'] - features['quote_funding']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Funding sentiment: average of both assets
|
||||||
|
features['funding_avg'] = (
|
||||||
|
features['base_funding'] + features['quote_funding']
|
||||||
|
) / 2
|
||||||
|
|
||||||
|
# Add on-chain features from CryptoQuant (BTC/ETH only)
|
||||||
|
if on_chain_data is not None and not on_chain_data.empty:
|
||||||
|
cq_aligned = on_chain_data.reindex(features.index, method='ffill')
|
||||||
|
|
||||||
|
# Inflows (only available for BTC/ETH)
|
||||||
|
base_inflow_col = f'{base_short}_inflow'
|
||||||
|
quote_inflow_col = f'{quote_short}_inflow'
|
||||||
|
|
||||||
|
if base_inflow_col in cq_aligned.columns:
|
||||||
|
features['base_inflow'] = cq_aligned[base_inflow_col]
|
||||||
|
if quote_inflow_col in cq_aligned.columns:
|
||||||
|
features['quote_inflow'] = cq_aligned[quote_inflow_col]
|
||||||
|
|
||||||
|
if 'base_inflow' in features.columns and 'quote_inflow' in features.columns:
|
||||||
|
features['inflow_ratio'] = (
|
||||||
|
features['base_inflow'] /
|
||||||
|
(features['quote_inflow'] + 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_feature_columns(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get list of feature columns for ML model.
|
||||||
|
|
||||||
|
Excludes metadata and target-related columns.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of feature column names
|
||||||
|
"""
|
||||||
|
# Core features (always present)
|
||||||
|
core_features = [
|
||||||
|
'z_score', 'spread_rsi', 'spread_roc', 'spread_change_1h',
|
||||||
|
'vol_ratio', 'vol_ratio_rel', 'vol_diff_ratio',
|
||||||
|
'realized_vol_base', 'realized_vol_quote', 'realized_vol_avg',
|
||||||
|
'base_idx', 'quote_idx'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Funding features (now available for all 10 assets via OKX)
|
||||||
|
funding_features = [
|
||||||
|
'base_funding', 'quote_funding', 'funding_diff', 'funding_avg'
|
||||||
|
]
|
||||||
|
|
||||||
|
# On-chain features (BTC/ETH only via CryptoQuant)
|
||||||
|
onchain_features = [
|
||||||
|
'base_inflow', 'quote_inflow', 'inflow_ratio'
|
||||||
|
]
|
||||||
|
|
||||||
|
return core_features + funding_features + onchain_features
|
||||||
272
strategies/multi_pair/funding.py
Normal file
272
strategies/multi_pair/funding.py
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
"""
|
||||||
|
Funding Rate Fetcher for Multi-Pair Strategy.
|
||||||
|
|
||||||
|
Fetches historical funding rates from OKX for all assets.
|
||||||
|
CryptoQuant only supports BTC/ETH, so we use OKX for the full universe.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import ccxt
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from engine.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FundingRateFetcher:
|
||||||
|
"""
|
||||||
|
Fetches and caches funding rate data from OKX.
|
||||||
|
|
||||||
|
OKX funding rates are settled every 8 hours (00:00, 08:00, 16:00 UTC).
|
||||||
|
This fetcher retrieves historical funding rate data and aligns it
|
||||||
|
to hourly candles for use in the multi-pair strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cache_dir: str = "data/funding"):
|
||||||
|
self.cache_dir = Path(cache_dir)
|
||||||
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.exchange: ccxt.okx | None = None
|
||||||
|
|
||||||
|
def _init_exchange(self) -> None:
|
||||||
|
"""Initialize OKX exchange connection."""
|
||||||
|
if self.exchange is None:
|
||||||
|
self.exchange = ccxt.okx({
|
||||||
|
'enableRateLimit': True,
|
||||||
|
'options': {'defaultType': 'swap'}
|
||||||
|
})
|
||||||
|
self.exchange.load_markets()
|
||||||
|
|
||||||
|
def fetch_funding_history(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None,
|
||||||
|
limit: int = 100
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Fetch historical funding rates for a symbol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Asset symbol (e.g., 'BTC-USDT')
|
||||||
|
start_date: Start date (YYYY-MM-DD)
|
||||||
|
end_date: End date (YYYY-MM-DD)
|
||||||
|
limit: Max records per request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with funding rate history
|
||||||
|
"""
|
||||||
|
self._init_exchange()
|
||||||
|
|
||||||
|
# Convert symbol format
|
||||||
|
base = symbol.replace('-USDT', '')
|
||||||
|
okx_symbol = f"{base}/USDT:USDT"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# OKX funding rate history endpoint
|
||||||
|
# Uses fetch_funding_rate_history if available
|
||||||
|
all_funding = []
|
||||||
|
|
||||||
|
# Parse dates
|
||||||
|
if start_date:
|
||||||
|
since = self.exchange.parse8601(f"{start_date}T00:00:00Z")
|
||||||
|
else:
|
||||||
|
# Default to 1 year ago
|
||||||
|
since = self.exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
|
if end_date:
|
||||||
|
until = self.exchange.parse8601(f"{end_date}T23:59:59Z")
|
||||||
|
else:
|
||||||
|
until = self.exchange.milliseconds()
|
||||||
|
|
||||||
|
# Fetch in batches
|
||||||
|
current_since = since
|
||||||
|
while current_since < until:
|
||||||
|
try:
|
||||||
|
funding = self.exchange.fetch_funding_rate_history(
|
||||||
|
okx_symbol,
|
||||||
|
since=current_since,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
if not funding:
|
||||||
|
break
|
||||||
|
|
||||||
|
all_funding.extend(funding)
|
||||||
|
|
||||||
|
# Move to next batch
|
||||||
|
last_ts = funding[-1]['timestamp']
|
||||||
|
if last_ts <= current_since:
|
||||||
|
break
|
||||||
|
current_since = last_ts + 1
|
||||||
|
|
||||||
|
time.sleep(0.1) # Rate limit
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Error fetching funding batch for %s: %s",
|
||||||
|
symbol, str(e)[:50]
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not all_funding:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
# Convert to DataFrame
|
||||||
|
df = pd.DataFrame(all_funding)
|
||||||
|
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||||
|
df.set_index('timestamp', inplace=True)
|
||||||
|
df = df[['fundingRate']].rename(columns={'fundingRate': 'funding_rate'})
|
||||||
|
df.sort_index(inplace=True)
|
||||||
|
|
||||||
|
# Remove duplicates
|
||||||
|
df = df[~df.index.duplicated(keep='first')]
|
||||||
|
|
||||||
|
logger.info("Fetched %d funding records for %s", len(df), symbol)
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to fetch funding for %s: %s", symbol, e)
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
def fetch_all_assets(
|
||||||
|
self,
|
||||||
|
assets: list[str],
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Fetch funding rates for all assets and combine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
assets: List of asset symbols (e.g., ['BTC-USDT', 'ETH-USDT'])
|
||||||
|
start_date: Start date
|
||||||
|
end_date: End date
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined DataFrame with columns like 'btc_funding', 'eth_funding', etc.
|
||||||
|
"""
|
||||||
|
combined = pd.DataFrame()
|
||||||
|
|
||||||
|
for symbol in assets:
|
||||||
|
df = self.fetch_funding_history(symbol, start_date, end_date)
|
||||||
|
|
||||||
|
if df.empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Rename column
|
||||||
|
asset_name = symbol.replace('-USDT', '').lower()
|
||||||
|
col_name = f"{asset_name}_funding"
|
||||||
|
df = df.rename(columns={'funding_rate': col_name})
|
||||||
|
|
||||||
|
if combined.empty:
|
||||||
|
combined = df
|
||||||
|
else:
|
||||||
|
combined = combined.join(df, how='outer')
|
||||||
|
|
||||||
|
time.sleep(0.2) # Be nice to API
|
||||||
|
|
||||||
|
# Forward fill to hourly (funding is every 8h)
|
||||||
|
if not combined.empty:
|
||||||
|
combined = combined.sort_index()
|
||||||
|
combined = combined.ffill()
|
||||||
|
|
||||||
|
return combined
|
||||||
|
|
||||||
|
def save_to_cache(self, df: pd.DataFrame, filename: str = "funding_rates.csv") -> None:
|
||||||
|
"""Save funding data to cache file."""
|
||||||
|
path = self.cache_dir / filename
|
||||||
|
df.to_csv(path)
|
||||||
|
logger.info("Saved funding rates to %s", path)
|
||||||
|
|
||||||
|
def load_from_cache(self, filename: str = "funding_rates.csv") -> pd.DataFrame | None:
|
||||||
|
"""Load funding data from cache if available."""
|
||||||
|
path = self.cache_dir / filename
|
||||||
|
if path.exists():
|
||||||
|
df = pd.read_csv(path, index_col='timestamp', parse_dates=True)
|
||||||
|
logger.info("Loaded funding rates from cache: %d rows", len(df))
|
||||||
|
return df
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_funding_data(
|
||||||
|
self,
|
||||||
|
assets: list[str],
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None,
|
||||||
|
use_cache: bool = True,
|
||||||
|
force_refresh: bool = False
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Get funding data, using cache if available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
assets: List of asset symbols
|
||||||
|
start_date: Start date
|
||||||
|
end_date: End date
|
||||||
|
use_cache: Whether to use cached data
|
||||||
|
force_refresh: Force refresh even if cache exists
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with funding rates for all assets
|
||||||
|
"""
|
||||||
|
cache_file = "funding_rates.csv"
|
||||||
|
|
||||||
|
# Try cache first
|
||||||
|
if use_cache and not force_refresh:
|
||||||
|
cached = self.load_from_cache(cache_file)
|
||||||
|
if cached is not None:
|
||||||
|
# Check if cache covers requested range
|
||||||
|
if start_date and end_date:
|
||||||
|
start_ts = pd.Timestamp(start_date, tz='UTC')
|
||||||
|
end_ts = pd.Timestamp(end_date, tz='UTC')
|
||||||
|
|
||||||
|
if cached.index.min() <= start_ts and cached.index.max() >= end_ts:
|
||||||
|
# Filter to requested range
|
||||||
|
return cached[(cached.index >= start_ts) & (cached.index <= end_ts)]
|
||||||
|
|
||||||
|
# Fetch fresh data
|
||||||
|
logger.info("Fetching fresh funding rate data...")
|
||||||
|
df = self.fetch_all_assets(assets, start_date, end_date)
|
||||||
|
|
||||||
|
if not df.empty and use_cache:
|
||||||
|
self.save_to_cache(df, cache_file)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def download_funding_data():
|
||||||
|
"""Download funding data for all multi-pair assets."""
|
||||||
|
from strategies.multi_pair.config import MultiPairConfig
|
||||||
|
|
||||||
|
config = MultiPairConfig()
|
||||||
|
fetcher = FundingRateFetcher()
|
||||||
|
|
||||||
|
# Fetch last year of data
|
||||||
|
end_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
start_date = (datetime.now(timezone.utc) - pd.Timedelta(days=365)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
logger.info("Downloading funding rates for %d assets...", len(config.assets))
|
||||||
|
logger.info("Date range: %s to %s", start_date, end_date)
|
||||||
|
|
||||||
|
df = fetcher.get_funding_data(
|
||||||
|
config.assets,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
force_refresh=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not df.empty:
|
||||||
|
logger.info("Downloaded %d funding rate records", len(df))
|
||||||
|
logger.info("Columns: %s", list(df.columns))
|
||||||
|
else:
|
||||||
|
logger.warning("No funding data downloaded")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from engine.logging_config import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
download_funding_data()
|
||||||
168
strategies/multi_pair/pair_scanner.py
Normal file
168
strategies/multi_pair/pair_scanner.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
Pair Scanner for Multi-Pair Divergence Strategy.
|
||||||
|
|
||||||
|
Generates all possible pairs from asset universe and checks tradeability.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from itertools import combinations
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import ccxt
|
||||||
|
|
||||||
|
from engine.logging_config import get_logger
|
||||||
|
from .config import MultiPairConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TradingPair:
|
||||||
|
"""
|
||||||
|
Represents a tradeable pair for spread analysis.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_asset: First asset in the pair (numerator)
|
||||||
|
quote_asset: Second asset in the pair (denominator)
|
||||||
|
pair_id: Unique identifier for the pair
|
||||||
|
is_direct: Whether pair can be traded directly on exchange
|
||||||
|
exchange_symbol: Symbol for direct trading (if available)
|
||||||
|
"""
|
||||||
|
base_asset: str
|
||||||
|
quote_asset: str
|
||||||
|
pair_id: str
|
||||||
|
is_direct: bool = False
|
||||||
|
exchange_symbol: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Human-readable pair name."""
|
||||||
|
return f"{self.base_asset}/{self.quote_asset}"
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.pair_id)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, TradingPair):
|
||||||
|
return False
|
||||||
|
return self.pair_id == other.pair_id
|
||||||
|
|
||||||
|
|
||||||
|
class PairScanner:
|
||||||
|
"""
|
||||||
|
Scans and generates tradeable pairs from asset universe.
|
||||||
|
|
||||||
|
Checks OKX for directly tradeable cross-pairs and generates
|
||||||
|
synthetic pairs via USDT for others.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: MultiPairConfig):
|
||||||
|
self.config = config
|
||||||
|
self.exchange: Optional[ccxt.Exchange] = None
|
||||||
|
self._available_markets: set[str] = set()
|
||||||
|
|
||||||
|
def _init_exchange(self) -> None:
|
||||||
|
"""Initialize exchange connection for market lookup."""
|
||||||
|
if self.exchange is None:
|
||||||
|
exchange_class = getattr(ccxt, self.config.exchange_id)
|
||||||
|
self.exchange = exchange_class({'enableRateLimit': True})
|
||||||
|
self.exchange.load_markets()
|
||||||
|
self._available_markets = set(self.exchange.symbols)
|
||||||
|
logger.info(
|
||||||
|
"Loaded %d markets from %s",
|
||||||
|
len(self._available_markets),
|
||||||
|
self.config.exchange_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_pairs(self, check_exchange: bool = True) -> list[TradingPair]:
|
||||||
|
"""
|
||||||
|
Generate all unique pairs from asset universe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
check_exchange: Whether to check OKX for direct trading
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of TradingPair objects
|
||||||
|
"""
|
||||||
|
if check_exchange:
|
||||||
|
self._init_exchange()
|
||||||
|
|
||||||
|
pairs = []
|
||||||
|
assets = self.config.assets
|
||||||
|
|
||||||
|
for base, quote in combinations(assets, 2):
|
||||||
|
pair_id = f"{base}__{quote}"
|
||||||
|
|
||||||
|
# Check if directly tradeable as cross-pair on OKX
|
||||||
|
is_direct = False
|
||||||
|
exchange_symbol = None
|
||||||
|
|
||||||
|
if check_exchange:
|
||||||
|
# Check perpetual cross-pair (e.g., ETH/BTC:BTC)
|
||||||
|
# OKX perpetuals are typically quoted in USDT
|
||||||
|
# Cross-pairs like ETH/BTC are less common
|
||||||
|
cross_symbol = f"{base.replace('-USDT', '')}/{quote.replace('-USDT', '')}:USDT"
|
||||||
|
if cross_symbol in self._available_markets:
|
||||||
|
is_direct = True
|
||||||
|
exchange_symbol = cross_symbol
|
||||||
|
|
||||||
|
pair = TradingPair(
|
||||||
|
base_asset=base,
|
||||||
|
quote_asset=quote,
|
||||||
|
pair_id=pair_id,
|
||||||
|
is_direct=is_direct,
|
||||||
|
exchange_symbol=exchange_symbol
|
||||||
|
)
|
||||||
|
pairs.append(pair)
|
||||||
|
|
||||||
|
# Log summary
|
||||||
|
direct_count = sum(1 for p in pairs if p.is_direct)
|
||||||
|
logger.info(
|
||||||
|
"Generated %d pairs: %d direct, %d synthetic",
|
||||||
|
len(pairs), direct_count, len(pairs) - direct_count
|
||||||
|
)
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
def get_required_symbols(self, pairs: list[TradingPair]) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get list of symbols needed to calculate all pair spreads.
|
||||||
|
|
||||||
|
For synthetic pairs, we need both USDT pairs.
|
||||||
|
For direct pairs, we still load USDT pairs for simplicity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pairs: List of trading pairs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique symbols to load (e.g., ['BTC-USDT', 'ETH-USDT'])
|
||||||
|
"""
|
||||||
|
symbols = set()
|
||||||
|
for pair in pairs:
|
||||||
|
symbols.add(pair.base_asset)
|
||||||
|
symbols.add(pair.quote_asset)
|
||||||
|
return list(symbols)
|
||||||
|
|
||||||
|
def filter_by_assets(
|
||||||
|
self,
|
||||||
|
pairs: list[TradingPair],
|
||||||
|
exclude_assets: list[str]
|
||||||
|
) -> list[TradingPair]:
|
||||||
|
"""
|
||||||
|
Filter pairs that contain any of the excluded assets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pairs: List of trading pairs
|
||||||
|
exclude_assets: Assets to exclude
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of pairs
|
||||||
|
"""
|
||||||
|
if not exclude_assets:
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
exclude_set = set(exclude_assets)
|
||||||
|
return [
|
||||||
|
p for p in pairs
|
||||||
|
if p.base_asset not in exclude_set
|
||||||
|
and p.quote_asset not in exclude_set
|
||||||
|
]
|
||||||
525
strategies/multi_pair/strategy.py
Normal file
525
strategies/multi_pair/strategy.py
Normal file
@@ -0,0 +1,525 @@
|
|||||||
|
"""
|
||||||
|
Multi-Pair Divergence Selection Strategy.
|
||||||
|
|
||||||
|
Main strategy class that orchestrates pair scanning, feature calculation,
|
||||||
|
model training, and signal generation for backtesting.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from strategies.base import BaseStrategy
|
||||||
|
from engine.market import MarketType
|
||||||
|
from engine.logging_config import get_logger
|
||||||
|
from .config import MultiPairConfig
|
||||||
|
from .pair_scanner import PairScanner, TradingPair
|
||||||
|
from .correlation import CorrelationFilter
|
||||||
|
from .feature_engine import MultiPairFeatureEngine
|
||||||
|
from .divergence_scorer import DivergenceScorer, DivergenceSignal
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PositionState:
|
||||||
|
"""Tracks current position state."""
|
||||||
|
pair: TradingPair | None = None
|
||||||
|
direction: str | None = None # 'long' or 'short'
|
||||||
|
entry_price: float = 0.0
|
||||||
|
entry_idx: int = -1
|
||||||
|
stop_loss: float = 0.0
|
||||||
|
take_profit: float = 0.0
|
||||||
|
atr: float = 0.0 # ATR at entry for reference
|
||||||
|
last_exit_idx: int = -100 # For cooldown tracking
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPairDivergenceStrategy(BaseStrategy):
|
||||||
|
"""
|
||||||
|
Multi-Pair Divergence Selection Strategy.
|
||||||
|
|
||||||
|
Scans multiple cryptocurrency pairs for spread divergence,
|
||||||
|
selects the most divergent pair using ML-enhanced scoring,
|
||||||
|
and trades mean-reversion opportunities.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- Universal ML model across all pairs
|
||||||
|
- Correlation-based pair filtering
|
||||||
|
- Dynamic SL/TP based on volatility
|
||||||
|
- Walk-forward training
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MultiPairConfig | None = None,
|
||||||
|
model_path: str = "data/multi_pair_model.pkl"
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config or MultiPairConfig()
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self.pair_scanner = PairScanner(self.config)
|
||||||
|
self.correlation_filter = CorrelationFilter(self.config)
|
||||||
|
self.feature_engine = MultiPairFeatureEngine(self.config)
|
||||||
|
self.divergence_scorer = DivergenceScorer(self.config, model_path)
|
||||||
|
|
||||||
|
# Strategy configuration
|
||||||
|
self.default_market_type = MarketType.PERPETUAL
|
||||||
|
self.default_leverage = 1
|
||||||
|
|
||||||
|
# Runtime state
|
||||||
|
self.pairs: list[TradingPair] = []
|
||||||
|
self.asset_data: dict[str, pd.DataFrame] = {}
|
||||||
|
self.pair_features: dict[str, pd.DataFrame] = {}
|
||||||
|
self.position = PositionState()
|
||||||
|
self.train_end_idx: int = 0
|
||||||
|
|
||||||
|
def run(self, close: pd.Series, **kwargs) -> tuple:
|
||||||
|
"""
|
||||||
|
Execute the multi-pair divergence strategy.
|
||||||
|
|
||||||
|
This method is called by the backtester with the primary asset's
|
||||||
|
close prices. For multi-pair, we load all assets internally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
close: Primary close prices (used for index alignment)
|
||||||
|
**kwargs: Additional data (high, low, volume)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (long_entries, long_exits, short_entries, short_exits, size)
|
||||||
|
"""
|
||||||
|
logger.info("Starting Multi-Pair Divergence Strategy")
|
||||||
|
|
||||||
|
# 1. Load all asset data
|
||||||
|
start_date = close.index.min().strftime("%Y-%m-%d")
|
||||||
|
end_date = close.index.max().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
self.asset_data = self.feature_engine.load_all_assets(
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1b. Load funding rate data for all assets
|
||||||
|
self.feature_engine.load_funding_data(
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
use_cache=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self.asset_data) < 2:
|
||||||
|
logger.error("Insufficient assets loaded, need at least 2")
|
||||||
|
return self._empty_signals(close)
|
||||||
|
|
||||||
|
# 2. Generate pairs
|
||||||
|
self.pairs = self.pair_scanner.generate_pairs(check_exchange=False)
|
||||||
|
|
||||||
|
# Filter to pairs with available data
|
||||||
|
available_assets = set(self.asset_data.keys())
|
||||||
|
self.pairs = [
|
||||||
|
p for p in self.pairs
|
||||||
|
if p.base_asset in available_assets
|
||||||
|
and p.quote_asset in available_assets
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info("Trading %d pairs from %d assets", len(self.pairs), len(self.asset_data))
|
||||||
|
|
||||||
|
# 3. Calculate features for all pairs
|
||||||
|
self.pair_features = self.feature_engine.calculate_all_pair_features(
|
||||||
|
self.pairs, self.asset_data
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.pair_features:
|
||||||
|
logger.error("No pair features calculated")
|
||||||
|
return self._empty_signals(close)
|
||||||
|
|
||||||
|
# 4. Align to common index
|
||||||
|
common_index = self._get_common_index()
|
||||||
|
if len(common_index) < 200:
|
||||||
|
logger.error("Insufficient common data across pairs")
|
||||||
|
return self._empty_signals(close)
|
||||||
|
|
||||||
|
# 5. Walk-forward split
|
||||||
|
n_samples = len(common_index)
|
||||||
|
train_size = int(n_samples * self.config.train_ratio)
|
||||||
|
self.train_end_idx = train_size
|
||||||
|
|
||||||
|
train_end_date = common_index[train_size - 1]
|
||||||
|
test_start_date = common_index[train_size]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Walk-Forward Split: Train=%d bars (until %s), Test=%d bars (from %s)",
|
||||||
|
train_size, train_end_date.strftime('%Y-%m-%d'),
|
||||||
|
n_samples - train_size, test_start_date.strftime('%Y-%m-%d')
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Train model on training period
|
||||||
|
if self.divergence_scorer.model is None:
|
||||||
|
train_features = {
|
||||||
|
pid: feat[feat.index <= train_end_date]
|
||||||
|
for pid, feat in self.pair_features.items()
|
||||||
|
}
|
||||||
|
combined = self.feature_engine.get_combined_features(train_features)
|
||||||
|
self.divergence_scorer.train_model(combined, train_features)
|
||||||
|
|
||||||
|
# 7. Generate signals for test period
|
||||||
|
return self._generate_signals(common_index, train_size, close)
|
||||||
|
|
||||||
|
def _generate_signals(
|
||||||
|
self,
|
||||||
|
index: pd.DatetimeIndex,
|
||||||
|
train_size: int,
|
||||||
|
reference_close: pd.Series
|
||||||
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Generate entry/exit signals for the test period.
|
||||||
|
|
||||||
|
Iterates through each bar in the test period, scoring pairs
|
||||||
|
and generating signals based on divergence scores.
|
||||||
|
"""
|
||||||
|
# Initialize signal arrays aligned to reference close
|
||||||
|
long_entries = pd.Series(False, index=reference_close.index)
|
||||||
|
long_exits = pd.Series(False, index=reference_close.index)
|
||||||
|
short_entries = pd.Series(False, index=reference_close.index)
|
||||||
|
short_exits = pd.Series(False, index=reference_close.index)
|
||||||
|
size = pd.Series(1.0, index=reference_close.index)
|
||||||
|
|
||||||
|
# Track position state
|
||||||
|
self.position = PositionState()
|
||||||
|
|
||||||
|
# Price data for correlation calculation
|
||||||
|
price_data = {
|
||||||
|
symbol: df['close'] for symbol, df in self.asset_data.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Iterate through test period
|
||||||
|
test_indices = index[train_size:]
|
||||||
|
|
||||||
|
trade_count = 0
|
||||||
|
|
||||||
|
for i, timestamp in enumerate(test_indices):
|
||||||
|
current_idx = train_size + i
|
||||||
|
|
||||||
|
# Check exit conditions first
|
||||||
|
if self.position.pair is not None:
|
||||||
|
# Enforce minimum hold period
|
||||||
|
bars_held = current_idx - self.position.entry_idx
|
||||||
|
if bars_held < self.config.min_hold_bars:
|
||||||
|
# Only allow SL/TP exits during min hold period
|
||||||
|
should_exit, exit_reason = self._check_sl_tp_only(timestamp)
|
||||||
|
else:
|
||||||
|
should_exit, exit_reason = self._check_exit(timestamp)
|
||||||
|
|
||||||
|
if should_exit:
|
||||||
|
# Map exit signal to reference index
|
||||||
|
if timestamp in reference_close.index:
|
||||||
|
if self.position.direction == 'long':
|
||||||
|
long_exits.loc[timestamp] = True
|
||||||
|
else:
|
||||||
|
short_exits.loc[timestamp] = True
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Exit %s %s at %s: %s (held %d bars)",
|
||||||
|
self.position.direction,
|
||||||
|
self.position.pair.name,
|
||||||
|
timestamp.strftime('%Y-%m-%d %H:%M'),
|
||||||
|
exit_reason,
|
||||||
|
bars_held
|
||||||
|
)
|
||||||
|
self.position = PositionState(last_exit_idx=current_idx)
|
||||||
|
|
||||||
|
# Score pairs (with correlation filter if position exists)
|
||||||
|
held_asset = None
|
||||||
|
if self.position.pair is not None:
|
||||||
|
held_asset = self.position.pair.base_asset
|
||||||
|
|
||||||
|
# Filter pairs by correlation
|
||||||
|
candidate_pairs = self.correlation_filter.filter_pairs(
|
||||||
|
self.pairs,
|
||||||
|
held_asset,
|
||||||
|
price_data,
|
||||||
|
current_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get candidate features
|
||||||
|
candidate_features = {
|
||||||
|
pid: feat for pid, feat in self.pair_features.items()
|
||||||
|
if any(p.pair_id == pid for p in candidate_pairs)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Score pairs
|
||||||
|
signals = self.divergence_scorer.score_pairs(
|
||||||
|
candidate_features,
|
||||||
|
candidate_pairs,
|
||||||
|
timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get best signal
|
||||||
|
best = self.divergence_scorer.select_best_pair(signals)
|
||||||
|
|
||||||
|
if best is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if we should switch positions or enter new
|
||||||
|
should_enter = False
|
||||||
|
|
||||||
|
# Check cooldown
|
||||||
|
bars_since_exit = current_idx - self.position.last_exit_idx
|
||||||
|
in_cooldown = bars_since_exit < self.config.cooldown_bars
|
||||||
|
|
||||||
|
if self.position.pair is None and not in_cooldown:
|
||||||
|
# No position and not in cooldown, can enter
|
||||||
|
should_enter = True
|
||||||
|
elif self.position.pair is not None:
|
||||||
|
# Check if we should switch (requires min hold + significant improvement)
|
||||||
|
bars_held = current_idx - self.position.entry_idx
|
||||||
|
current_score = self._get_current_score(timestamp)
|
||||||
|
|
||||||
|
if (bars_held >= self.config.min_hold_bars and
|
||||||
|
best.divergence_score > current_score * self.config.switch_threshold):
|
||||||
|
# New opportunity is significantly better
|
||||||
|
if timestamp in reference_close.index:
|
||||||
|
if self.position.direction == 'long':
|
||||||
|
long_exits.loc[timestamp] = True
|
||||||
|
else:
|
||||||
|
short_exits.loc[timestamp] = True
|
||||||
|
self.position = PositionState(last_exit_idx=current_idx)
|
||||||
|
should_enter = True
|
||||||
|
|
||||||
|
if should_enter:
|
||||||
|
# Calculate ATR-based dynamic SL/TP
|
||||||
|
sl_price, tp_price = self._calculate_sl_tp(
|
||||||
|
best.base_price,
|
||||||
|
best.direction,
|
||||||
|
best.atr,
|
||||||
|
best.atr_pct
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set position
|
||||||
|
self.position = PositionState(
|
||||||
|
pair=best.pair,
|
||||||
|
direction=best.direction,
|
||||||
|
entry_price=best.base_price,
|
||||||
|
entry_idx=current_idx,
|
||||||
|
stop_loss=sl_price,
|
||||||
|
take_profit=tp_price,
|
||||||
|
atr=best.atr
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate position size based on divergence
|
||||||
|
pos_size = self._calculate_size(best.divergence_score)
|
||||||
|
|
||||||
|
# Generate entry signal
|
||||||
|
if timestamp in reference_close.index:
|
||||||
|
if best.direction == 'long':
|
||||||
|
long_entries.loc[timestamp] = True
|
||||||
|
else:
|
||||||
|
short_entries.loc[timestamp] = True
|
||||||
|
size.loc[timestamp] = pos_size
|
||||||
|
|
||||||
|
trade_count += 1
|
||||||
|
logger.debug(
|
||||||
|
"Entry %s %s at %s: z=%.2f, prob=%.2f, score=%.3f",
|
||||||
|
best.direction,
|
||||||
|
best.pair.name,
|
||||||
|
timestamp.strftime('%Y-%m-%d %H:%M'),
|
||||||
|
best.z_score,
|
||||||
|
best.probability,
|
||||||
|
best.divergence_score
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Generated %d trades in test period", trade_count)
|
||||||
|
|
||||||
|
return long_entries, long_exits, short_entries, short_exits, size
|
||||||
|
|
||||||
|
def _check_exit(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Check if current position should be exited.
|
||||||
|
|
||||||
|
Exit conditions:
|
||||||
|
1. Z-Score reverted to mean (|Z| < threshold)
|
||||||
|
2. Stop-loss hit
|
||||||
|
3. Take-profit hit
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (should_exit, reason)
|
||||||
|
"""
|
||||||
|
if self.position.pair is None:
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
pair_id = self.position.pair.pair_id
|
||||||
|
if pair_id not in self.pair_features:
|
||||||
|
return True, "pair_data_missing"
|
||||||
|
|
||||||
|
features = self.pair_features[pair_id]
|
||||||
|
valid = features[features.index <= timestamp]
|
||||||
|
|
||||||
|
if len(valid) == 0:
|
||||||
|
return True, "no_data"
|
||||||
|
|
||||||
|
latest = valid.iloc[-1]
|
||||||
|
z_score = latest['z_score']
|
||||||
|
current_price = latest['base_close']
|
||||||
|
|
||||||
|
# Check mean reversion (primary exit)
|
||||||
|
if abs(z_score) < self.config.z_exit_threshold:
|
||||||
|
return True, f"mean_reversion (z={z_score:.2f})"
|
||||||
|
|
||||||
|
# Check SL/TP
|
||||||
|
return self._check_sl_tp(current_price)
|
||||||
|
|
||||||
|
def _check_sl_tp_only(self, timestamp: pd.Timestamp) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Check only stop-loss and take-profit conditions.
|
||||||
|
Used during minimum hold period.
|
||||||
|
"""
|
||||||
|
if self.position.pair is None:
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
pair_id = self.position.pair.pair_id
|
||||||
|
if pair_id not in self.pair_features:
|
||||||
|
return True, "pair_data_missing"
|
||||||
|
|
||||||
|
features = self.pair_features[pair_id]
|
||||||
|
valid = features[features.index <= timestamp]
|
||||||
|
|
||||||
|
if len(valid) == 0:
|
||||||
|
return True, "no_data"
|
||||||
|
|
||||||
|
latest = valid.iloc[-1]
|
||||||
|
current_price = latest['base_close']
|
||||||
|
|
||||||
|
return self._check_sl_tp(current_price)
|
||||||
|
|
||||||
|
def _check_sl_tp(self, current_price: float) -> tuple[bool, str]:
|
||||||
|
"""Check stop-loss and take-profit levels."""
|
||||||
|
if self.position.direction == 'long':
|
||||||
|
if current_price <= self.position.stop_loss:
|
||||||
|
return True, f"stop_loss ({current_price:.2f} <= {self.position.stop_loss:.2f})"
|
||||||
|
if current_price >= self.position.take_profit:
|
||||||
|
return True, f"take_profit ({current_price:.2f} >= {self.position.take_profit:.2f})"
|
||||||
|
else: # short
|
||||||
|
if current_price >= self.position.stop_loss:
|
||||||
|
return True, f"stop_loss ({current_price:.2f} >= {self.position.stop_loss:.2f})"
|
||||||
|
if current_price <= self.position.take_profit:
|
||||||
|
return True, f"take_profit ({current_price:.2f} <= {self.position.take_profit:.2f})"
|
||||||
|
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
def _get_current_score(self, timestamp: pd.Timestamp) -> float:
|
||||||
|
"""Get current position's divergence score for comparison."""
|
||||||
|
if self.position.pair is None:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
pair_id = self.position.pair.pair_id
|
||||||
|
if pair_id not in self.pair_features:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
features = self.pair_features[pair_id]
|
||||||
|
valid = features[features.index <= timestamp]
|
||||||
|
|
||||||
|
if len(valid) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
latest = valid.iloc[-1]
|
||||||
|
z_score = abs(latest['z_score'])
|
||||||
|
|
||||||
|
# Re-score with model
|
||||||
|
if self.divergence_scorer.model is not None:
|
||||||
|
feature_row = latest[self.divergence_scorer.feature_cols].fillna(0)
|
||||||
|
feature_row = feature_row.replace([np.inf, -np.inf], 0)
|
||||||
|
X = pd.DataFrame(
|
||||||
|
[feature_row.values],
|
||||||
|
columns=self.divergence_scorer.feature_cols
|
||||||
|
)
|
||||||
|
prob = self.divergence_scorer.model.predict_proba(X)[0, 1]
|
||||||
|
return z_score * prob
|
||||||
|
|
||||||
|
return z_score * 0.5
|
||||||
|
|
||||||
|
def _calculate_sl_tp(
|
||||||
|
self,
|
||||||
|
entry_price: float,
|
||||||
|
direction: str,
|
||||||
|
atr: float,
|
||||||
|
atr_pct: float
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Calculate ATR-based dynamic stop-loss and take-profit prices.
|
||||||
|
|
||||||
|
Uses ATR (Average True Range) to set stops that adapt to
|
||||||
|
each asset's volatility. More volatile assets get wider stops.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry_price: Entry price
|
||||||
|
direction: 'long' or 'short'
|
||||||
|
atr: ATR in price units
|
||||||
|
atr_pct: ATR as percentage of price
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (stop_loss_price, take_profit_price)
|
||||||
|
"""
|
||||||
|
# Calculate SL/TP as ATR multiples
|
||||||
|
if atr > 0 and atr_pct > 0:
|
||||||
|
# ATR-based calculation
|
||||||
|
sl_distance = atr * self.config.sl_atr_multiplier
|
||||||
|
tp_distance = atr * self.config.tp_atr_multiplier
|
||||||
|
|
||||||
|
# Convert to percentage for bounds checking
|
||||||
|
sl_pct = sl_distance / entry_price
|
||||||
|
tp_pct = tp_distance / entry_price
|
||||||
|
else:
|
||||||
|
# Fallback to fixed percentages if ATR unavailable
|
||||||
|
sl_pct = self.config.base_sl_pct
|
||||||
|
tp_pct = self.config.base_tp_pct
|
||||||
|
|
||||||
|
# Apply bounds to prevent extreme stops
|
||||||
|
sl_pct = max(self.config.min_sl_pct, min(sl_pct, self.config.max_sl_pct))
|
||||||
|
tp_pct = max(self.config.min_tp_pct, min(tp_pct, self.config.max_tp_pct))
|
||||||
|
|
||||||
|
# Calculate actual prices
|
||||||
|
if direction == 'long':
|
||||||
|
stop_loss = entry_price * (1 - sl_pct)
|
||||||
|
take_profit = entry_price * (1 + tp_pct)
|
||||||
|
else: # short
|
||||||
|
stop_loss = entry_price * (1 + sl_pct)
|
||||||
|
take_profit = entry_price * (1 - tp_pct)
|
||||||
|
|
||||||
|
return stop_loss, take_profit
|
||||||
|
|
||||||
|
def _calculate_size(self, divergence_score: float) -> float:
|
||||||
|
"""
|
||||||
|
Calculate position size based on divergence score.
|
||||||
|
|
||||||
|
Higher divergence = larger position (up to 2x).
|
||||||
|
"""
|
||||||
|
# Base score threshold (Z=1.0, prob=0.5 -> score=0.5)
|
||||||
|
base_threshold = 0.5
|
||||||
|
|
||||||
|
# Scale factor
|
||||||
|
if divergence_score <= base_threshold:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
# Linear scaling: 1.0 at threshold, up to 2.0 at 2x threshold
|
||||||
|
scale = 1.0 + (divergence_score - base_threshold) / base_threshold
|
||||||
|
return min(scale, 2.0)
|
||||||
|
|
||||||
|
def _get_common_index(self) -> pd.DatetimeIndex:
|
||||||
|
"""Get the intersection of all pair feature indices."""
|
||||||
|
if not self.pair_features:
|
||||||
|
return pd.DatetimeIndex([])
|
||||||
|
|
||||||
|
common = None
|
||||||
|
for features in self.pair_features.values():
|
||||||
|
if common is None:
|
||||||
|
common = features.index
|
||||||
|
else:
|
||||||
|
common = common.intersection(features.index)
|
||||||
|
|
||||||
|
return common.sort_values()
|
||||||
|
|
||||||
|
def _empty_signals(self, close: pd.Series) -> tuple:
|
||||||
|
"""Return empty signal arrays."""
|
||||||
|
empty = self.create_empty_signals(close)
|
||||||
|
size = pd.Series(1.0, index=close.index)
|
||||||
|
return empty, empty, empty, empty, size
|
||||||
@@ -30,7 +30,7 @@ class RegimeReversionStrategy(BaseStrategy):
|
|||||||
|
|
||||||
# Optimal parameters from walk-forward research (2025-10 to 2025-12)
|
# Optimal parameters from walk-forward research (2025-10 to 2025-12)
|
||||||
# Research: research/horizon_optimization_results.csv
|
# Research: research/horizon_optimization_results.csv
|
||||||
OPTIMAL_HORIZON = 54 # Updated from 102h based on corrected labeling
|
OPTIMAL_HORIZON = 102 # 4.25 days - best Net PnL (+232%)
|
||||||
OPTIMAL_Z_WINDOW = 24 # 24h rolling window for spread Z-score
|
OPTIMAL_Z_WINDOW = 24 # 24h rolling window for spread Z-score
|
||||||
OPTIMAL_TRAIN_RATIO = 0.7 # 70% train / 30% test split
|
OPTIMAL_TRAIN_RATIO = 0.7 # 70% train / 30% test split
|
||||||
OPTIMAL_PROFIT_TARGET = 0.005 # 0.5% profit threshold for target definition
|
OPTIMAL_PROFIT_TARGET = 0.005 # 0.5% profit threshold for target definition
|
||||||
@@ -321,64 +321,21 @@ class RegimeReversionStrategy(BaseStrategy):
|
|||||||
train_features: DataFrame containing features for training period only
|
train_features: DataFrame containing features for training period only
|
||||||
"""
|
"""
|
||||||
threshold = self.profit_target
|
threshold = self.profit_target
|
||||||
stop_loss_pct = self.stop_loss
|
|
||||||
horizon = self.horizon
|
horizon = self.horizon
|
||||||
z_thresh = self.z_entry_threshold
|
z_thresh = self.z_entry_threshold
|
||||||
|
|
||||||
# Calculate targets path-dependently (checking SL before TP)
|
# Define targets using ONLY training data
|
||||||
spread = train_features['spread'].values
|
# For Short Spread (Z > threshold): Did spread drop below target within horizon?
|
||||||
z_score = train_features['z_score'].values
|
future_min = train_features['spread'].rolling(window=horizon).min().shift(-horizon)
|
||||||
n = len(spread)
|
target_short = train_features['spread'] * (1 - threshold)
|
||||||
|
success_short = (train_features['z_score'] > z_thresh) & (future_min < target_short)
|
||||||
|
|
||||||
targets = np.zeros(n, dtype=int)
|
# For Long Spread (Z < -threshold): Did spread rise above target within horizon?
|
||||||
|
future_max = train_features['spread'].rolling(window=horizon).max().shift(-horizon)
|
||||||
|
target_long = train_features['spread'] * (1 + threshold)
|
||||||
|
success_long = (train_features['z_score'] < -z_thresh) & (future_max > target_long)
|
||||||
|
|
||||||
# Only iterate relevant rows for efficiency
|
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||||
candidates = np.where((z_score > z_thresh) | (z_score < -z_thresh))[0]
|
|
||||||
|
|
||||||
for i in candidates:
|
|
||||||
if i + horizon >= n:
|
|
||||||
continue
|
|
||||||
|
|
||||||
entry_price = spread[i]
|
|
||||||
future_prices = spread[i+1 : i+1+horizon]
|
|
||||||
|
|
||||||
if z_score[i] > z_thresh: # Short
|
|
||||||
target_price = entry_price * (1 - threshold)
|
|
||||||
stop_price = entry_price * (1 + stop_loss_pct)
|
|
||||||
|
|
||||||
hit_tp = future_prices <= target_price
|
|
||||||
hit_sl = future_prices >= stop_price
|
|
||||||
|
|
||||||
if not np.any(hit_tp):
|
|
||||||
targets[i] = 0
|
|
||||||
elif not np.any(hit_sl):
|
|
||||||
targets[i] = 1
|
|
||||||
else:
|
|
||||||
first_tp_idx = np.argmax(hit_tp)
|
|
||||||
first_sl_idx = np.argmax(hit_sl)
|
|
||||||
if first_tp_idx < first_sl_idx:
|
|
||||||
targets[i] = 1
|
|
||||||
else:
|
|
||||||
targets[i] = 0
|
|
||||||
|
|
||||||
else: # Long
|
|
||||||
target_price = entry_price * (1 + threshold)
|
|
||||||
stop_price = entry_price * (1 - stop_loss_pct)
|
|
||||||
|
|
||||||
hit_tp = future_prices >= target_price
|
|
||||||
hit_sl = future_prices <= stop_price
|
|
||||||
|
|
||||||
if not np.any(hit_tp):
|
|
||||||
targets[i] = 0
|
|
||||||
elif not np.any(hit_sl):
|
|
||||||
targets[i] = 1
|
|
||||||
else:
|
|
||||||
first_tp_idx = np.argmax(hit_tp)
|
|
||||||
first_sl_idx = np.argmax(hit_sl)
|
|
||||||
if first_tp_idx < first_sl_idx:
|
|
||||||
targets[i] = 1
|
|
||||||
else:
|
|
||||||
targets[i] = 0
|
|
||||||
|
|
||||||
# Build model
|
# Build model
|
||||||
model = RandomForestClassifier(
|
model = RandomForestClassifier(
|
||||||
@@ -394,9 +351,10 @@ class RegimeReversionStrategy(BaseStrategy):
|
|||||||
X_train = train_features[cols].fillna(0)
|
X_train = train_features[cols].fillna(0)
|
||||||
X_train = X_train.replace([np.inf, -np.inf], 0)
|
X_train = X_train.replace([np.inf, -np.inf], 0)
|
||||||
|
|
||||||
# Use rows where we had enough data to look ahead
|
# Remove rows with NaN targets (from rolling window at end of training period)
|
||||||
valid_mask = np.zeros(n, dtype=bool)
|
valid_mask = ~np.isnan(targets) & ~np.isinf(targets)
|
||||||
valid_mask[:n-horizon] = True
|
# Also check for rows where future data doesn't exist (shift created NaNs)
|
||||||
|
valid_mask = valid_mask & (future_min.notna().values) & (future_max.notna().values)
|
||||||
|
|
||||||
X_train_clean = X_train[valid_mask]
|
X_train_clean = X_train[valid_mask]
|
||||||
targets_clean = targets[valid_mask]
|
targets_clean = targets[valid_mask]
|
||||||
|
|||||||
321
tasks/prd-multi-pair-divergence-strategy.md
Normal file
321
tasks/prd-multi-pair-divergence-strategy.md
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
# PRD: Multi-Pair Divergence Selection Strategy
|
||||||
|
|
||||||
|
## 1. Introduction / Overview
|
||||||
|
|
||||||
|
This document describes the **Multi-Pair Divergence Selection Strategy**, an extension of the existing BTC/ETH regime reversion system. The strategy expands spread analysis to the **top 10 cryptocurrencies by market cap**, calculates divergence scores for all tradeable pairs, and dynamically selects the **most divergent pair** for trading.
|
||||||
|
|
||||||
|
The core hypothesis: by scanning multiple pairs simultaneously, we can identify stronger mean-reversion opportunities than focusing on a single pair, improving net PnL while maintaining the proven ML-based regime detection approach.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Goals
|
||||||
|
|
||||||
|
1. **Extend regime detection** to top 10 market cap cryptocurrencies
|
||||||
|
2. **Dynamically select** the most divergent tradeable pair each cycle
|
||||||
|
3. **Integrate volatility** into dynamic SL/TP calculations
|
||||||
|
4. **Filter correlated pairs** to avoid redundant positions
|
||||||
|
5. **Improve net PnL** compared to single-pair BTC/ETH strategy
|
||||||
|
6. **Backtest-first** implementation with walk-forward validation
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. User Stories
|
||||||
|
|
||||||
|
### US-1: Multi-Pair Analysis
|
||||||
|
> As a trader, I want the system to analyze spread divergence across multiple cryptocurrency pairs so that I can identify the best trading opportunity at any given moment.
|
||||||
|
|
||||||
|
### US-2: Dynamic Pair Selection
|
||||||
|
> As a trader, I want the system to automatically select and trade the pair with the highest divergence score (combination of Z-score magnitude and ML probability) so that I maximize mean-reversion profit potential.
|
||||||
|
|
||||||
|
### US-3: Volatility-Adjusted Risk
|
||||||
|
> As a trader, I want stop-loss and take-profit levels to adapt to each pair's volatility so that I avoid being stopped out prematurely on volatile assets while protecting profits on stable ones.
|
||||||
|
|
||||||
|
### US-4: Correlation Filtering
|
||||||
|
> As a trader, I want the system to avoid selecting pairs that are highly correlated with my current position so that I don't inadvertently double-down on the same market exposure.
|
||||||
|
|
||||||
|
### US-5: Backtest Validation
|
||||||
|
> As a researcher, I want to backtest this multi-pair strategy with walk-forward training so that I can validate improvement over the single-pair baseline without look-ahead bias.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Functional Requirements
|
||||||
|
|
||||||
|
### 4.1 Data Management
|
||||||
|
|
||||||
|
| ID | Requirement |
|
||||||
|
|----|-------------|
|
||||||
|
| FR-1.1 | System must support loading OHLCV data for top 10 market cap cryptocurrencies |
|
||||||
|
| FR-1.2 | Target assets: BTC, ETH, SOL, XRP, BNB, DOGE, ADA, AVAX, LINK, DOT (configurable) |
|
||||||
|
| FR-1.3 | System must identify all directly tradeable cross-pairs on OKX perpetuals |
|
||||||
|
| FR-1.4 | System must align timestamps across all pairs for synchronized analysis |
|
||||||
|
| FR-1.5 | System must handle missing data gracefully (skip pair if insufficient history) |
|
||||||
|
|
||||||
|
### 4.2 Pair Generation
|
||||||
|
|
||||||
|
| ID | Requirement |
|
||||||
|
|----|-------------|
|
||||||
|
| FR-2.1 | Generate all unique pairs from asset universe: N*(N-1)/2 pairs (e.g., 45 pairs for 10 assets) |
|
||||||
|
| FR-2.2 | Filter pairs to only those directly tradeable on OKX (no USDT intermediate) |
|
||||||
|
| FR-2.3 | Fallback: If cross-pair not available, calculate synthetic spread via USDT pairs |
|
||||||
|
| FR-2.4 | Store pair metadata: base asset, quote asset, exchange symbol, tradeable flag |
|
||||||
|
|
||||||
|
### 4.3 Feature Engineering (Per Pair)
|
||||||
|
|
||||||
|
| ID | Requirement |
|
||||||
|
|----|-------------|
|
||||||
|
| FR-3.1 | Calculate spread ratio: `asset_a_close / asset_b_close` |
|
||||||
|
| FR-3.2 | Calculate Z-Score with configurable rolling window (default: 24h) |
|
||||||
|
| FR-3.3 | Calculate spread technicals: RSI(14), ROC(5), 1h change |
|
||||||
|
| FR-3.4 | Calculate volume ratio and relative volume |
|
||||||
|
| FR-3.5 | Calculate volatility ratio: `std(returns_a) / std(returns_b)` over Z-window |
|
||||||
|
| FR-3.6 | Calculate realized volatility for each asset (for dynamic SL/TP) |
|
||||||
|
| FR-3.7 | Merge on-chain data (funding rates, inflows) if available per asset |
|
||||||
|
| FR-3.8 | Add pair identifier as categorical feature for universal model |
|
||||||
|
|
||||||
|
### 4.4 Correlation Filtering
|
||||||
|
|
||||||
|
| ID | Requirement |
|
||||||
|
|----|-------------|
|
||||||
|
| FR-4.1 | Calculate rolling correlation matrix between all assets (default: 168h / 7 days) |
|
||||||
|
| FR-4.2 | Define correlation threshold (default: 0.85) |
|
||||||
|
| FR-4.3 | If current position exists, exclude pairs where either asset has correlation > threshold with held asset |
|
||||||
|
| FR-4.4 | Log filtered pairs with reason for exclusion |
|
||||||
|
|
||||||
|
### 4.5 Divergence Scoring & Pair Selection
|
||||||
|
|
||||||
|
| ID | Requirement |
|
||||||
|
|----|-------------|
|
||||||
|
| FR-5.1 | Calculate divergence score: `abs(z_score) * model_probability` |
|
||||||
|
| FR-5.2 | Only consider pairs where `abs(z_score) > z_entry_threshold` (default: 1.0) |
|
||||||
|
| FR-5.3 | Only consider pairs where `model_probability > prob_threshold` (default: 0.5) |
|
||||||
|
| FR-5.4 | Apply correlation filter to eligible pairs |
|
||||||
|
| FR-5.5 | Select pair with highest divergence score |
|
||||||
|
| FR-5.6 | If no pair qualifies, signal "hold" |
|
||||||
|
| FR-5.7 | Log all pair scores for analysis/debugging |
|
||||||
|
|
||||||
|
### 4.6 ML Model (Universal)
|
||||||
|
|
||||||
|
| ID | Requirement |
|
||||||
|
|----|-------------|
|
||||||
|
| FR-6.1 | Train single Random Forest model on all pairs combined |
|
||||||
|
| FR-6.2 | Include `pair_id` as one-hot encoded or label-encoded feature |
|
||||||
|
| FR-6.3 | Target: binary (1 = profitable reversion within horizon, 0 = no reversion) |
|
||||||
|
| FR-6.4 | Walk-forward training: 70% train / 30% test split |
|
||||||
|
| FR-6.5 | Daily retraining schedule (for live, configurable for backtest) |
|
||||||
|
| FR-6.6 | Model hyperparameters: `n_estimators=300, max_depth=5, min_samples_leaf=30, class_weight={0:1, 1:3}` |
|
||||||
|
| FR-6.7 | Save/load model with feature column metadata |
|
||||||
|
|
||||||
|
### 4.7 Signal Generation
|
||||||
|
|
||||||
|
| ID | Requirement |
|
||||||
|
|----|-------------|
|
||||||
|
| FR-7.1 | Direction: If `z_score > threshold` -> Short spread (short asset_a), If `z_score < -threshold` -> Long spread (long asset_a) |
|
||||||
|
| FR-7.2 | Apply funding rate filter per asset (block if extreme funding opposes direction) |
|
||||||
|
| FR-7.3 | Output signal: `{pair, action, side, probability, z_score, divergence_score, reason}` |
|
||||||
|
|
||||||
|
### 4.8 Position Sizing
|
||||||
|
|
||||||
|
| ID | Requirement |
|
||||||
|
|----|-------------|
|
||||||
|
| FR-8.1 | Base size: 100% of available subaccount balance |
|
||||||
|
| FR-8.2 | Scale by divergence: `size_multiplier = 1.0 + (divergence_score - base_threshold) * scaling_factor` |
|
||||||
|
| FR-8.3 | Cap multiplier between 1.0x and 2.0x |
|
||||||
|
| FR-8.4 | Respect exchange minimum order size per asset |
|
||||||
|
|
||||||
|
### 4.9 Dynamic SL/TP (Volatility-Adjusted)
|
||||||
|
|
||||||
|
| ID | Requirement |
|
||||||
|
|----|-------------|
|
||||||
|
| FR-9.1 | Calculate asset realized volatility: `std(returns) * sqrt(24)` for daily vol |
|
||||||
|
| FR-9.2 | Base SL: `entry_price * (1 - base_sl_pct * vol_multiplier)` for longs |
|
||||||
|
| FR-9.3 | Base TP: `entry_price * (1 + base_tp_pct * vol_multiplier)` for longs |
|
||||||
|
| FR-9.4 | `vol_multiplier = asset_volatility / baseline_volatility` (baseline = BTC volatility) |
|
||||||
|
| FR-9.5 | Cap vol_multiplier between 0.5x and 2.0x to prevent extreme values |
|
||||||
|
| FR-9.6 | Invert logic for short positions |
|
||||||
|
|
||||||
|
### 4.10 Exit Conditions
|
||||||
|
|
||||||
|
| ID | Requirement |
|
||||||
|
|----|-------------|
|
||||||
|
| FR-10.1 | Exit when Z-score crosses back through 0 (mean reversion complete) |
|
||||||
|
| FR-10.2 | Exit when dynamic SL or TP hit |
|
||||||
|
| FR-10.3 | No minimum holding period (can switch pairs immediately) |
|
||||||
|
| FR-10.4 | If new pair has higher divergence score, close current and open new |
|
||||||
|
|
||||||
|
### 4.11 Backtest Integration
|
||||||
|
|
||||||
|
| ID | Requirement |
|
||||||
|
|----|-------------|
|
||||||
|
| FR-11.1 | Integrate with existing `engine/backtester.py` framework |
|
||||||
|
| FR-11.2 | Support 1h timeframe (matching live trading) |
|
||||||
|
| FR-11.3 | Walk-forward validation: train on 70%, test on 30% |
|
||||||
|
| FR-11.4 | Output: trades log, equity curve, performance metrics |
|
||||||
|
| FR-11.5 | Compare against single-pair BTC/ETH baseline |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Non-Goals (Out of Scope)
|
||||||
|
|
||||||
|
1. **Live trading implementation** - Backtest validation first
|
||||||
|
2. **Multi-position portfolio** - Single pair at a time for v1
|
||||||
|
3. **Cross-exchange arbitrage** - OKX only
|
||||||
|
4. **Alternative ML models** - Stick with Random Forest for consistency
|
||||||
|
5. **Sub-1h timeframes** - 1h candles only for initial version
|
||||||
|
6. **Leveraged positions** - 1x leverage for backtest
|
||||||
|
7. **Portfolio-level VaR/risk budgeting** - Full subaccount allocation
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Design Considerations
|
||||||
|
|
||||||
|
### 6.1 Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
strategies/
|
||||||
|
multi_pair/
|
||||||
|
__init__.py
|
||||||
|
pair_scanner.py # Generates all pairs, filters tradeable
|
||||||
|
feature_engine.py # Calculates features for all pairs
|
||||||
|
correlation.py # Rolling correlation matrix & filtering
|
||||||
|
divergence_scorer.py # Ranks pairs by divergence score
|
||||||
|
strategy.py # Main strategy orchestration
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 Data Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Load OHLCV for all 10 assets
|
||||||
|
2. Generate pair combinations (45 pairs)
|
||||||
|
3. Filter to tradeable pairs (OKX check)
|
||||||
|
4. Calculate features for each pair
|
||||||
|
5. Train/load universal ML model
|
||||||
|
6. Predict probability for all pairs
|
||||||
|
7. Calculate divergence scores
|
||||||
|
8. Apply correlation filter
|
||||||
|
9. Select top pair
|
||||||
|
10. Generate signal with dynamic SL/TP
|
||||||
|
11. Execute in backtest engine
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.3 Configuration
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class MultiPairConfig:
|
||||||
|
# Assets
|
||||||
|
assets: list[str] = field(default_factory=lambda: [
|
||||||
|
"BTC", "ETH", "SOL", "XRP", "BNB",
|
||||||
|
"DOGE", "ADA", "AVAX", "LINK", "DOT"
|
||||||
|
])
|
||||||
|
|
||||||
|
# Thresholds
|
||||||
|
z_window: int = 24
|
||||||
|
z_entry_threshold: float = 1.0
|
||||||
|
prob_threshold: float = 0.5
|
||||||
|
correlation_threshold: float = 0.85
|
||||||
|
correlation_window: int = 168 # 7 days in hours
|
||||||
|
|
||||||
|
# Risk
|
||||||
|
base_sl_pct: float = 0.06
|
||||||
|
base_tp_pct: float = 0.05
|
||||||
|
vol_multiplier_min: float = 0.5
|
||||||
|
vol_multiplier_max: float = 2.0
|
||||||
|
|
||||||
|
# Model
|
||||||
|
train_ratio: float = 0.7
|
||||||
|
horizon: int = 102
|
||||||
|
profit_target: float = 0.005
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Technical Considerations
|
||||||
|
|
||||||
|
### 7.1 Dependencies
|
||||||
|
|
||||||
|
- Extend `DataManager` to load multiple symbols
|
||||||
|
- Query OKX API for available perpetual cross-pairs
|
||||||
|
- Reuse existing feature engineering from `RegimeReversionStrategy`
|
||||||
|
|
||||||
|
### 7.2 Performance
|
||||||
|
|
||||||
|
- Pre-calculate all pair features in batch (vectorized)
|
||||||
|
- Cache correlation matrix (update every N candles, not every minute)
|
||||||
|
- Model inference is fast (single predict call with all pairs as rows)
|
||||||
|
|
||||||
|
### 7.3 Edge Cases
|
||||||
|
|
||||||
|
- Handle pairs with insufficient history (< 200 bars) - exclude
|
||||||
|
- Handle assets delisted mid-backtest - skip pair
|
||||||
|
- Handle zero-volume periods - use last valid price
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Success Metrics
|
||||||
|
|
||||||
|
| Metric | Baseline (BTC/ETH) | Target |
|
||||||
|
|--------|-------------------|--------|
|
||||||
|
| Net PnL | Current performance | > 10% improvement |
|
||||||
|
| Number of Trades | N | Comparable or higher |
|
||||||
|
| Win Rate | Baseline % | Maintain or improve |
|
||||||
|
| Average Trade Duration | Baseline hours | Flexible |
|
||||||
|
| Max Drawdown | Baseline % | Not significantly worse |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Open Questions
|
||||||
|
|
||||||
|
1. **OKX Cross-Pairs**: Need to verify which cross-pairs are available on OKX perpetuals. May need to fallback to synthetic spreads for most pairs.
|
||||||
|
|
||||||
|
2. **On-Chain Data**: CryptoQuant data currently covers BTC/ETH. Should we:
|
||||||
|
- Run without on-chain features for other assets?
|
||||||
|
- Source alternative on-chain data?
|
||||||
|
- Use funding rates only (available from OKX)?
|
||||||
|
|
||||||
|
3. **Pair ID Encoding**: For the universal model, should pair_id be:
|
||||||
|
- One-hot encoded (adds 45 features)?
|
||||||
|
- Label encoded (single ordinal feature)?
|
||||||
|
- Hierarchical (base_asset + quote_asset as separate features)?
|
||||||
|
|
||||||
|
4. **Synthetic Spreads**: If trading SOL/DOT spread but only USDT pairs available:
|
||||||
|
- Calculate spread synthetically: `SOL-USDT / DOT-USDT`
|
||||||
|
- Execute as two legs: Long SOL-USDT, Short DOT-USDT
|
||||||
|
- This doubles fees and adds execution complexity. Include in v1?
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Implementation Phases
|
||||||
|
|
||||||
|
### Phase 1: Data & Infrastructure (Est. 2-3 days)
|
||||||
|
- Extend DataManager for multi-symbol loading
|
||||||
|
- Build pair scanner with OKX tradeable filter
|
||||||
|
- Implement correlation matrix calculation
|
||||||
|
|
||||||
|
### Phase 2: Feature Engineering (Est. 2 days)
|
||||||
|
- Adapt existing feature calculation for arbitrary pairs
|
||||||
|
- Add pair identifier feature
|
||||||
|
- Batch feature calculation for all pairs
|
||||||
|
|
||||||
|
### Phase 3: Model & Scoring (Est. 2 days)
|
||||||
|
- Train universal model on all pairs
|
||||||
|
- Implement divergence scoring
|
||||||
|
- Add correlation filtering to pair selection
|
||||||
|
|
||||||
|
### Phase 4: Strategy Integration (Est. 2-3 days)
|
||||||
|
- Implement dynamic SL/TP with volatility
|
||||||
|
- Integrate with backtester
|
||||||
|
- Build strategy orchestration class
|
||||||
|
|
||||||
|
### Phase 5: Validation & Comparison (Est. 2 days)
|
||||||
|
- Run walk-forward backtest
|
||||||
|
- Compare against BTC/ETH baseline
|
||||||
|
- Generate performance report
|
||||||
|
|
||||||
|
**Total Estimated Effort: 10-12 days**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
*Document Version: 1.0*
|
||||||
|
*Created: 2026-01-15*
|
||||||
|
*Author: AI Assistant*
|
||||||
|
*Status: Draft - Awaiting Review*
|
||||||
@@ -1,351 +0,0 @@
|
|||||||
# PRD: Terminal UI for Live Trading Bot
|
|
||||||
|
|
||||||
## Introduction/Overview
|
|
||||||
|
|
||||||
The live trading bot currently uses basic console logging for output, making it difficult to monitor trading activity, track performance, and understand the system state at a glance. This feature introduces a Rich-based terminal UI that provides a professional, real-time dashboard for monitoring the live trading bot.
|
|
||||||
|
|
||||||
The UI will display a horizontal split layout with a **summary panel** at the top (with tabbed time-period views) and a **scrollable log panel** at the bottom. The interface will update every second and support keyboard navigation.
|
|
||||||
|
|
||||||
## Goals
|
|
||||||
|
|
||||||
1. Provide real-time visibility into trading performance (PnL, win rate, trade count)
|
|
||||||
2. Enable monitoring of current position state (entry, SL/TP, unrealized PnL)
|
|
||||||
3. Display strategy signals (Z-score, model probability) for transparency
|
|
||||||
4. Support historical performance tracking across time periods (daily, weekly, monthly, all-time)
|
|
||||||
5. Improve operational experience with keyboard shortcuts and log filtering
|
|
||||||
6. Create a responsive design that works across different terminal sizes
|
|
||||||
|
|
||||||
## User Stories
|
|
||||||
|
|
||||||
1. **As a trader**, I want to see my total PnL and daily PnL at a glance so I can quickly assess performance.
|
|
||||||
2. **As a trader**, I want to see my current position details (entry price, unrealized PnL, SL/TP levels) so I can monitor risk.
|
|
||||||
3. **As a trader**, I want to view performance metrics by time period (daily, weekly, monthly) so I can track trends.
|
|
||||||
4. **As a trader**, I want to filter logs by type (errors, trades, signals) so I can focus on relevant information.
|
|
||||||
5. **As a trader**, I want keyboard shortcuts to navigate the UI without using a mouse.
|
|
||||||
6. **As a trader**, I want the UI to show strategy state (Z-score, probability) so I understand why signals are generated.
|
|
||||||
|
|
||||||
## Functional Requirements
|
|
||||||
|
|
||||||
### FR1: Layout Structure
|
|
||||||
|
|
||||||
1.1. The UI must use a horizontal split layout with the summary panel at the top and logs panel at the bottom.
|
|
||||||
|
|
||||||
1.2. The summary panel must contain tabbed views accessible via number keys:
|
|
||||||
- Tab 1 (`1`): **General** - Overall metrics since bot started
|
|
||||||
- Tab 2 (`2`): **Monthly** - Current month metrics (shown only if data spans > 1 month)
|
|
||||||
- Tab 3 (`3`): **Weekly** - Current week metrics (shown only if data spans > 1 week)
|
|
||||||
- Tab 4 (`4`): **Daily** - Today's metrics
|
|
||||||
|
|
||||||
1.3. The logs panel must be a scrollable area showing recent log entries.
|
|
||||||
|
|
||||||
1.4. The UI must be responsive and adapt to terminal size (minimum 80x24).
|
|
||||||
|
|
||||||
### FR2: Metrics Display
|
|
||||||
|
|
||||||
The summary panel must display the following metrics:
|
|
||||||
|
|
||||||
**Performance Metrics:**
|
|
||||||
2.1. Total PnL (USD) - cumulative profit/loss since tracking began
|
|
||||||
2.2. Period PnL (USD) - profit/loss for selected time period (daily/weekly/monthly)
|
|
||||||
2.3. Win Rate (%) - percentage of winning trades
|
|
||||||
2.4. Total Number of Trades
|
|
||||||
2.5. Average Trade Duration (hours)
|
|
||||||
2.6. Max Drawdown (USD and %)
|
|
||||||
|
|
||||||
**Current Position (if open):**
|
|
||||||
2.7. Symbol and side (long/short)
|
|
||||||
2.8. Entry price
|
|
||||||
2.9. Current price
|
|
||||||
2.10. Unrealized PnL (USD and %)
|
|
||||||
2.11. Stop-loss price and distance (%)
|
|
||||||
2.12. Take-profit price and distance (%)
|
|
||||||
2.13. Position size (USD)
|
|
||||||
|
|
||||||
**Account Status:**
|
|
||||||
2.14. Account balance / available margin (USDT)
|
|
||||||
2.15. Current leverage setting
|
|
||||||
|
|
||||||
**Strategy State:**
|
|
||||||
2.16. Current Z-score
|
|
||||||
2.17. Model probability
|
|
||||||
2.18. Current funding rate (BTC)
|
|
||||||
2.19. Last signal action and reason
|
|
||||||
|
|
||||||
### FR3: Historical Data Loading
|
|
||||||
|
|
||||||
3.1. On startup, the system must initialize SQLite database at `live_trading/trading.db`.
|
|
||||||
|
|
||||||
3.2. If `trade_log.csv` exists and database is empty, migrate CSV data to SQLite.
|
|
||||||
|
|
||||||
3.3. The UI must load current positions from `live_trading/positions.json` (kept for compatibility with existing position manager).
|
|
||||||
|
|
||||||
3.4. Metrics must be calculated via SQL aggregation queries for each time period.
|
|
||||||
|
|
||||||
3.5. If no historical data exists, the UI must show "No data" gracefully.
|
|
||||||
|
|
||||||
3.6. New trades must be written to both SQLite (primary) and CSV (backup/compatibility).
|
|
||||||
|
|
||||||
### FR4: Real-Time Updates
|
|
||||||
|
|
||||||
4.1. The UI must refresh every 1 second.
|
|
||||||
|
|
||||||
4.2. Position unrealized PnL must update based on latest price data.
|
|
||||||
|
|
||||||
4.3. New log entries must appear in real-time.
|
|
||||||
|
|
||||||
4.4. Metrics must recalculate when trades are opened/closed.
|
|
||||||
|
|
||||||
### FR5: Log Panel
|
|
||||||
|
|
||||||
5.1. The log panel must display log entries with timestamp, level, and message.
|
|
||||||
|
|
||||||
5.2. Log entries must be color-coded by level:
|
|
||||||
- ERROR: Red
|
|
||||||
- WARNING: Yellow
|
|
||||||
- INFO: White/Default
|
|
||||||
- DEBUG: Gray (if shown)
|
|
||||||
|
|
||||||
5.3. The log panel must support filtering by log type:
|
|
||||||
- All logs (default)
|
|
||||||
- Errors only
|
|
||||||
- Trades only (entries containing "position", "trade", "order")
|
|
||||||
- Signals only (entries containing "signal", "z_score", "prob")
|
|
||||||
|
|
||||||
5.4. Filter switching must be available via keyboard shortcut (`f` to cycle filters).
|
|
||||||
|
|
||||||
### FR6: Keyboard Controls
|
|
||||||
|
|
||||||
6.1. `q` or `Ctrl+C` - Graceful shutdown
|
|
||||||
6.2. `r` - Force refresh data
|
|
||||||
6.3. `1` - Switch to General tab
|
|
||||||
6.4. `2` - Switch to Monthly tab
|
|
||||||
6.5. `3` - Switch to Weekly tab
|
|
||||||
6.6. `4` - Switch to Daily tab
|
|
||||||
6.7. `f` - Cycle log filter
|
|
||||||
6.8. Arrow keys - Scroll logs (if supported)
|
|
||||||
|
|
||||||
### FR7: Color Scheme
|
|
||||||
|
|
||||||
7.1. Use dark theme as base.
|
|
||||||
|
|
||||||
7.2. PnL values must be colored:
|
|
||||||
- Positive: Green
|
|
||||||
- Negative: Red
|
|
||||||
- Zero/Neutral: White
|
|
||||||
|
|
||||||
7.3. Position side must be colored:
|
|
||||||
- Long: Green
|
|
||||||
- Short: Red
|
|
||||||
|
|
||||||
7.4. Use consistent color coding for emphasis and warnings.
|
|
||||||
|
|
||||||
## Non-Goals (Out of Scope)
|
|
||||||
|
|
||||||
1. **Mouse support** - This is a keyboard-driven terminal UI
|
|
||||||
2. **Trade execution from UI** - The UI is read-only; trades are executed by the bot
|
|
||||||
3. **Configuration editing** - Config changes require restarting the bot
|
|
||||||
4. **Multi-exchange support** - Only OKX is supported
|
|
||||||
5. **Charts/graphs** - Text-based metrics only (no ASCII charts in v1)
|
|
||||||
6. **Sound alerts** - No audio notifications
|
|
||||||
7. **Remote access** - Local terminal only
|
|
||||||
|
|
||||||
## Design Considerations
|
|
||||||
|
|
||||||
### Technology Choice: Rich
|
|
||||||
|
|
||||||
Use the [Rich](https://github.com/Textualize/rich) Python library for terminal UI:
|
|
||||||
- Rich provides `Live` display for real-time updates
|
|
||||||
- Rich `Layout` for split-screen design
|
|
||||||
- Rich `Table` for metrics display
|
|
||||||
- Rich `Panel` for bordered sections
|
|
||||||
- Rich `Text` for colored output
|
|
||||||
|
|
||||||
Alternative considered: **Textual** (also by Will McGugan) provides more advanced TUI features but adds complexity. Rich is simpler and sufficient for this use case.
|
|
||||||
|
|
||||||
### UI Mockup
|
|
||||||
|
|
||||||
```
|
|
||||||
+==============================================================================+
|
|
||||||
| REGIME REVERSION STRATEGY - LIVE TRADING [DEMO] ETH/USDT |
|
|
||||||
+==============================================================================+
|
|
||||||
| [1:General] [2:Monthly] [3:Weekly] [4:Daily] |
|
|
||||||
+------------------------------------------------------------------------------+
|
|
||||||
| PERFORMANCE | CURRENT POSITION |
|
|
||||||
| Total PnL: $1,234.56 | Side: LONG |
|
|
||||||
| Today PnL: $45.23 | Entry: $3,245.50 |
|
|
||||||
| Win Rate: 67.5% | Current: $3,289.00 |
|
|
||||||
| Total Trades: 24 | Unrealized: +$43.50 (+1.34%) |
|
|
||||||
| Avg Duration: 4.2h | Size: $500.00 |
|
|
||||||
| Max Drawdown: -$156.00 | SL: $3,050.00 (-6.0%) TP: $3,408.00 (+5%)|
|
|
||||||
| | |
|
|
||||||
| ACCOUNT | STRATEGY STATE |
|
|
||||||
| Balance: $5,432.10 | Z-Score: 1.45 |
|
|
||||||
| Available: $4,932.10 | Probability: 0.72 |
|
|
||||||
| Leverage: 2x | Funding: 0.0012 |
|
|
||||||
+------------------------------------------------------------------------------+
|
|
||||||
| LOGS [Filter: All] Press 'f' cycle |
|
|
||||||
+------------------------------------------------------------------------------+
|
|
||||||
| 14:32:15 [INFO] Trading Cycle Start: 2026-01-16T14:32:15+00:00 |
|
|
||||||
| 14:32:16 [INFO] Signal: entry long (prob=0.72, z=-1.45, reason=z_score...) |
|
|
||||||
| 14:32:17 [INFO] Executing LONG entry: 0.1540 ETH @ 3245.50 ($500.00) |
|
|
||||||
| 14:32:18 [INFO] Position opened: ETH/USDT:USDT_20260116_143217 |
|
|
||||||
| 14:32:18 [INFO] Portfolio: 1 positions, exposure=$500.00, unrealized=$0.00 |
|
|
||||||
| 14:32:18 [INFO] --- Cycle completed in 3.2s --- |
|
|
||||||
| 14:32:18 [INFO] Sleeping for 60 minutes... |
|
|
||||||
| |
|
|
||||||
+------------------------------------------------------------------------------+
|
|
||||||
| [q]Quit [r]Refresh [1-4]Tabs [f]Filter |
|
|
||||||
+==============================================================================+
|
|
||||||
```
|
|
||||||
|
|
||||||
### File Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
live_trading/
|
|
||||||
ui/
|
|
||||||
__init__.py
|
|
||||||
dashboard.py # Main UI orchestration and threading
|
|
||||||
panels.py # Panel components (metrics, logs, position)
|
|
||||||
state.py # Thread-safe shared state
|
|
||||||
log_handler.py # Custom logging handler for UI queue
|
|
||||||
keyboard.py # Keyboard input handling
|
|
||||||
db/
|
|
||||||
__init__.py
|
|
||||||
database.py # SQLite connection and queries
|
|
||||||
models.py # Data models (Trade, DailySummary, Session)
|
|
||||||
migrations.py # CSV migration and schema setup
|
|
||||||
metrics.py # Metrics aggregation queries
|
|
||||||
trading.db # SQLite database file (created at runtime)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Technical Considerations
|
|
||||||
|
|
||||||
### Integration with Existing Code
|
|
||||||
|
|
||||||
1. **Logging Integration**: Create a custom `logging.Handler` that captures log messages and forwards them to the UI log panel while still writing to file.
|
|
||||||
|
|
||||||
2. **Data Access**: The UI needs access to:
|
|
||||||
- `PositionManager` for current positions
|
|
||||||
- `TradingConfig` for settings display
|
|
||||||
- SQLite database for historical metrics
|
|
||||||
- Real-time data from `DataFeed` for current prices
|
|
||||||
|
|
||||||
3. **Main Loop Modification**: The `LiveTradingBot.run()` method needs modification to run the UI in a **separate thread** alongside the trading loop. The UI thread handles rendering and keyboard input while the main thread executes trading logic.
|
|
||||||
|
|
||||||
4. **Graceful Shutdown**: Ensure `SIGINT`/`SIGTERM` handlers work with the UI layer and properly terminate the UI thread.
|
|
||||||
|
|
||||||
### Database Schema (SQLite)
|
|
||||||
|
|
||||||
Create `live_trading/trading.db` with the following schema:
|
|
||||||
|
|
||||||
```sql
|
|
||||||
-- Trade history table
|
|
||||||
CREATE TABLE trades (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
trade_id TEXT UNIQUE NOT NULL,
|
|
||||||
symbol TEXT NOT NULL,
|
|
||||||
side TEXT NOT NULL, -- 'long' or 'short'
|
|
||||||
entry_price REAL NOT NULL,
|
|
||||||
exit_price REAL,
|
|
||||||
size REAL NOT NULL,
|
|
||||||
size_usdt REAL NOT NULL,
|
|
||||||
pnl_usd REAL,
|
|
||||||
pnl_pct REAL,
|
|
||||||
entry_time TEXT NOT NULL, -- ISO format
|
|
||||||
exit_time TEXT,
|
|
||||||
hold_duration_hours REAL,
|
|
||||||
reason TEXT, -- 'stop_loss', 'take_profit', 'signal', etc.
|
|
||||||
order_id_entry TEXT,
|
|
||||||
order_id_exit TEXT,
|
|
||||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
||||||
);
|
|
||||||
|
|
||||||
-- Daily summary table (for faster queries)
|
|
||||||
CREATE TABLE daily_summary (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
date TEXT UNIQUE NOT NULL, -- YYYY-MM-DD
|
|
||||||
total_trades INTEGER DEFAULT 0,
|
|
||||||
winning_trades INTEGER DEFAULT 0,
|
|
||||||
total_pnl_usd REAL DEFAULT 0,
|
|
||||||
max_drawdown_usd REAL DEFAULT 0,
|
|
||||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
||||||
);
|
|
||||||
|
|
||||||
-- Session metadata
|
|
||||||
CREATE TABLE sessions (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
start_time TEXT NOT NULL,
|
|
||||||
end_time TEXT,
|
|
||||||
starting_balance REAL,
|
|
||||||
ending_balance REAL,
|
|
||||||
total_pnl REAL,
|
|
||||||
total_trades INTEGER DEFAULT 0
|
|
||||||
);
|
|
||||||
|
|
||||||
-- Indexes for common queries
|
|
||||||
CREATE INDEX idx_trades_entry_time ON trades(entry_time);
|
|
||||||
CREATE INDEX idx_trades_exit_time ON trades(exit_time);
|
|
||||||
CREATE INDEX idx_daily_summary_date ON daily_summary(date);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Migration from CSV
|
|
||||||
|
|
||||||
On first run with the new system:
|
|
||||||
1. Check if `trading.db` exists
|
|
||||||
2. If not, create database with schema
|
|
||||||
3. If `trade_log.csv` exists, migrate data to `trades` table
|
|
||||||
4. Rebuild `daily_summary` from migrated trades
|
|
||||||
|
|
||||||
### Dependencies
|
|
||||||
|
|
||||||
Add to `pyproject.toml`:
|
|
||||||
```toml
|
|
||||||
dependencies = [
|
|
||||||
# ... existing deps
|
|
||||||
"rich>=13.0.0",
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
Note: SQLite is part of Python's standard library (`sqlite3`), no additional dependency needed.
|
|
||||||
|
|
||||||
### Performance Considerations
|
|
||||||
|
|
||||||
- UI runs in a separate thread to avoid blocking trading logic
|
|
||||||
- Log buffer limited to 1000 entries in memory to prevent growth
|
|
||||||
- SQLite queries should use indexes for fast period-based aggregations
|
|
||||||
- Historical data loading happens once at startup, incremental updates thereafter
|
|
||||||
|
|
||||||
### Threading Model
|
|
||||||
|
|
||||||
```
|
|
||||||
Main Thread UI Thread
|
|
||||||
| |
|
|
||||||
v v
|
|
||||||
[Trading Loop] [Rich Live Display]
|
|
||||||
| |
|
|
||||||
+---> SharedState <------------+
|
|
||||||
(thread-safe)
|
|
||||||
| |
|
|
||||||
+---> LogQueue <---------------+
|
|
||||||
(thread-safe)
|
|
||||||
```
|
|
||||||
|
|
||||||
- Use `threading.Lock` for shared state access
|
|
||||||
- Use `queue.Queue` for log message passing
|
|
||||||
- UI thread polls for updates every 1 second
|
|
||||||
|
|
||||||
## Success Metrics
|
|
||||||
|
|
||||||
1. UI starts successfully and displays all required metrics
|
|
||||||
2. UI updates in real-time (1-second refresh) without impacting trading performance
|
|
||||||
3. All keyboard shortcuts function correctly
|
|
||||||
4. Historical data loads and displays accurately from SQLite
|
|
||||||
5. Log filtering works as expected
|
|
||||||
6. UI gracefully handles edge cases (no data, no position, terminal resize)
|
|
||||||
7. CSV migration completes successfully on first run
|
|
||||||
8. Database queries complete within 100ms
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*Generated: 2026-01-16*
|
|
||||||
*Decisions: Threading model, 1000 log buffer, SQLite database, no fallback mode*
|
|
||||||
36
uv.lock
generated
36
uv.lock
generated
@@ -981,7 +981,6 @@ dependencies = [
|
|||||||
{ name = "plotly" },
|
{ name = "plotly" },
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
{ name = "requests" },
|
{ name = "requests" },
|
||||||
{ name = "rich" },
|
|
||||||
{ name = "scikit-learn" },
|
{ name = "scikit-learn" },
|
||||||
{ name = "sqlalchemy" },
|
{ name = "sqlalchemy" },
|
||||||
{ name = "ta" },
|
{ name = "ta" },
|
||||||
@@ -1005,7 +1004,6 @@ requires-dist = [
|
|||||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|
||||||
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
||||||
{ name = "requests", specifier = ">=2.32.5" },
|
{ name = "requests", specifier = ">=2.32.5" },
|
||||||
{ name = "rich", specifier = ">=13.0.0" },
|
|
||||||
{ name = "scikit-learn", specifier = ">=1.6.0" },
|
{ name = "scikit-learn", specifier = ">=1.6.0" },
|
||||||
{ name = "sqlalchemy", specifier = ">=2.0.0" },
|
{ name = "sqlalchemy", specifier = ">=2.0.0" },
|
||||||
{ name = "ta", specifier = ">=0.11.0" },
|
{ name = "ta", specifier = ">=0.11.0" },
|
||||||
@@ -1014,18 +1012,6 @@ requires-dist = [
|
|||||||
]
|
]
|
||||||
provides-extras = ["dev"]
|
provides-extras = ["dev"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "markdown-it-py"
|
|
||||||
version = "4.0.0"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
dependencies = [
|
|
||||||
{ name = "mdurl" },
|
|
||||||
]
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matplotlib"
|
name = "matplotlib"
|
||||||
version = "3.10.8"
|
version = "3.10.8"
|
||||||
@@ -1092,15 +1078,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" },
|
{ url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "mdurl"
|
|
||||||
version = "0.1.2"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "multidict"
|
name = "multidict"
|
||||||
version = "6.7.0"
|
version = "6.7.0"
|
||||||
@@ -1935,19 +1912,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rich"
|
|
||||||
version = "14.2.0"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
dependencies = [
|
|
||||||
{ name = "markdown-it-py" },
|
|
||||||
{ name = "pygments" },
|
|
||||||
]
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/fb/d2/8920e102050a0de7bfabeb4c4614a49248cf8d5d7a8d01885fbb24dc767a/rich-14.2.0.tar.gz", hash = "sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4", size = 219990, upload-time = "2025-10-09T14:16:53.064Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/25/7a/b0178788f8dc6cafce37a212c99565fa1fe7872c70c6c9c1e1a372d9d88f/rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd", size = 243393, upload-time = "2025-10-09T14:16:51.245Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "schedule"
|
name = "schedule"
|
||||||
version = "1.2.2"
|
version = "1.2.2"
|
||||||
|
|||||||
Reference in New Issue
Block a user