Add daily model training scripts and terminal UI for live trading
- Introduced `train_daily.sh` for automating daily model retraining, including data download and model training steps. - Added `install_cron.sh` for setting up a cron job to run the daily training script. - Created `setup_schedule.sh` for configuring Systemd timers for daily training tasks. - Implemented a terminal UI using Rich for real-time monitoring of trading performance, including metrics display and log handling. - Updated `pyproject.toml` to include the `rich` dependency for UI functionality. - Enhanced `.gitignore` to exclude model and log files. - Added database support for trade persistence and metrics calculation. - Updated README with installation and usage instructions for the new features.
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -175,3 +175,7 @@ data/backtest_runs.db
|
||||
.gitignore
|
||||
live_trading/regime_model.pkl
|
||||
live_trading/positions.json
|
||||
|
||||
|
||||
*.pkl
|
||||
*.db
|
||||
304
README.md
304
README.md
@@ -1,82 +1,262 @@
|
||||
### lowkey_backtest — Supertrend Backtester
|
||||
# Lowkey Backtest
|
||||
|
||||
### Overview
|
||||
Backtest a simple, long-only strategy driven by a meta Supertrend signal on aggregated OHLCV data. The script:
|
||||
- Loads 1-minute BTC/USD data from `../data/btcusd_1-min_data.csv`
|
||||
- Aggregates to multiple timeframes (e.g., `5min`, `15min`, `30min`, `1h`, `4h`, `1d`)
|
||||
- Computes three Supertrend variants and creates a meta signal when all agree
|
||||
- Executes entries/exits at the aggregated bar open price
|
||||
- Applies OKX spot fee assumptions (taker by default)
|
||||
- Evaluates stop-loss using intra-bar 1-minute data
|
||||
- Writes detailed trade logs and a summary CSV
|
||||
A backtesting framework supporting multiple market types (spot, perpetual) with realistic trading simulation including leverage, funding, and shorts.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Requirements
|
||||
- Python 3.12+
|
||||
- Dependencies: `pandas`, `numpy`, `ta`
|
||||
- Package management: `uv`
|
||||
- Package manager: `uv`
|
||||
|
||||
Install dependencies with uv:
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
# If a dependency is missing, add it explicitly and sync
|
||||
uv add pandas numpy ta
|
||||
uv sync
|
||||
```
|
||||
|
||||
### Data
|
||||
- Expected CSV location: `../data/btcusd_1-min_data.csv` (relative to the repo root)
|
||||
- Required columns: `Timestamp`, `Open`, `High`, `Low`, `Close`, `Volume`
|
||||
- `Timestamp` should be UNIX seconds; zero-volume rows are ignored
|
||||
## Quick Reference
|
||||
|
||||
### Quickstart
|
||||
Run the backtest with defaults:
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `uv run python main.py download -p BTC-USDT` | Download data |
|
||||
| `uv run python main.py backtest -s meta_st -p BTC-USDT` | Run backtest |
|
||||
| `uv run python main.py wfa -s regime -p BTC-USDT` | Walk-forward analysis |
|
||||
| `uv run python train_model.py --download` | Train/retrain ML model |
|
||||
| `uv run python research/regime_detection.py` | Run research script |
|
||||
|
||||
---
|
||||
|
||||
## Backtest CLI
|
||||
|
||||
The main entry point is `main.py` which provides three commands: `download`, `backtest`, and `wfa`.
|
||||
|
||||
### Download Data
|
||||
|
||||
Download historical OHLCV data from exchanges.
|
||||
|
||||
```bash
|
||||
uv run python main.py
|
||||
uv run python main.py download -p BTC-USDT -t 1h
|
||||
```
|
||||
|
||||
Outputs:
|
||||
- Per-run trade logs in `backtest_logs/` named like `trade_log_<TIMEFRAME>_sl<STOPLOSS>.csv`
|
||||
- Run-level summary in `backtest_summary.csv`
|
||||
**Options:**
|
||||
- `-p, --pair` (required): Trading pair (e.g., `BTC-USDT`, `ETH-USDT`)
|
||||
- `-t, --timeframe`: Timeframe (default: `1m`)
|
||||
- `-e, --exchange`: Exchange (default: `okx`)
|
||||
- `-m, --market`: Market type: `spot` or `perpetual` (default: `spot`)
|
||||
- `--start`: Start date in `YYYY-MM-DD` format
|
||||
|
||||
### Configuring a Run
|
||||
Adjust parameters directly in `main.py`:
|
||||
- Date range (in `load_data`): `load_data('2021-11-01', '2024-10-16')`
|
||||
- Timeframes to test (any subset of `"5min", "15min", "30min", "1h", "4h", "1d"`):
|
||||
- `timeframes = ["5min", "15min", "30min", "1h", "4h", "1d"]`
|
||||
- Stop-loss percentages:
|
||||
- `stoplosses = [0.03, 0.05, 0.1]`
|
||||
- Supertrend settings (in `add_supertrend_indicators`): `(period, multiplier)` pairs `(12, 3.0)`, `(10, 1.0)`, `(11, 2.0)`
|
||||
- Fee model (in `calculate_okx_taker_maker_fee`): taker `0.0010`, maker `0.0008`
|
||||
**Examples:**
|
||||
```bash
|
||||
# Download 1-hour spot data
|
||||
uv run python main.py download -p ETH-USDT -t 1h
|
||||
|
||||
### What the Backtester Does
|
||||
- Aggregation: Resamples 1-minute data to your selected timeframe using OHLCV rules
|
||||
- Supertrend signals: Computes three Supertrends and derives a meta trend of `+1` (bullish) or `-1` (bearish) when all agree; otherwise `0`
|
||||
- Trade logic (long-only):
|
||||
- Entry when the meta trend changes to bullish; uses aggregated bar open price
|
||||
- Exit when the meta trend changes to bearish; uses aggregated bar open price
|
||||
- Stop-loss: For each aggregated bar, scans corresponding 1-minute closes to detect stop-loss and exits using a realistic fill (threshold or next 1-minute open if gapped)
|
||||
- Performance metrics: total return, max drawdown, Sharpe (daily, factor 252), win rate, number of trades, final/initial equity, and total fees
|
||||
|
||||
### Important: Lookahead Bias Note
|
||||
The current implementation uses the meta Supertrend signal of the same bar for entries/exits, which introduces lookahead bias. To avoid this, lag the signal by one bar inside `backtest()` in `main.py`:
|
||||
|
||||
```python
|
||||
# Replace the current line
|
||||
meta_trend_signal = meta_trend
|
||||
|
||||
# With a one-bar lag to remove lookahead
|
||||
# meta_trend_signal = np.roll(meta_trend, 1)
|
||||
# meta_trend_signal[0] = 0
|
||||
# Download perpetual data from a specific date
|
||||
uv run python main.py download -p BTC-USDT -m perpetual --start 2024-01-01
|
||||
```
|
||||
|
||||
### Outputs
|
||||
- `backtest_logs/trade_log_<TIMEFRAME>_sl<STOPLOSS>.csv`: trade-by-trade records including type (`buy`, `sell`, `stop_loss`, `forced_close`), timestamps, prices, balances, PnL, and fees
|
||||
- `backtest_summary.csv`: one row per (timeframe, stop-loss) combination with `timeframe`, `stop_loss`, `total_return`, `max_drawdown`, `sharpe_ratio`, `win_rate`, `num_trades`, `final_equity`, `initial_equity`, `num_stop_losses`, `total_fees`
|
||||
### Run Backtest
|
||||
|
||||
### Troubleshooting
|
||||
- CSV not found: Ensure the dataset is located at `../data/btcusd_1-min_data.csv`
|
||||
- Missing packages: Run `uv add pandas numpy ta` then `uv sync`
|
||||
- Memory/performance: Large date ranges on 1-minute data can be heavy; narrow the date span or test fewer timeframes
|
||||
Run a backtest with a specific strategy.
|
||||
|
||||
```bash
|
||||
uv run python main.py backtest -s <strategy> -p <pair> [options]
|
||||
```
|
||||
|
||||
**Available Strategies:**
|
||||
- `meta_st` - Meta Supertrend (triple supertrend consensus)
|
||||
- `regime` - Regime Reversion (ML-based spread trading)
|
||||
- `rsi` - RSI overbought/oversold
|
||||
- `macross` - Moving Average Crossover
|
||||
|
||||
**Options:**
|
||||
- `-s, --strategy` (required): Strategy name
|
||||
- `-p, --pair` (required): Trading pair
|
||||
- `-t, --timeframe`: Timeframe (default: `1m`)
|
||||
- `--start`: Start date
|
||||
- `--end`: End date
|
||||
- `-g, --grid`: Run grid search optimization
|
||||
- `--plot`: Show equity curve plot
|
||||
- `--sl`: Stop loss percentage
|
||||
- `--tp`: Take profit percentage
|
||||
- `--trail`: Enable trailing stop
|
||||
- `--fees`: Override fee rate
|
||||
- `--slippage`: Slippage (default: `0.001`)
|
||||
- `-l, --leverage`: Leverage multiplier
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Basic backtest with Meta Supertrend
|
||||
uv run python main.py backtest -s meta_st -p BTC-USDT -t 1h
|
||||
|
||||
# Backtest with date range and plot
|
||||
uv run python main.py backtest -s meta_st -p BTC-USDT --start 2024-01-01 --end 2024-12-31 --plot
|
||||
|
||||
# Grid search optimization
|
||||
uv run python main.py backtest -s meta_st -p BTC-USDT -t 4h -g
|
||||
|
||||
# Backtest with risk parameters
|
||||
uv run python main.py backtest -s meta_st -p BTC-USDT --sl 0.05 --tp 0.10 --trail
|
||||
|
||||
# Regime strategy on ETH/BTC spread
|
||||
uv run python main.py backtest -s regime -p ETH-USDT -t 1h
|
||||
```
|
||||
|
||||
### Walk-Forward Analysis (WFA)
|
||||
|
||||
Run walk-forward optimization to avoid overfitting.
|
||||
|
||||
```bash
|
||||
uv run python main.py wfa -s <strategy> -p <pair> [options]
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `-s, --strategy` (required): Strategy name
|
||||
- `-p, --pair` (required): Trading pair
|
||||
- `-t, --timeframe`: Timeframe (default: `1d`)
|
||||
- `-w, --windows`: Number of walk-forward windows (default: `10`)
|
||||
- `--plot`: Show WFA results plot
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Walk-forward analysis with 10 windows
|
||||
uv run python main.py wfa -s meta_st -p BTC-USDT -t 1d -w 10
|
||||
|
||||
# WFA with plot output
|
||||
uv run python main.py wfa -s regime -p ETH-USDT --plot
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Research Scripts
|
||||
|
||||
Research scripts are located in the `research/` directory for experimental analysis.
|
||||
|
||||
### Regime Detection Research
|
||||
|
||||
Tests multiple holding horizons for the regime reversion strategy using walk-forward training.
|
||||
|
||||
```bash
|
||||
uv run python research/regime_detection.py
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--days DAYS`: Number of days of historical data (default: 90)
|
||||
- `--start DATE`: Start date (YYYY-MM-DD), overrides `--days`
|
||||
- `--end DATE`: End date (YYYY-MM-DD), defaults to now
|
||||
- `--output PATH`: Output CSV path
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Use last 90 days (default)
|
||||
uv run python research/regime_detection.py
|
||||
|
||||
# Use last 180 days
|
||||
uv run python research/regime_detection.py --days 180
|
||||
|
||||
# Specific date range
|
||||
uv run python research/regime_detection.py --start 2025-07-01 --end 2025-12-31
|
||||
```
|
||||
|
||||
**What it does:**
|
||||
- Loads BTC and ETH hourly data
|
||||
- Calculates spread features (Z-score, RSI, volume ratios)
|
||||
- Trains RandomForest classifier with walk-forward methodology
|
||||
- Tests horizons from 6h to 150h
|
||||
- Outputs best parameters by F1 score, Net PnL, and MAE
|
||||
|
||||
**Output:**
|
||||
- Console: Summary of results for each horizon
|
||||
- File: `research/horizon_optimization_results.csv`
|
||||
|
||||
---
|
||||
|
||||
## ML Model Training
|
||||
|
||||
The `regime` strategy uses a RandomForest classifier that can be trained with new data.
|
||||
|
||||
### Train Model
|
||||
|
||||
Train or retrain the ML model with latest data:
|
||||
|
||||
```bash
|
||||
uv run python train_model.py [options]
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--days DAYS`: Days of historical data (default: 90)
|
||||
- `--pair PAIR`: Base pair for context (default: BTC-USDT)
|
||||
- `--spread-pair PAIR`: Trading pair (default: ETH-USDT)
|
||||
- `--timeframe TF`: Timeframe (default: 1h)
|
||||
- `--market TYPE`: Market type: `spot` or `perpetual` (default: perpetual)
|
||||
- `--output PATH`: Model output path (default: `data/regime_model.pkl`)
|
||||
- `--train-ratio R`: Train/test split ratio (default: 0.7)
|
||||
- `--horizon H`: Prediction horizon in bars (default: 102)
|
||||
- `--download`: Download latest data before training
|
||||
- `--dry-run`: Run without saving model
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Train with last 90 days of data
|
||||
uv run python train_model.py
|
||||
|
||||
# Download fresh data and train
|
||||
uv run python train_model.py --download
|
||||
|
||||
# Train with 180 days of data
|
||||
uv run python train_model.py --days 180
|
||||
|
||||
# Train on spot market data
|
||||
uv run python train_model.py --market spot
|
||||
|
||||
# Dry run to see metrics without saving
|
||||
uv run python train_model.py --dry-run
|
||||
```
|
||||
|
||||
### Daily Retraining (Cron)
|
||||
|
||||
To automate daily model retraining, add a cron job:
|
||||
|
||||
```bash
|
||||
# Edit crontab
|
||||
crontab -e
|
||||
|
||||
# Add entry to retrain daily at 00:30 UTC
|
||||
30 0 * * * cd /path/to/lowkey_backtest_live && uv run python train_model.py --download >> logs/training.log 2>&1
|
||||
```
|
||||
|
||||
### Model Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `data/regime_model.pkl` | Current production model |
|
||||
| `data/regime_model_YYYYMMDD_HHMMSS.pkl` | Versioned model snapshots |
|
||||
|
||||
The model file contains:
|
||||
- Trained RandomForest classifier
|
||||
- Feature column names
|
||||
- Training metrics (F1 score, sample counts)
|
||||
- Training timestamp
|
||||
|
||||
---
|
||||
|
||||
## Output Files
|
||||
|
||||
| Location | Description |
|
||||
|----------|-------------|
|
||||
| `backtest_logs/` | Trade logs and WFA results |
|
||||
| `research/` | Research output files |
|
||||
| `data/` | Downloaded OHLCV data and ML models |
|
||||
| `data/regime_model.pkl` | Trained ML model for regime strategy |
|
||||
|
||||
---
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
uv run pytest tests/
|
||||
```
|
||||
|
||||
Run a specific test file:
|
||||
|
||||
```bash
|
||||
uv run pytest tests/test_data_manager.py
|
||||
```
|
||||
|
||||
29
install_cron.sh
Executable file
29
install_cron.sh
Executable file
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
# Install cron job for daily model training
|
||||
# Runs daily at 00:30
|
||||
|
||||
PROJECT_DIR="/home/tamaya/Documents/Work/TCP/lowkey_backtest_live"
|
||||
SCRIPT_PATH="$PROJECT_DIR/train_daily.sh"
|
||||
LOG_PATH="$PROJECT_DIR/logs/training.log"
|
||||
|
||||
# Check if script exists
|
||||
if [ ! -f "$SCRIPT_PATH" ]; then
|
||||
echo "Error: $SCRIPT_PATH not found!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Make executable
|
||||
chmod +x "$SCRIPT_PATH"
|
||||
|
||||
# Prepare cron entry
|
||||
# 30 0 * * * = 00:30 daily
|
||||
CRON_CMD="30 0 * * * $SCRIPT_PATH >> $LOG_PATH 2>&1"
|
||||
|
||||
# Check if job already exists
|
||||
(crontab -l 2>/dev/null | grep -F "$SCRIPT_PATH") && echo "Cron job already exists." && exit 0
|
||||
|
||||
# Add to crontab
|
||||
(crontab -l 2>/dev/null; echo "$CRON_CMD") | crontab -
|
||||
|
||||
echo "Cron job installed successfully:"
|
||||
echo "$CRON_CMD"
|
||||
@@ -60,7 +60,7 @@ class TradingConfig:
|
||||
|
||||
# Position sizing
|
||||
max_position_usdt: float = -1.0 # Max position size in USDT. If <= 0, use all available funds
|
||||
min_position_usdt: float = 10.0 # Min position size in USDT
|
||||
min_position_usdt: float = 1.0 # Min position size in USDT
|
||||
leverage: int = 1 # Leverage (1x = no leverage)
|
||||
margin_mode: str = "cross" # "cross" or "isolated"
|
||||
|
||||
|
||||
13
live_trading/db/__init__.py
Normal file
13
live_trading/db/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Database module for live trading persistence."""
|
||||
from .database import get_db, init_db
|
||||
from .models import Trade, DailySummary, Session
|
||||
from .metrics import MetricsCalculator
|
||||
|
||||
__all__ = [
|
||||
"get_db",
|
||||
"init_db",
|
||||
"Trade",
|
||||
"DailySummary",
|
||||
"Session",
|
||||
"MetricsCalculator",
|
||||
]
|
||||
325
live_trading/db/database.py
Normal file
325
live_trading/db/database.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""SQLite database connection and operations."""
|
||||
import sqlite3
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from contextlib import contextmanager
|
||||
|
||||
from .models import Trade, DailySummary, Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database schema
|
||||
SCHEMA = """
|
||||
-- Trade history table
|
||||
CREATE TABLE IF NOT EXISTS trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
trade_id TEXT UNIQUE NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
side TEXT NOT NULL,
|
||||
entry_price REAL NOT NULL,
|
||||
exit_price REAL,
|
||||
size REAL NOT NULL,
|
||||
size_usdt REAL NOT NULL,
|
||||
pnl_usd REAL,
|
||||
pnl_pct REAL,
|
||||
entry_time TEXT NOT NULL,
|
||||
exit_time TEXT,
|
||||
hold_duration_hours REAL,
|
||||
reason TEXT,
|
||||
order_id_entry TEXT,
|
||||
order_id_exit TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Daily summary table
|
||||
CREATE TABLE IF NOT EXISTS daily_summary (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
date TEXT UNIQUE NOT NULL,
|
||||
total_trades INTEGER DEFAULT 0,
|
||||
winning_trades INTEGER DEFAULT 0,
|
||||
total_pnl_usd REAL DEFAULT 0,
|
||||
max_drawdown_usd REAL DEFAULT 0,
|
||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Session metadata
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
start_time TEXT NOT NULL,
|
||||
end_time TEXT,
|
||||
starting_balance REAL,
|
||||
ending_balance REAL,
|
||||
total_pnl REAL,
|
||||
total_trades INTEGER DEFAULT 0
|
||||
);
|
||||
|
||||
-- Indexes for common queries
|
||||
CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades(entry_time);
|
||||
CREATE INDEX IF NOT EXISTS idx_trades_exit_time ON trades(exit_time);
|
||||
CREATE INDEX IF NOT EXISTS idx_daily_summary_date ON daily_summary(date);
|
||||
"""
|
||||
|
||||
_db_instance: Optional["TradingDatabase"] = None
|
||||
|
||||
|
||||
class TradingDatabase:
|
||||
"""SQLite database for trade persistence."""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = db_path
|
||||
self._connection: Optional[sqlite3.Connection] = None
|
||||
|
||||
@property
|
||||
def connection(self) -> sqlite3.Connection:
|
||||
"""Get or create database connection."""
|
||||
if self._connection is None:
|
||||
self._connection = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
check_same_thread=False,
|
||||
)
|
||||
self._connection.row_factory = sqlite3.Row
|
||||
return self._connection
|
||||
|
||||
def init_schema(self) -> None:
|
||||
"""Initialize database schema."""
|
||||
with self.connection:
|
||||
self.connection.executescript(SCHEMA)
|
||||
logger.info(f"Database initialized at {self.db_path}")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
if self._connection:
|
||||
self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
@contextmanager
|
||||
def transaction(self):
|
||||
"""Context manager for database transactions."""
|
||||
try:
|
||||
yield self.connection
|
||||
self.connection.commit()
|
||||
except Exception:
|
||||
self.connection.rollback()
|
||||
raise
|
||||
|
||||
def insert_trade(self, trade: Trade) -> int:
|
||||
"""
|
||||
Insert a new trade record.
|
||||
|
||||
Args:
|
||||
trade: Trade object to insert
|
||||
|
||||
Returns:
|
||||
Row ID of inserted trade
|
||||
"""
|
||||
sql = """
|
||||
INSERT INTO trades (
|
||||
trade_id, symbol, side, entry_price, exit_price,
|
||||
size, size_usdt, pnl_usd, pnl_pct, entry_time,
|
||||
exit_time, hold_duration_hours, reason,
|
||||
order_id_entry, order_id_exit
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self.transaction():
|
||||
cursor = self.connection.execute(
|
||||
sql,
|
||||
(
|
||||
trade.trade_id,
|
||||
trade.symbol,
|
||||
trade.side,
|
||||
trade.entry_price,
|
||||
trade.exit_price,
|
||||
trade.size,
|
||||
trade.size_usdt,
|
||||
trade.pnl_usd,
|
||||
trade.pnl_pct,
|
||||
trade.entry_time,
|
||||
trade.exit_time,
|
||||
trade.hold_duration_hours,
|
||||
trade.reason,
|
||||
trade.order_id_entry,
|
||||
trade.order_id_exit,
|
||||
),
|
||||
)
|
||||
return cursor.lastrowid
|
||||
|
||||
def update_trade(self, trade_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
Update an existing trade record.
|
||||
|
||||
Args:
|
||||
trade_id: Trade ID to update
|
||||
**kwargs: Fields to update
|
||||
|
||||
Returns:
|
||||
True if trade was updated
|
||||
"""
|
||||
if not kwargs:
|
||||
return False
|
||||
|
||||
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
|
||||
sql = f"UPDATE trades SET {set_clause} WHERE trade_id = ?"
|
||||
|
||||
with self.transaction():
|
||||
cursor = self.connection.execute(
|
||||
sql, (*kwargs.values(), trade_id)
|
||||
)
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_trade(self, trade_id: str) -> Optional[Trade]:
|
||||
"""Get a trade by ID."""
|
||||
sql = "SELECT * FROM trades WHERE trade_id = ?"
|
||||
row = self.connection.execute(sql, (trade_id,)).fetchone()
|
||||
if row:
|
||||
return Trade(**dict(row))
|
||||
return None
|
||||
|
||||
def get_trades(
|
||||
self,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[Trade]:
|
||||
"""
|
||||
Get trades within a time range.
|
||||
|
||||
Args:
|
||||
start_time: ISO format start time filter
|
||||
end_time: ISO format end time filter
|
||||
limit: Maximum number of trades to return
|
||||
|
||||
Returns:
|
||||
List of Trade objects
|
||||
"""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if start_time:
|
||||
conditions.append("entry_time >= ?")
|
||||
params.append(start_time)
|
||||
if end_time:
|
||||
conditions.append("entry_time <= ?")
|
||||
params.append(end_time)
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
limit_clause = f"LIMIT {limit}" if limit else ""
|
||||
|
||||
sql = f"""
|
||||
SELECT * FROM trades
|
||||
WHERE {where_clause}
|
||||
ORDER BY entry_time DESC
|
||||
{limit_clause}
|
||||
"""
|
||||
|
||||
rows = self.connection.execute(sql, params).fetchall()
|
||||
return [Trade(**dict(row)) for row in rows]
|
||||
|
||||
def get_all_trades(self) -> list[Trade]:
|
||||
"""Get all trades."""
|
||||
sql = "SELECT * FROM trades ORDER BY entry_time DESC"
|
||||
rows = self.connection.execute(sql).fetchall()
|
||||
return [Trade(**dict(row)) for row in rows]
|
||||
|
||||
def count_trades(self) -> int:
|
||||
"""Get total number of trades."""
|
||||
sql = "SELECT COUNT(*) FROM trades WHERE exit_time IS NOT NULL"
|
||||
return self.connection.execute(sql).fetchone()[0]
|
||||
|
||||
def upsert_daily_summary(self, summary: DailySummary) -> None:
|
||||
"""Insert or update daily summary."""
|
||||
sql = """
|
||||
INSERT INTO daily_summary (
|
||||
date, total_trades, winning_trades, total_pnl_usd, max_drawdown_usd
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(date) DO UPDATE SET
|
||||
total_trades = excluded.total_trades,
|
||||
winning_trades = excluded.winning_trades,
|
||||
total_pnl_usd = excluded.total_pnl_usd,
|
||||
max_drawdown_usd = excluded.max_drawdown_usd,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"""
|
||||
with self.transaction():
|
||||
self.connection.execute(
|
||||
sql,
|
||||
(
|
||||
summary.date,
|
||||
summary.total_trades,
|
||||
summary.winning_trades,
|
||||
summary.total_pnl_usd,
|
||||
summary.max_drawdown_usd,
|
||||
),
|
||||
)
|
||||
|
||||
def get_daily_summary(self, date: str) -> Optional[DailySummary]:
|
||||
"""Get daily summary for a specific date."""
|
||||
sql = "SELECT * FROM daily_summary WHERE date = ?"
|
||||
row = self.connection.execute(sql, (date,)).fetchone()
|
||||
if row:
|
||||
return DailySummary(**dict(row))
|
||||
return None
|
||||
|
||||
def insert_session(self, session: Session) -> int:
|
||||
"""Insert a new session record."""
|
||||
sql = """
|
||||
INSERT INTO sessions (
|
||||
start_time, end_time, starting_balance,
|
||||
ending_balance, total_pnl, total_trades
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self.transaction():
|
||||
cursor = self.connection.execute(
|
||||
sql,
|
||||
(
|
||||
session.start_time,
|
||||
session.end_time,
|
||||
session.starting_balance,
|
||||
session.ending_balance,
|
||||
session.total_pnl,
|
||||
session.total_trades,
|
||||
),
|
||||
)
|
||||
return cursor.lastrowid
|
||||
|
||||
def update_session(self, session_id: int, **kwargs) -> bool:
|
||||
"""Update an existing session."""
|
||||
if not kwargs:
|
||||
return False
|
||||
|
||||
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
|
||||
sql = f"UPDATE sessions SET {set_clause} WHERE id = ?"
|
||||
|
||||
with self.transaction():
|
||||
cursor = self.connection.execute(
|
||||
sql, (*kwargs.values(), session_id)
|
||||
)
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_latest_session(self) -> Optional[Session]:
|
||||
"""Get the most recent session."""
|
||||
sql = "SELECT * FROM sessions ORDER BY id DESC LIMIT 1"
|
||||
row = self.connection.execute(sql).fetchone()
|
||||
if row:
|
||||
return Session(**dict(row))
|
||||
return None
|
||||
|
||||
|
||||
def init_db(db_path: Path) -> TradingDatabase:
|
||||
"""
|
||||
Initialize the database.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file
|
||||
|
||||
Returns:
|
||||
TradingDatabase instance
|
||||
"""
|
||||
global _db_instance
|
||||
_db_instance = TradingDatabase(db_path)
|
||||
_db_instance.init_schema()
|
||||
return _db_instance
|
||||
|
||||
|
||||
def get_db() -> Optional[TradingDatabase]:
|
||||
"""Get the global database instance."""
|
||||
return _db_instance
|
||||
235
live_trading/db/metrics.py
Normal file
235
live_trading/db/metrics.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Metrics calculation from trade database."""
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from .database import TradingDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeriodMetrics:
|
||||
"""Trading metrics for a time period."""
|
||||
|
||||
period_name: str
|
||||
start_time: Optional[str]
|
||||
end_time: Optional[str]
|
||||
total_pnl: float = 0.0
|
||||
total_trades: int = 0
|
||||
winning_trades: int = 0
|
||||
losing_trades: int = 0
|
||||
win_rate: float = 0.0
|
||||
avg_trade_duration_hours: float = 0.0
|
||||
max_drawdown: float = 0.0
|
||||
max_drawdown_pct: float = 0.0
|
||||
best_trade: float = 0.0
|
||||
worst_trade: float = 0.0
|
||||
avg_win: float = 0.0
|
||||
avg_loss: float = 0.0
|
||||
|
||||
|
||||
class MetricsCalculator:
|
||||
"""Calculate trading metrics from database."""
|
||||
|
||||
def __init__(self, db: TradingDatabase):
|
||||
self.db = db
|
||||
|
||||
def get_all_time_metrics(self) -> PeriodMetrics:
|
||||
"""Get metrics for all trades ever."""
|
||||
return self._calculate_metrics("All Time", None, None)
|
||||
|
||||
def get_monthly_metrics(self) -> PeriodMetrics:
|
||||
"""Get metrics for current month."""
|
||||
now = datetime.now(timezone.utc)
|
||||
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
return self._calculate_metrics(
|
||||
"Monthly",
|
||||
start.isoformat(),
|
||||
now.isoformat(),
|
||||
)
|
||||
|
||||
def get_weekly_metrics(self) -> PeriodMetrics:
|
||||
"""Get metrics for current week (Monday to now)."""
|
||||
now = datetime.now(timezone.utc)
|
||||
days_since_monday = now.weekday()
|
||||
start = now - timedelta(days=days_since_monday)
|
||||
start = start.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
return self._calculate_metrics(
|
||||
"Weekly",
|
||||
start.isoformat(),
|
||||
now.isoformat(),
|
||||
)
|
||||
|
||||
def get_daily_metrics(self) -> PeriodMetrics:
|
||||
"""Get metrics for today (UTC)."""
|
||||
now = datetime.now(timezone.utc)
|
||||
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
return self._calculate_metrics(
|
||||
"Daily",
|
||||
start.isoformat(),
|
||||
now.isoformat(),
|
||||
)
|
||||
|
||||
def _calculate_metrics(
|
||||
self,
|
||||
period_name: str,
|
||||
start_time: Optional[str],
|
||||
end_time: Optional[str],
|
||||
) -> PeriodMetrics:
|
||||
"""
|
||||
Calculate metrics for a time period.
|
||||
|
||||
Args:
|
||||
period_name: Name of the period
|
||||
start_time: ISO format start time (None for all time)
|
||||
end_time: ISO format end time (None for all time)
|
||||
|
||||
Returns:
|
||||
PeriodMetrics object
|
||||
"""
|
||||
metrics = PeriodMetrics(
|
||||
period_name=period_name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# Build query conditions
|
||||
conditions = ["exit_time IS NOT NULL"]
|
||||
params = []
|
||||
|
||||
if start_time:
|
||||
conditions.append("exit_time >= ?")
|
||||
params.append(start_time)
|
||||
if end_time:
|
||||
conditions.append("exit_time <= ?")
|
||||
params.append(end_time)
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
|
||||
# Get aggregate metrics
|
||||
sql = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_trades,
|
||||
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
|
||||
SUM(CASE WHEN pnl_usd < 0 THEN 1 ELSE 0 END) as losing_trades,
|
||||
COALESCE(SUM(pnl_usd), 0) as total_pnl,
|
||||
COALESCE(AVG(hold_duration_hours), 0) as avg_duration,
|
||||
COALESCE(MAX(pnl_usd), 0) as best_trade,
|
||||
COALESCE(MIN(pnl_usd), 0) as worst_trade,
|
||||
COALESCE(AVG(CASE WHEN pnl_usd > 0 THEN pnl_usd END), 0) as avg_win,
|
||||
COALESCE(AVG(CASE WHEN pnl_usd < 0 THEN pnl_usd END), 0) as avg_loss
|
||||
FROM trades
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
|
||||
row = self.db.connection.execute(sql, params).fetchone()
|
||||
|
||||
if row and row["total_trades"] > 0:
|
||||
metrics.total_trades = row["total_trades"]
|
||||
metrics.winning_trades = row["winning_trades"] or 0
|
||||
metrics.losing_trades = row["losing_trades"] or 0
|
||||
metrics.total_pnl = row["total_pnl"]
|
||||
metrics.avg_trade_duration_hours = row["avg_duration"]
|
||||
metrics.best_trade = row["best_trade"]
|
||||
metrics.worst_trade = row["worst_trade"]
|
||||
metrics.avg_win = row["avg_win"]
|
||||
metrics.avg_loss = row["avg_loss"]
|
||||
|
||||
if metrics.total_trades > 0:
|
||||
metrics.win_rate = (
|
||||
metrics.winning_trades / metrics.total_trades * 100
|
||||
)
|
||||
|
||||
# Calculate max drawdown
|
||||
metrics.max_drawdown = self._calculate_max_drawdown(
|
||||
start_time, end_time
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
def _calculate_max_drawdown(
|
||||
self,
|
||||
start_time: Optional[str],
|
||||
end_time: Optional[str],
|
||||
) -> float:
|
||||
"""Calculate maximum drawdown for a period."""
|
||||
conditions = ["exit_time IS NOT NULL"]
|
||||
params = []
|
||||
|
||||
if start_time:
|
||||
conditions.append("exit_time >= ?")
|
||||
params.append(start_time)
|
||||
if end_time:
|
||||
conditions.append("exit_time <= ?")
|
||||
params.append(end_time)
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
|
||||
sql = f"""
|
||||
SELECT pnl_usd
|
||||
FROM trades
|
||||
WHERE {where_clause}
|
||||
ORDER BY exit_time
|
||||
"""
|
||||
|
||||
rows = self.db.connection.execute(sql, params).fetchall()
|
||||
|
||||
if not rows:
|
||||
return 0.0
|
||||
|
||||
cumulative = 0.0
|
||||
peak = 0.0
|
||||
max_drawdown = 0.0
|
||||
|
||||
for row in rows:
|
||||
pnl = row["pnl_usd"] or 0.0
|
||||
cumulative += pnl
|
||||
peak = max(peak, cumulative)
|
||||
drawdown = peak - cumulative
|
||||
max_drawdown = max(max_drawdown, drawdown)
|
||||
|
||||
return max_drawdown
|
||||
|
||||
def has_monthly_data(self) -> bool:
|
||||
"""Check if we have data spanning more than current month."""
|
||||
sql = """
|
||||
SELECT MIN(exit_time) as first_trade
|
||||
FROM trades
|
||||
WHERE exit_time IS NOT NULL
|
||||
"""
|
||||
row = self.db.connection.execute(sql).fetchone()
|
||||
if not row or not row["first_trade"]:
|
||||
return False
|
||||
|
||||
first_trade = datetime.fromisoformat(row["first_trade"])
|
||||
now = datetime.now(timezone.utc)
|
||||
month_start = now.replace(day=1, hour=0, minute=0, second=0)
|
||||
|
||||
return first_trade < month_start
|
||||
|
||||
def has_weekly_data(self) -> bool:
|
||||
"""Check if we have data spanning more than current week."""
|
||||
sql = """
|
||||
SELECT MIN(exit_time) as first_trade
|
||||
FROM trades
|
||||
WHERE exit_time IS NOT NULL
|
||||
"""
|
||||
row = self.db.connection.execute(sql).fetchone()
|
||||
if not row or not row["first_trade"]:
|
||||
return False
|
||||
|
||||
first_trade = datetime.fromisoformat(row["first_trade"])
|
||||
now = datetime.now(timezone.utc)
|
||||
days_since_monday = now.weekday()
|
||||
week_start = now - timedelta(days=days_since_monday)
|
||||
week_start = week_start.replace(hour=0, minute=0, second=0)
|
||||
|
||||
return first_trade < week_start
|
||||
|
||||
def get_session_start_balance(self) -> Optional[float]:
|
||||
"""Get starting balance from latest session."""
|
||||
sql = "SELECT starting_balance FROM sessions ORDER BY id DESC LIMIT 1"
|
||||
row = self.db.connection.execute(sql).fetchone()
|
||||
return row["starting_balance"] if row else None
|
||||
191
live_trading/db/migrations.py
Normal file
191
live_trading/db/migrations.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Database migrations and CSV import."""
|
||||
import csv
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from .database import TradingDatabase
|
||||
from .models import Trade, DailySummary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_csv_to_db(db: TradingDatabase, csv_path: Path) -> int:
|
||||
"""
|
||||
Migrate trades from CSV file to SQLite database.
|
||||
|
||||
Args:
|
||||
db: TradingDatabase instance
|
||||
csv_path: Path to trade_log.csv
|
||||
|
||||
Returns:
|
||||
Number of trades migrated
|
||||
"""
|
||||
if not csv_path.exists():
|
||||
logger.info("No CSV file to migrate")
|
||||
return 0
|
||||
|
||||
# Check if database already has trades
|
||||
existing_count = db.count_trades()
|
||||
if existing_count > 0:
|
||||
logger.info(
|
||||
f"Database already has {existing_count} trades, skipping migration"
|
||||
)
|
||||
return 0
|
||||
|
||||
migrated = 0
|
||||
try:
|
||||
with open(csv_path, "r", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
for row in reader:
|
||||
trade = _csv_row_to_trade(row)
|
||||
if trade:
|
||||
try:
|
||||
db.insert_trade(trade)
|
||||
migrated += 1
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to migrate trade {row.get('trade_id')}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"Migrated {migrated} trades from CSV to database")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CSV migration failed: {e}")
|
||||
|
||||
return migrated
|
||||
|
||||
|
||||
def _csv_row_to_trade(row: dict) -> Trade | None:
|
||||
"""Convert a CSV row to a Trade object."""
|
||||
try:
|
||||
return Trade(
|
||||
trade_id=row["trade_id"],
|
||||
symbol=row["symbol"],
|
||||
side=row["side"],
|
||||
entry_price=float(row["entry_price"]),
|
||||
exit_price=_safe_float(row.get("exit_price")),
|
||||
size=float(row["size"]),
|
||||
size_usdt=float(row["size_usdt"]),
|
||||
pnl_usd=_safe_float(row.get("pnl_usd")),
|
||||
pnl_pct=_safe_float(row.get("pnl_pct")),
|
||||
entry_time=row["entry_time"],
|
||||
exit_time=row.get("exit_time") or None,
|
||||
hold_duration_hours=_safe_float(row.get("hold_duration_hours")),
|
||||
reason=row.get("reason") or None,
|
||||
order_id_entry=row.get("order_id_entry") or None,
|
||||
order_id_exit=row.get("order_id_exit") or None,
|
||||
)
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning(f"Invalid CSV row: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _safe_float(value: str | None) -> float | None:
|
||||
"""Safely convert string to float."""
|
||||
if value is None or value == "":
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def rebuild_daily_summaries(db: TradingDatabase) -> int:
|
||||
"""
|
||||
Rebuild daily summary table from trades.
|
||||
|
||||
Args:
|
||||
db: TradingDatabase instance
|
||||
|
||||
Returns:
|
||||
Number of daily summaries created
|
||||
"""
|
||||
sql = """
|
||||
SELECT
|
||||
DATE(exit_time) as date,
|
||||
COUNT(*) as total_trades,
|
||||
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
|
||||
SUM(pnl_usd) as total_pnl_usd
|
||||
FROM trades
|
||||
WHERE exit_time IS NOT NULL
|
||||
GROUP BY DATE(exit_time)
|
||||
ORDER BY date
|
||||
"""
|
||||
|
||||
rows = db.connection.execute(sql).fetchall()
|
||||
count = 0
|
||||
|
||||
for row in rows:
|
||||
summary = DailySummary(
|
||||
date=row["date"],
|
||||
total_trades=row["total_trades"],
|
||||
winning_trades=row["winning_trades"],
|
||||
total_pnl_usd=row["total_pnl_usd"] or 0.0,
|
||||
max_drawdown_usd=0.0, # Calculated separately
|
||||
)
|
||||
db.upsert_daily_summary(summary)
|
||||
count += 1
|
||||
|
||||
# Calculate max drawdowns
|
||||
_calculate_daily_drawdowns(db)
|
||||
|
||||
logger.info(f"Rebuilt {count} daily summaries")
|
||||
return count
|
||||
|
||||
|
||||
def _calculate_daily_drawdowns(db: TradingDatabase) -> None:
|
||||
"""Calculate and update max drawdown for each day."""
|
||||
sql = """
|
||||
SELECT trade_id, DATE(exit_time) as date, pnl_usd
|
||||
FROM trades
|
||||
WHERE exit_time IS NOT NULL
|
||||
ORDER BY exit_time
|
||||
"""
|
||||
|
||||
rows = db.connection.execute(sql).fetchall()
|
||||
|
||||
# Track cumulative PnL and drawdown per day
|
||||
daily_drawdowns: dict[str, float] = {}
|
||||
cumulative_pnl = 0.0
|
||||
peak_pnl = 0.0
|
||||
|
||||
for row in rows:
|
||||
date = row["date"]
|
||||
pnl = row["pnl_usd"] or 0.0
|
||||
|
||||
cumulative_pnl += pnl
|
||||
peak_pnl = max(peak_pnl, cumulative_pnl)
|
||||
drawdown = peak_pnl - cumulative_pnl
|
||||
|
||||
if date not in daily_drawdowns:
|
||||
daily_drawdowns[date] = 0.0
|
||||
daily_drawdowns[date] = max(daily_drawdowns[date], drawdown)
|
||||
|
||||
# Update daily summaries with drawdown
|
||||
for date, drawdown in daily_drawdowns.items():
|
||||
db.connection.execute(
|
||||
"UPDATE daily_summary SET max_drawdown_usd = ? WHERE date = ?",
|
||||
(drawdown, date),
|
||||
)
|
||||
db.connection.commit()
|
||||
|
||||
|
||||
def run_migrations(db: TradingDatabase, csv_path: Path) -> None:
|
||||
"""
|
||||
Run all migrations.
|
||||
|
||||
Args:
|
||||
db: TradingDatabase instance
|
||||
csv_path: Path to trade_log.csv for migration
|
||||
"""
|
||||
logger.info("Running database migrations...")
|
||||
|
||||
# Migrate CSV data if exists
|
||||
migrate_csv_to_db(db, csv_path)
|
||||
|
||||
# Rebuild daily summaries
|
||||
rebuild_daily_summaries(db)
|
||||
|
||||
logger.info("Migrations complete")
|
||||
69
live_trading/db/models.py
Normal file
69
live_trading/db/models.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Data models for trade persistence."""
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trade:
|
||||
"""Represents a completed trade."""
|
||||
|
||||
trade_id: str
|
||||
symbol: str
|
||||
side: str
|
||||
entry_price: float
|
||||
size: float
|
||||
size_usdt: float
|
||||
entry_time: str
|
||||
exit_price: Optional[float] = None
|
||||
pnl_usd: Optional[float] = None
|
||||
pnl_pct: Optional[float] = None
|
||||
exit_time: Optional[str] = None
|
||||
hold_duration_hours: Optional[float] = None
|
||||
reason: Optional[str] = None
|
||||
order_id_entry: Optional[str] = None
|
||||
order_id_exit: Optional[str] = None
|
||||
id: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: tuple, columns: list[str]) -> "Trade":
|
||||
"""Create Trade from database row."""
|
||||
data = dict(zip(columns, row))
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DailySummary:
|
||||
"""Daily trading summary."""
|
||||
|
||||
date: str
|
||||
total_trades: int = 0
|
||||
winning_trades: int = 0
|
||||
total_pnl_usd: float = 0.0
|
||||
max_drawdown_usd: float = 0.0
|
||||
id: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""Trading session metadata."""
|
||||
|
||||
start_time: str
|
||||
end_time: Optional[str] = None
|
||||
starting_balance: Optional[float] = None
|
||||
ending_balance: Optional[float] = None
|
||||
total_pnl: Optional[float] = None
|
||||
total_trades: int = 0
|
||||
id: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
@@ -39,18 +39,40 @@ class LiveRegimeStrategy:
|
||||
self.paths = path_config
|
||||
self.model: Optional[RandomForestClassifier] = None
|
||||
self.feature_cols: Optional[list] = None
|
||||
self.horizon: int = 102 # Default horizon
|
||||
self._last_model_load_time: float = 0.0
|
||||
self._load_or_train_model()
|
||||
|
||||
def reload_model_if_changed(self) -> None:
|
||||
"""Check if model file has changed and reload if necessary."""
|
||||
if not self.paths.model_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
mtime = self.paths.model_path.stat().st_mtime
|
||||
if mtime > self._last_model_load_time:
|
||||
logger.info(f"Model file changed, reloading... (last: {self._last_model_load_time}, new: {mtime})")
|
||||
self._load_or_train_model()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking model file: {e}")
|
||||
|
||||
def _load_or_train_model(self) -> None:
|
||||
"""Load pre-trained model or train a new one."""
|
||||
if self.paths.model_path.exists():
|
||||
try:
|
||||
self._last_model_load_time = self.paths.model_path.stat().st_mtime
|
||||
with open(self.paths.model_path, 'rb') as f:
|
||||
saved = pickle.load(f)
|
||||
self.model = saved['model']
|
||||
self.feature_cols = saved['feature_cols']
|
||||
logger.info(f"Loaded model from {self.paths.model_path}")
|
||||
return
|
||||
|
||||
# Load horizon from metrics if available
|
||||
if 'metrics' in saved and 'horizon' in saved['metrics']:
|
||||
self.horizon = saved['metrics']['horizon']
|
||||
logger.info(f"Loaded model from {self.paths.model_path} (horizon={self.horizon})")
|
||||
else:
|
||||
logger.info(f"Loaded model from {self.paths.model_path} (default horizon={self.horizon})")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load model: {e}")
|
||||
|
||||
@@ -66,6 +88,7 @@ class LiveRegimeStrategy:
|
||||
pickle.dump({
|
||||
'model': self.model,
|
||||
'feature_cols': self.feature_cols,
|
||||
'metrics': {'horizon': self.horizon} # Save horizon
|
||||
}, f)
|
||||
logger.info(f"Saved model to {self.paths.model_path}")
|
||||
except Exception as e:
|
||||
@@ -81,7 +104,7 @@ class LiveRegimeStrategy:
|
||||
logger.info(f"Training model on {len(features)} samples...")
|
||||
|
||||
z_thresh = self.config.z_entry_threshold
|
||||
horizon = 102 # Optimal horizon from research
|
||||
horizon = self.horizon
|
||||
profit_target = 0.005 # 0.5% profit threshold
|
||||
|
||||
# Define targets
|
||||
|
||||
@@ -11,14 +11,19 @@ Usage:
|
||||
|
||||
# Run with specific settings
|
||||
uv run python -m live_trading.main --max-position 500 --leverage 2
|
||||
|
||||
# Run without UI (headless mode)
|
||||
uv run python -m live_trading.main --no-ui
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
@@ -28,22 +33,47 @@ from live_trading.okx_client import OKXClient
|
||||
from live_trading.data_feed import DataFeed
|
||||
from live_trading.position_manager import PositionManager
|
||||
from live_trading.live_regime_strategy import LiveRegimeStrategy
|
||||
from live_trading.db.database import init_db, TradingDatabase
|
||||
from live_trading.db.migrations import run_migrations
|
||||
from live_trading.ui.state import SharedState, PositionState
|
||||
from live_trading.ui.dashboard import TradingDashboard, setup_ui_logging
|
||||
|
||||
|
||||
def setup_logging(log_dir: Path) -> logging.Logger:
|
||||
"""Configure logging for the trading bot."""
|
||||
def setup_logging(
|
||||
log_dir: Path,
|
||||
log_queue: Optional[queue.Queue] = None,
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Configure logging for the trading bot.
|
||||
|
||||
Args:
|
||||
log_dir: Directory for log files
|
||||
log_queue: Optional queue for UI log handler
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
log_file = log_dir / "live_trading.log"
|
||||
|
||||
handlers = [
|
||||
logging.FileHandler(log_file),
|
||||
]
|
||||
|
||||
# Only add StreamHandler if no UI (log_queue is None)
|
||||
if log_queue is None:
|
||||
handlers.append(logging.StreamHandler(sys.stdout))
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_file),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
force=True
|
||||
handlers=handlers,
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Add UI log handler if queue provided
|
||||
if log_queue is not None:
|
||||
setup_ui_logging(log_queue)
|
||||
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -59,11 +89,15 @@ class LiveTradingBot:
|
||||
self,
|
||||
okx_config: OKXConfig,
|
||||
trading_config: TradingConfig,
|
||||
path_config: PathConfig
|
||||
path_config: PathConfig,
|
||||
database: Optional[TradingDatabase] = None,
|
||||
shared_state: Optional[SharedState] = None,
|
||||
):
|
||||
self.okx_config = okx_config
|
||||
self.trading_config = trading_config
|
||||
self.path_config = path_config
|
||||
self.db = database
|
||||
self.state = shared_state
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.running = True
|
||||
@@ -74,7 +108,7 @@ class LiveTradingBot:
|
||||
self.okx_client = OKXClient(okx_config, trading_config)
|
||||
self.data_feed = DataFeed(self.okx_client, trading_config, path_config)
|
||||
self.position_manager = PositionManager(
|
||||
self.okx_client, trading_config, path_config
|
||||
self.okx_client, trading_config, path_config, database
|
||||
)
|
||||
self.strategy = LiveRegimeStrategy(trading_config, path_config)
|
||||
|
||||
@@ -82,6 +116,16 @@ class LiveTradingBot:
|
||||
signal.signal(signal.SIGINT, self._handle_shutdown)
|
||||
signal.signal(signal.SIGTERM, self._handle_shutdown)
|
||||
|
||||
# Initialize shared state if provided
|
||||
if self.state:
|
||||
mode = "DEMO" if okx_config.demo_mode else "LIVE"
|
||||
self.state.set_mode(mode)
|
||||
self.state.set_symbols(
|
||||
trading_config.eth_symbol,
|
||||
trading_config.btc_symbol,
|
||||
)
|
||||
self.state.update_account(0.0, 0.0, trading_config.leverage)
|
||||
|
||||
self._print_startup_banner()
|
||||
|
||||
def _print_startup_banner(self) -> None:
|
||||
@@ -109,6 +153,8 @@ class LiveTradingBot:
|
||||
"""Handle shutdown signals gracefully."""
|
||||
self.logger.info("Shutdown signal received, stopping...")
|
||||
self.running = False
|
||||
if self.state:
|
||||
self.state.stop()
|
||||
|
||||
def run_trading_cycle(self) -> None:
|
||||
"""
|
||||
@@ -118,10 +164,20 @@ class LiveTradingBot:
|
||||
2. Update open positions
|
||||
3. Generate trading signal
|
||||
4. Execute trades if signal triggers
|
||||
5. Update shared state for UI
|
||||
"""
|
||||
# Reload model if it has changed (e.g. daily training)
|
||||
try:
|
||||
self.strategy.reload_model_if_changed()
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to reload model: {e}")
|
||||
|
||||
cycle_start = datetime.now(timezone.utc)
|
||||
self.logger.info(f"--- Trading Cycle Start: {cycle_start.isoformat()} ---")
|
||||
|
||||
if self.state:
|
||||
self.state.set_last_cycle_time(cycle_start.isoformat())
|
||||
|
||||
try:
|
||||
# 1. Fetch market data
|
||||
features = self.data_feed.get_latest_data()
|
||||
@@ -154,15 +210,22 @@ class LiveTradingBot:
|
||||
funding = self.data_feed.get_current_funding_rates()
|
||||
|
||||
# 5. Generate trading signal
|
||||
signal = self.strategy.generate_signal(features, funding)
|
||||
sig = self.strategy.generate_signal(features, funding)
|
||||
|
||||
# 6. Execute trades based on signal
|
||||
if signal['action'] == 'entry':
|
||||
self._execute_entry(signal, eth_price)
|
||||
elif signal['action'] == 'check_exit':
|
||||
self._execute_exit(signal)
|
||||
# 6. Update shared state with strategy info
|
||||
self._update_strategy_state(sig, funding)
|
||||
|
||||
# 7. Log portfolio summary
|
||||
# 7. Execute trades based on signal
|
||||
if sig['action'] == 'entry':
|
||||
self._execute_entry(sig, eth_price)
|
||||
elif sig['action'] == 'check_exit':
|
||||
self._execute_exit(sig)
|
||||
|
||||
# 8. Update shared state with position and account
|
||||
self._update_position_state(eth_price)
|
||||
self._update_account_state()
|
||||
|
||||
# 9. Log portfolio summary
|
||||
summary = self.position_manager.get_portfolio_summary()
|
||||
self.logger.info(
|
||||
f"Portfolio: {summary['open_positions']} positions, "
|
||||
@@ -178,6 +241,61 @@ class LiveTradingBot:
|
||||
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
|
||||
self.logger.info(f"--- Cycle completed in {cycle_duration:.1f}s ---")
|
||||
|
||||
def _update_strategy_state(self, sig: dict, funding: dict) -> None:
|
||||
"""Update shared state with strategy information."""
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
self.state.update_strategy(
|
||||
z_score=sig.get('z_score', 0.0),
|
||||
probability=sig.get('probability', 0.0),
|
||||
funding_rate=funding.get('btc_funding', 0.0),
|
||||
action=sig.get('action', 'hold'),
|
||||
reason=sig.get('reason', ''),
|
||||
)
|
||||
|
||||
def _update_position_state(self, current_price: float) -> None:
|
||||
"""Update shared state with current position."""
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
symbol = self.trading_config.eth_symbol
|
||||
position = self.position_manager.get_position_for_symbol(symbol)
|
||||
|
||||
if position is None:
|
||||
self.state.clear_position()
|
||||
return
|
||||
|
||||
pos_state = PositionState(
|
||||
trade_id=position.trade_id,
|
||||
symbol=position.symbol,
|
||||
side=position.side,
|
||||
entry_price=position.entry_price,
|
||||
current_price=position.current_price,
|
||||
size=position.size,
|
||||
size_usdt=position.size_usdt,
|
||||
unrealized_pnl=position.unrealized_pnl,
|
||||
unrealized_pnl_pct=position.unrealized_pnl_pct,
|
||||
stop_loss_price=position.stop_loss_price,
|
||||
take_profit_price=position.take_profit_price,
|
||||
)
|
||||
self.state.set_position(pos_state)
|
||||
|
||||
def _update_account_state(self) -> None:
|
||||
"""Update shared state with account information."""
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
try:
|
||||
balance = self.okx_client.get_balance()
|
||||
self.state.update_account(
|
||||
balance=balance.get('total', 0.0),
|
||||
available=balance.get('free', 0.0),
|
||||
leverage=self.trading_config.leverage,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to update account state: {e}")
|
||||
|
||||
def _execute_entry(self, signal: dict, current_price: float) -> None:
|
||||
"""Execute entry trade."""
|
||||
symbol = self.trading_config.eth_symbol
|
||||
@@ -191,11 +309,15 @@ class LiveTradingBot:
|
||||
# Get account balance
|
||||
balance = self.okx_client.get_balance()
|
||||
available_usdt = balance['free']
|
||||
self.logger.info(f"Account balance: ${available_usdt:.2f} USDT available")
|
||||
|
||||
# Calculate position size
|
||||
size_usdt = self.strategy.calculate_position_size(signal, available_usdt)
|
||||
if size_usdt <= 0:
|
||||
self.logger.info("Position size too small, skipping entry")
|
||||
self.logger.info(
|
||||
f"Position size too small (${size_usdt:.2f}), skipping entry. "
|
||||
f"Min required: ${self.strategy.config.min_position_usdt:.2f}"
|
||||
)
|
||||
return
|
||||
|
||||
size_eth = size_usdt / current_price
|
||||
@@ -290,22 +412,30 @@ class LiveTradingBot:
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exit execution failed: {e}", exc_info=True)
|
||||
|
||||
def _is_running(self) -> bool:
|
||||
"""Check if bot should continue running."""
|
||||
if not self.running:
|
||||
return False
|
||||
if self.state and not self.state.is_running():
|
||||
return False
|
||||
return True
|
||||
|
||||
def run(self) -> None:
|
||||
"""Main trading loop."""
|
||||
self.logger.info("Starting trading loop...")
|
||||
|
||||
while self.running:
|
||||
while self._is_running():
|
||||
try:
|
||||
self.run_trading_cycle()
|
||||
|
||||
if self.running:
|
||||
if self._is_running():
|
||||
sleep_seconds = self.trading_config.sleep_seconds
|
||||
minutes = sleep_seconds // 60
|
||||
self.logger.info(f"Sleeping for {minutes} minutes...")
|
||||
|
||||
# Sleep in smaller chunks to allow faster shutdown
|
||||
for _ in range(sleep_seconds):
|
||||
if not self.running:
|
||||
if not self._is_running():
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
@@ -319,6 +449,8 @@ class LiveTradingBot:
|
||||
# Cleanup
|
||||
self.logger.info("Shutting down...")
|
||||
self.position_manager.save_positions()
|
||||
if self.db:
|
||||
self.db.close()
|
||||
self.logger.info("Shutdown complete")
|
||||
|
||||
|
||||
@@ -350,6 +482,11 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-ui",
|
||||
action="store_true",
|
||||
help="Run in headless mode without terminal UI"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -370,19 +507,64 @@ def main():
|
||||
if args.live:
|
||||
okx_config.demo_mode = False
|
||||
|
||||
# Determine if UI should be enabled
|
||||
use_ui = not args.no_ui and sys.stdin.isatty()
|
||||
|
||||
# Initialize database
|
||||
db_path = path_config.base_dir / "live_trading" / "trading.db"
|
||||
db = init_db(db_path)
|
||||
|
||||
# Run migrations (imports CSV if exists)
|
||||
run_migrations(db, path_config.trade_log_file)
|
||||
|
||||
# Initialize UI components if enabled
|
||||
log_queue: Optional[queue.Queue] = None
|
||||
shared_state: Optional[SharedState] = None
|
||||
dashboard: Optional[TradingDashboard] = None
|
||||
|
||||
if use_ui:
|
||||
log_queue = queue.Queue(maxsize=1000)
|
||||
shared_state = SharedState()
|
||||
|
||||
# Setup logging
|
||||
logger = setup_logging(path_config.logs_dir)
|
||||
logger = setup_logging(path_config.logs_dir, log_queue)
|
||||
|
||||
try:
|
||||
# Create and run bot
|
||||
bot = LiveTradingBot(okx_config, trading_config, path_config)
|
||||
# Create bot
|
||||
bot = LiveTradingBot(
|
||||
okx_config,
|
||||
trading_config,
|
||||
path_config,
|
||||
database=db,
|
||||
shared_state=shared_state,
|
||||
)
|
||||
|
||||
# Start dashboard if UI enabled
|
||||
if use_ui and shared_state and log_queue:
|
||||
dashboard = TradingDashboard(
|
||||
state=shared_state,
|
||||
db=db,
|
||||
log_queue=log_queue,
|
||||
on_quit=lambda: setattr(bot, 'running', False),
|
||||
)
|
||||
dashboard.start()
|
||||
logger.info("Dashboard started")
|
||||
|
||||
# Run bot
|
||||
bot.run()
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Configuration error: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
finally:
|
||||
# Cleanup
|
||||
if dashboard:
|
||||
dashboard.stop()
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -3,16 +3,21 @@ Position Manager for Live Trading.
|
||||
|
||||
Tracks open positions, manages risk, and handles SL/TP logic.
|
||||
"""
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
from .okx_client import OKXClient
|
||||
from .config import TradingConfig, PathConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db.database import TradingDatabase
|
||||
from .db.models import Trade
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -78,11 +83,13 @@ class PositionManager:
|
||||
self,
|
||||
okx_client: OKXClient,
|
||||
trading_config: TradingConfig,
|
||||
path_config: PathConfig
|
||||
path_config: PathConfig,
|
||||
database: Optional["TradingDatabase"] = None,
|
||||
):
|
||||
self.client = okx_client
|
||||
self.config = trading_config
|
||||
self.paths = path_config
|
||||
self.db = database
|
||||
self.positions: dict[str, Position] = {}
|
||||
self.trade_log: list[dict] = []
|
||||
self._load_positions()
|
||||
@@ -249,16 +256,55 @@ class PositionManager:
|
||||
return trade_record
|
||||
|
||||
def _append_trade_log(self, trade_record: dict) -> None:
|
||||
"""Append trade record to CSV log file."""
|
||||
import csv
|
||||
"""Append trade record to CSV and SQLite database."""
|
||||
# Write to CSV (backup/compatibility)
|
||||
self._append_trade_csv(trade_record)
|
||||
|
||||
# Write to SQLite (primary)
|
||||
self._append_trade_db(trade_record)
|
||||
|
||||
def _append_trade_csv(self, trade_record: dict) -> None:
|
||||
"""Append trade record to CSV log file."""
|
||||
file_exists = self.paths.trade_log_file.exists()
|
||||
|
||||
with open(self.paths.trade_log_file, 'a', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
writer.writerow(trade_record)
|
||||
try:
|
||||
with open(self.paths.trade_log_file, 'a', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
writer.writerow(trade_record)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write trade to CSV: {e}")
|
||||
|
||||
def _append_trade_db(self, trade_record: dict) -> None:
|
||||
"""Append trade record to SQLite database."""
|
||||
if self.db is None:
|
||||
return
|
||||
|
||||
try:
|
||||
from .db.models import Trade
|
||||
|
||||
trade = Trade(
|
||||
trade_id=trade_record['trade_id'],
|
||||
symbol=trade_record['symbol'],
|
||||
side=trade_record['side'],
|
||||
entry_price=trade_record['entry_price'],
|
||||
exit_price=trade_record.get('exit_price'),
|
||||
size=trade_record['size'],
|
||||
size_usdt=trade_record['size_usdt'],
|
||||
pnl_usd=trade_record.get('pnl_usd'),
|
||||
pnl_pct=trade_record.get('pnl_pct'),
|
||||
entry_time=trade_record['entry_time'],
|
||||
exit_time=trade_record.get('exit_time'),
|
||||
hold_duration_hours=trade_record.get('hold_duration_hours'),
|
||||
reason=trade_record.get('reason'),
|
||||
order_id_entry=trade_record.get('order_id_entry'),
|
||||
order_id_exit=trade_record.get('order_id_exit'),
|
||||
)
|
||||
self.db.insert_trade(trade)
|
||||
logger.debug(f"Trade {trade.trade_id} saved to database")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write trade to database: {e}")
|
||||
|
||||
def update_positions(self, current_prices: dict[str, float]) -> list[dict]:
|
||||
"""
|
||||
|
||||
10
live_trading/ui/__init__.py
Normal file
10
live_trading/ui/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Terminal UI module for live trading dashboard."""
|
||||
from .dashboard import TradingDashboard
|
||||
from .state import SharedState
|
||||
from .log_handler import UILogHandler
|
||||
|
||||
__all__ = [
|
||||
"TradingDashboard",
|
||||
"SharedState",
|
||||
"UILogHandler",
|
||||
]
|
||||
240
live_trading/ui/dashboard.py
Normal file
240
live_trading/ui/dashboard.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Main trading dashboard UI orchestration."""
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, Callable
|
||||
|
||||
from rich.console import Console
|
||||
from rich.layout import Layout
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from .state import SharedState
|
||||
from .log_handler import LogBuffer, UILogHandler
|
||||
from .keyboard import KeyboardHandler
|
||||
from .panels import (
|
||||
HeaderPanel,
|
||||
TabBar,
|
||||
LogPanel,
|
||||
HelpBar,
|
||||
build_summary_panel,
|
||||
)
|
||||
from ..db.database import TradingDatabase
|
||||
from ..db.metrics import MetricsCalculator, PeriodMetrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TradingDashboard:
|
||||
"""
|
||||
Main trading dashboard orchestrator.
|
||||
|
||||
Runs in a separate thread and provides real-time UI updates
|
||||
while the trading loop runs in the main thread.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state: SharedState,
|
||||
db: TradingDatabase,
|
||||
log_queue: queue.Queue,
|
||||
on_quit: Optional[Callable] = None,
|
||||
):
|
||||
self.state = state
|
||||
self.db = db
|
||||
self.log_queue = log_queue
|
||||
self.on_quit = on_quit
|
||||
|
||||
self.console = Console()
|
||||
self.log_buffer = LogBuffer(max_entries=1000)
|
||||
self.keyboard = KeyboardHandler()
|
||||
self.metrics_calculator = MetricsCalculator(db)
|
||||
|
||||
self._running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._active_tab = 0
|
||||
self._cached_metrics: dict[int, PeriodMetrics] = {}
|
||||
self._last_metrics_refresh = 0.0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the dashboard in a separate thread."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
logger.debug("Dashboard thread started")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dashboard."""
|
||||
self._running = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=2.0)
|
||||
logger.debug("Dashboard thread stopped")
|
||||
|
||||
def _run(self) -> None:
|
||||
"""Main dashboard loop."""
|
||||
try:
|
||||
with self.keyboard:
|
||||
with Live(
|
||||
self._build_layout(),
|
||||
console=self.console,
|
||||
refresh_per_second=1,
|
||||
screen=True,
|
||||
) as live:
|
||||
while self._running and self.state.is_running():
|
||||
# Process keyboard input
|
||||
action = self.keyboard.get_action(timeout=0.1)
|
||||
if action:
|
||||
self._handle_action(action)
|
||||
|
||||
# Drain log queue
|
||||
self.log_buffer.drain_queue(self.log_queue)
|
||||
|
||||
# Refresh metrics periodically (every 5 seconds)
|
||||
now = time.time()
|
||||
if now - self._last_metrics_refresh > 5.0:
|
||||
self._refresh_metrics()
|
||||
self._last_metrics_refresh = now
|
||||
|
||||
# Update display
|
||||
live.update(self._build_layout())
|
||||
|
||||
# Small sleep to prevent CPU spinning
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dashboard error: {e}", exc_info=True)
|
||||
finally:
|
||||
self._running = False
|
||||
|
||||
def _handle_action(self, action: str) -> None:
|
||||
"""Handle keyboard action."""
|
||||
if action == "quit":
|
||||
logger.info("Quit requested from UI")
|
||||
self.state.stop()
|
||||
if self.on_quit:
|
||||
self.on_quit()
|
||||
|
||||
elif action == "refresh":
|
||||
self._refresh_metrics()
|
||||
logger.debug("Manual refresh triggered")
|
||||
|
||||
elif action == "filter":
|
||||
new_filter = self.log_buffer.cycle_filter()
|
||||
logger.debug(f"Log filter changed to: {new_filter}")
|
||||
|
||||
elif action == "filter_trades":
|
||||
self.log_buffer.set_filter(LogBuffer.FILTER_TRADES)
|
||||
logger.debug("Log filter set to: trades")
|
||||
|
||||
elif action == "filter_all":
|
||||
self.log_buffer.set_filter(LogBuffer.FILTER_ALL)
|
||||
logger.debug("Log filter set to: all")
|
||||
|
||||
elif action == "filter_errors":
|
||||
self.log_buffer.set_filter(LogBuffer.FILTER_ERRORS)
|
||||
logger.debug("Log filter set to: errors")
|
||||
|
||||
elif action == "tab_general":
|
||||
self._active_tab = 0
|
||||
elif action == "tab_monthly":
|
||||
if self._has_monthly_data():
|
||||
self._active_tab = 1
|
||||
elif action == "tab_weekly":
|
||||
if self._has_weekly_data():
|
||||
self._active_tab = 2
|
||||
elif action == "tab_daily":
|
||||
self._active_tab = 3
|
||||
|
||||
def _refresh_metrics(self) -> None:
|
||||
"""Refresh metrics from database."""
|
||||
try:
|
||||
self._cached_metrics[0] = self.metrics_calculator.get_all_time_metrics()
|
||||
self._cached_metrics[1] = self.metrics_calculator.get_monthly_metrics()
|
||||
self._cached_metrics[2] = self.metrics_calculator.get_weekly_metrics()
|
||||
self._cached_metrics[3] = self.metrics_calculator.get_daily_metrics()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to refresh metrics: {e}")
|
||||
|
||||
def _has_monthly_data(self) -> bool:
|
||||
"""Check if monthly tab should be shown."""
|
||||
try:
|
||||
return self.metrics_calculator.has_monthly_data()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _has_weekly_data(self) -> bool:
|
||||
"""Check if weekly tab should be shown."""
|
||||
try:
|
||||
return self.metrics_calculator.has_weekly_data()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _build_layout(self) -> Layout:
|
||||
"""Build the complete dashboard layout."""
|
||||
layout = Layout()
|
||||
|
||||
# Calculate available height
|
||||
term_height = self.console.height or 40
|
||||
|
||||
# Header takes 3 lines
|
||||
# Help bar takes 1 line
|
||||
# Summary panel takes about 12-14 lines
|
||||
# Rest goes to logs
|
||||
log_height = max(8, term_height - 20)
|
||||
|
||||
layout.split_column(
|
||||
Layout(name="header", size=3),
|
||||
Layout(name="summary", size=14),
|
||||
Layout(name="logs", size=log_height),
|
||||
Layout(name="help", size=1),
|
||||
)
|
||||
|
||||
# Header
|
||||
layout["header"].update(HeaderPanel(self.state).render())
|
||||
|
||||
# Summary panel with tabs
|
||||
current_metrics = self._cached_metrics.get(self._active_tab)
|
||||
tab_bar = TabBar(active_tab=self._active_tab)
|
||||
|
||||
layout["summary"].update(
|
||||
build_summary_panel(
|
||||
state=self.state,
|
||||
metrics=current_metrics,
|
||||
tab_bar=tab_bar,
|
||||
has_monthly=self._has_monthly_data(),
|
||||
has_weekly=self._has_weekly_data(),
|
||||
)
|
||||
)
|
||||
|
||||
# Log panel
|
||||
layout["logs"].update(LogPanel(self.log_buffer).render(height=log_height))
|
||||
|
||||
# Help bar
|
||||
layout["help"].update(HelpBar().render())
|
||||
|
||||
return layout
|
||||
|
||||
|
||||
def setup_ui_logging(log_queue: queue.Queue) -> UILogHandler:
|
||||
"""
|
||||
Set up logging to capture messages for UI.
|
||||
|
||||
Args:
|
||||
log_queue: Queue to send log messages to
|
||||
|
||||
Returns:
|
||||
UILogHandler instance
|
||||
"""
|
||||
handler = UILogHandler(log_queue)
|
||||
handler.setLevel(logging.INFO)
|
||||
|
||||
# Add handler to root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
return handler
|
||||
128
live_trading/ui/keyboard.py
Normal file
128
live_trading/ui/keyboard.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Keyboard input handling for terminal UI."""
|
||||
import sys
|
||||
import select
|
||||
import termios
|
||||
import tty
|
||||
from typing import Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeyAction:
|
||||
"""Represents a keyboard action."""
|
||||
|
||||
key: str
|
||||
action: str
|
||||
description: str
|
||||
|
||||
|
||||
class KeyboardHandler:
|
||||
"""
|
||||
Non-blocking keyboard input handler.
|
||||
|
||||
Uses terminal raw mode to capture single keypresses
|
||||
without waiting for Enter.
|
||||
"""
|
||||
|
||||
# Key mappings
|
||||
ACTIONS = {
|
||||
"q": "quit",
|
||||
"Q": "quit",
|
||||
"\x03": "quit", # Ctrl+C
|
||||
"r": "refresh",
|
||||
"R": "refresh",
|
||||
"f": "filter",
|
||||
"F": "filter",
|
||||
"t": "filter_trades",
|
||||
"T": "filter_trades",
|
||||
"l": "filter_all",
|
||||
"L": "filter_all",
|
||||
"e": "filter_errors",
|
||||
"E": "filter_errors",
|
||||
"1": "tab_general",
|
||||
"2": "tab_monthly",
|
||||
"3": "tab_weekly",
|
||||
"4": "tab_daily",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._old_settings = None
|
||||
self._enabled = False
|
||||
|
||||
def enable(self) -> bool:
|
||||
"""
|
||||
Enable raw keyboard input mode.
|
||||
|
||||
Returns:
|
||||
True if enabled successfully
|
||||
"""
|
||||
try:
|
||||
if not sys.stdin.isatty():
|
||||
return False
|
||||
|
||||
self._old_settings = termios.tcgetattr(sys.stdin)
|
||||
tty.setcbreak(sys.stdin.fileno())
|
||||
self._enabled = True
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Restore normal terminal mode."""
|
||||
if self._enabled and self._old_settings:
|
||||
try:
|
||||
termios.tcsetattr(
|
||||
sys.stdin,
|
||||
termios.TCSADRAIN,
|
||||
self._old_settings,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._enabled = False
|
||||
|
||||
def get_key(self, timeout: float = 0.1) -> Optional[str]:
|
||||
"""
|
||||
Get a keypress if available (non-blocking).
|
||||
|
||||
Args:
|
||||
timeout: Seconds to wait for input
|
||||
|
||||
Returns:
|
||||
Key character or None if no input
|
||||
"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
readable, _, _ = select.select([sys.stdin], [], [], timeout)
|
||||
if readable:
|
||||
return sys.stdin.read(1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def get_action(self, timeout: float = 0.1) -> Optional[str]:
|
||||
"""
|
||||
Get action name for pressed key.
|
||||
|
||||
Args:
|
||||
timeout: Seconds to wait for input
|
||||
|
||||
Returns:
|
||||
Action name or None
|
||||
"""
|
||||
key = self.get_key(timeout)
|
||||
if key:
|
||||
return self.ACTIONS.get(key)
|
||||
return None
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
self.enable()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.disable()
|
||||
return False
|
||||
178
live_trading/ui/log_handler.py
Normal file
178
live_trading/ui/log_handler.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Custom logging handler for UI integration."""
|
||||
import logging
|
||||
import queue
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from collections import deque
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogEntry:
|
||||
"""A single log entry."""
|
||||
|
||||
timestamp: str
|
||||
level: str
|
||||
message: str
|
||||
logger_name: str
|
||||
|
||||
@property
|
||||
def level_color(self) -> str:
|
||||
"""Get Rich color for log level."""
|
||||
colors = {
|
||||
"DEBUG": "dim",
|
||||
"INFO": "white",
|
||||
"WARNING": "yellow",
|
||||
"ERROR": "red",
|
||||
"CRITICAL": "bold red",
|
||||
}
|
||||
return colors.get(self.level, "white")
|
||||
|
||||
|
||||
class UILogHandler(logging.Handler):
|
||||
"""
|
||||
Custom logging handler that sends logs to UI.
|
||||
|
||||
Uses a thread-safe queue to pass log entries from the trading
|
||||
thread to the UI thread.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_queue: queue.Queue,
|
||||
max_entries: int = 1000,
|
||||
):
|
||||
super().__init__()
|
||||
self.log_queue = log_queue
|
||||
self.max_entries = max_entries
|
||||
self.setFormatter(
|
||||
logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
||||
)
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
"""Emit a log record to the queue."""
|
||||
try:
|
||||
entry = LogEntry(
|
||||
timestamp=datetime.fromtimestamp(record.created).strftime(
|
||||
"%H:%M:%S"
|
||||
),
|
||||
level=record.levelname,
|
||||
message=self.format_message(record),
|
||||
logger_name=record.name,
|
||||
)
|
||||
# Non-blocking put, drop if queue is full
|
||||
try:
|
||||
self.log_queue.put_nowait(entry)
|
||||
except queue.Full:
|
||||
pass
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
def format_message(self, record: logging.LogRecord) -> str:
|
||||
"""Format the log message."""
|
||||
return record.getMessage()
|
||||
|
||||
|
||||
class LogBuffer:
|
||||
"""
|
||||
Thread-safe buffer for log entries with filtering support.
|
||||
|
||||
Maintains a fixed-size buffer of log entries and supports
|
||||
filtering by log type.
|
||||
"""
|
||||
|
||||
FILTER_ALL = "all"
|
||||
FILTER_ERRORS = "errors"
|
||||
FILTER_TRADES = "trades"
|
||||
FILTER_SIGNALS = "signals"
|
||||
|
||||
FILTERS = [FILTER_ALL, FILTER_ERRORS, FILTER_TRADES, FILTER_SIGNALS]
|
||||
|
||||
def __init__(self, max_entries: int = 1000):
|
||||
self.max_entries = max_entries
|
||||
self._entries: deque[LogEntry] = deque(maxlen=max_entries)
|
||||
self._current_filter = self.FILTER_ALL
|
||||
|
||||
def add(self, entry: LogEntry) -> None:
|
||||
"""Add a log entry to the buffer."""
|
||||
self._entries.append(entry)
|
||||
|
||||
def get_filtered(self, limit: int = 50) -> list[LogEntry]:
|
||||
"""
|
||||
Get filtered log entries.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of entries to return
|
||||
|
||||
Returns:
|
||||
List of filtered LogEntry objects (most recent first)
|
||||
"""
|
||||
entries = list(self._entries)
|
||||
|
||||
if self._current_filter == self.FILTER_ERRORS:
|
||||
entries = [e for e in entries if e.level in ("ERROR", "CRITICAL")]
|
||||
elif self._current_filter == self.FILTER_TRADES:
|
||||
# Key terms indicating actual trading activity
|
||||
include_keywords = [
|
||||
"order", "entry", "exit", "executed", "filled",
|
||||
"opening", "closing", "position opened", "position closed"
|
||||
]
|
||||
# Terms to exclude (noise)
|
||||
exclude_keywords = [
|
||||
"sync complete", "0 positions", "portfolio: 0 positions"
|
||||
]
|
||||
|
||||
entries = [
|
||||
e for e in entries
|
||||
if any(kw in e.message.lower() for kw in include_keywords)
|
||||
and not any(ex in e.message.lower() for ex in exclude_keywords)
|
||||
]
|
||||
elif self._current_filter == self.FILTER_SIGNALS:
|
||||
signal_keywords = ["signal", "z_score", "prob", "z="]
|
||||
entries = [
|
||||
e for e in entries
|
||||
if any(kw in e.message.lower() for kw in signal_keywords)
|
||||
]
|
||||
|
||||
# Return most recent entries
|
||||
return list(reversed(entries[-limit:]))
|
||||
|
||||
def set_filter(self, filter_name: str) -> None:
|
||||
"""Set a specific filter."""
|
||||
if filter_name in self.FILTERS:
|
||||
self._current_filter = filter_name
|
||||
|
||||
def cycle_filter(self) -> str:
|
||||
"""Cycle to next filter and return its name."""
|
||||
current_idx = self.FILTERS.index(self._current_filter)
|
||||
next_idx = (current_idx + 1) % len(self.FILTERS)
|
||||
self._current_filter = self.FILTERS[next_idx]
|
||||
return self._current_filter
|
||||
|
||||
def get_current_filter(self) -> str:
|
||||
"""Get current filter name."""
|
||||
return self._current_filter
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all log entries."""
|
||||
self._entries.clear()
|
||||
|
||||
def drain_queue(self, log_queue: queue.Queue) -> int:
|
||||
"""
|
||||
Drain log entries from queue into buffer.
|
||||
|
||||
Args:
|
||||
log_queue: Queue to drain from
|
||||
|
||||
Returns:
|
||||
Number of entries drained
|
||||
"""
|
||||
count = 0
|
||||
while True:
|
||||
try:
|
||||
entry = log_queue.get_nowait()
|
||||
self.add(entry)
|
||||
count += 1
|
||||
except queue.Empty:
|
||||
break
|
||||
return count
|
||||
399
live_trading/ui/panels.py
Normal file
399
live_trading/ui/panels.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""UI panel components using Rich."""
|
||||
from typing import Optional
|
||||
|
||||
from rich.console import Console, Group
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
from rich.layout import Layout
|
||||
|
||||
from .state import SharedState, PositionState, StrategyState, AccountState
|
||||
from .log_handler import LogBuffer, LogEntry
|
||||
from ..db.metrics import PeriodMetrics
|
||||
|
||||
|
||||
def format_pnl(value: float, include_sign: bool = True) -> Text:
|
||||
"""Format PnL value with color."""
|
||||
if value > 0:
|
||||
sign = "+" if include_sign else ""
|
||||
return Text(f"{sign}${value:.2f}", style="green")
|
||||
elif value < 0:
|
||||
return Text(f"${value:.2f}", style="red")
|
||||
else:
|
||||
return Text(f"${value:.2f}", style="white")
|
||||
|
||||
|
||||
def format_pct(value: float, include_sign: bool = True) -> Text:
|
||||
"""Format percentage value with color."""
|
||||
if value > 0:
|
||||
sign = "+" if include_sign else ""
|
||||
return Text(f"{sign}{value:.2f}%", style="green")
|
||||
elif value < 0:
|
||||
return Text(f"{value:.2f}%", style="red")
|
||||
else:
|
||||
return Text(f"{value:.2f}%", style="white")
|
||||
|
||||
|
||||
def format_side(side: str) -> Text:
|
||||
"""Format position side with color."""
|
||||
if side.lower() == "long":
|
||||
return Text("LONG", style="bold green")
|
||||
else:
|
||||
return Text("SHORT", style="bold red")
|
||||
|
||||
|
||||
class HeaderPanel:
|
||||
"""Header panel with title and mode indicator."""
|
||||
|
||||
def __init__(self, state: SharedState):
|
||||
self.state = state
|
||||
|
||||
def render(self) -> Panel:
|
||||
"""Render the header panel."""
|
||||
mode = self.state.get_mode()
|
||||
eth_symbol, _ = self.state.get_symbols()
|
||||
|
||||
mode_style = "yellow" if mode == "DEMO" else "bold red"
|
||||
mode_text = Text(f"[{mode}]", style=mode_style)
|
||||
|
||||
title = Text()
|
||||
title.append("REGIME REVERSION STRATEGY - LIVE TRADING", style="bold white")
|
||||
title.append(" ")
|
||||
title.append(mode_text)
|
||||
title.append(" ")
|
||||
title.append(eth_symbol, style="cyan")
|
||||
|
||||
return Panel(title, style="blue", height=3)
|
||||
|
||||
|
||||
class TabBar:
|
||||
"""Tab bar for period selection."""
|
||||
|
||||
TABS = ["1:General", "2:Monthly", "3:Weekly", "4:Daily"]
|
||||
|
||||
def __init__(self, active_tab: int = 0):
|
||||
self.active_tab = active_tab
|
||||
|
||||
def render(
|
||||
self,
|
||||
has_monthly: bool = True,
|
||||
has_weekly: bool = True,
|
||||
) -> Text:
|
||||
"""Render the tab bar."""
|
||||
text = Text()
|
||||
text.append(" ")
|
||||
|
||||
for i, tab in enumerate(self.TABS):
|
||||
# Check if tab should be shown
|
||||
if i == 1 and not has_monthly:
|
||||
continue
|
||||
if i == 2 and not has_weekly:
|
||||
continue
|
||||
|
||||
if i == self.active_tab:
|
||||
text.append(f"[{tab}]", style="bold white on blue")
|
||||
else:
|
||||
text.append(f"[{tab}]", style="dim")
|
||||
text.append(" ")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class MetricsPanel:
|
||||
"""Panel showing trading metrics."""
|
||||
|
||||
def __init__(self, metrics: Optional[PeriodMetrics] = None):
|
||||
self.metrics = metrics
|
||||
|
||||
def render(self) -> Table:
|
||||
"""Render metrics as a table."""
|
||||
table = Table(
|
||||
show_header=False,
|
||||
show_edge=False,
|
||||
box=None,
|
||||
padding=(0, 1),
|
||||
)
|
||||
table.add_column("Label", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
if self.metrics is None or self.metrics.total_trades == 0:
|
||||
table.add_row("Status", Text("No trade data", style="dim"))
|
||||
return table
|
||||
|
||||
m = self.metrics
|
||||
|
||||
table.add_row("Total PnL:", format_pnl(m.total_pnl))
|
||||
table.add_row("Win Rate:", Text(f"{m.win_rate:.1f}%", style="white"))
|
||||
table.add_row("Total Trades:", Text(str(m.total_trades), style="white"))
|
||||
table.add_row(
|
||||
"Win/Loss:",
|
||||
Text(f"{m.winning_trades}/{m.losing_trades}", style="white"),
|
||||
)
|
||||
table.add_row(
|
||||
"Avg Duration:",
|
||||
Text(f"{m.avg_trade_duration_hours:.1f}h", style="white"),
|
||||
)
|
||||
table.add_row("Max Drawdown:", format_pnl(-m.max_drawdown))
|
||||
table.add_row("Best Trade:", format_pnl(m.best_trade))
|
||||
table.add_row("Worst Trade:", format_pnl(m.worst_trade))
|
||||
|
||||
return table
|
||||
|
||||
|
||||
class PositionPanel:
|
||||
"""Panel showing current position."""
|
||||
|
||||
def __init__(self, position: Optional[PositionState] = None):
|
||||
self.position = position
|
||||
|
||||
def render(self) -> Table:
|
||||
"""Render position as a table."""
|
||||
table = Table(
|
||||
show_header=False,
|
||||
show_edge=False,
|
||||
box=None,
|
||||
padding=(0, 1),
|
||||
)
|
||||
table.add_column("Label", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
if self.position is None:
|
||||
table.add_row("Status", Text("No open position", style="dim"))
|
||||
return table
|
||||
|
||||
p = self.position
|
||||
|
||||
table.add_row("Side:", format_side(p.side))
|
||||
table.add_row("Entry:", Text(f"${p.entry_price:.2f}", style="white"))
|
||||
table.add_row("Current:", Text(f"${p.current_price:.2f}", style="white"))
|
||||
|
||||
# Unrealized PnL
|
||||
pnl_text = Text()
|
||||
pnl_text.append_text(format_pnl(p.unrealized_pnl))
|
||||
pnl_text.append(" (")
|
||||
pnl_text.append_text(format_pct(p.unrealized_pnl_pct))
|
||||
pnl_text.append(")")
|
||||
table.add_row("Unrealized:", pnl_text)
|
||||
|
||||
table.add_row("Size:", Text(f"${p.size_usdt:.2f}", style="white"))
|
||||
|
||||
# SL/TP
|
||||
if p.side == "long":
|
||||
sl_dist = (p.stop_loss_price / p.entry_price - 1) * 100
|
||||
tp_dist = (p.take_profit_price / p.entry_price - 1) * 100
|
||||
else:
|
||||
sl_dist = (1 - p.stop_loss_price / p.entry_price) * 100
|
||||
tp_dist = (1 - p.take_profit_price / p.entry_price) * 100
|
||||
|
||||
sl_text = Text(f"${p.stop_loss_price:.2f} ({sl_dist:+.1f}%)", style="red")
|
||||
tp_text = Text(f"${p.take_profit_price:.2f} ({tp_dist:+.1f}%)", style="green")
|
||||
|
||||
table.add_row("Stop Loss:", sl_text)
|
||||
table.add_row("Take Profit:", tp_text)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
class AccountPanel:
|
||||
"""Panel showing account information."""
|
||||
|
||||
def __init__(self, account: Optional[AccountState] = None):
|
||||
self.account = account
|
||||
|
||||
def render(self) -> Table:
|
||||
"""Render account info as a table."""
|
||||
table = Table(
|
||||
show_header=False,
|
||||
show_edge=False,
|
||||
box=None,
|
||||
padding=(0, 1),
|
||||
)
|
||||
table.add_column("Label", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
if self.account is None:
|
||||
table.add_row("Status", Text("Loading...", style="dim"))
|
||||
return table
|
||||
|
||||
a = self.account
|
||||
|
||||
table.add_row("Balance:", Text(f"${a.balance:.2f}", style="white"))
|
||||
table.add_row("Available:", Text(f"${a.available:.2f}", style="white"))
|
||||
table.add_row("Leverage:", Text(f"{a.leverage}x", style="cyan"))
|
||||
|
||||
return table
|
||||
|
||||
|
||||
class StrategyPanel:
|
||||
"""Panel showing strategy state."""
|
||||
|
||||
def __init__(self, strategy: Optional[StrategyState] = None):
|
||||
self.strategy = strategy
|
||||
|
||||
def render(self) -> Table:
|
||||
"""Render strategy state as a table."""
|
||||
table = Table(
|
||||
show_header=False,
|
||||
show_edge=False,
|
||||
box=None,
|
||||
padding=(0, 1),
|
||||
)
|
||||
table.add_column("Label", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
if self.strategy is None:
|
||||
table.add_row("Status", Text("Waiting...", style="dim"))
|
||||
return table
|
||||
|
||||
s = self.strategy
|
||||
|
||||
# Z-score with color based on threshold
|
||||
z_style = "white"
|
||||
if abs(s.z_score) > 1.0:
|
||||
z_style = "yellow"
|
||||
if abs(s.z_score) > 1.5:
|
||||
z_style = "bold yellow"
|
||||
table.add_row("Z-Score:", Text(f"{s.z_score:.2f}", style=z_style))
|
||||
|
||||
# Probability with color
|
||||
prob_style = "white"
|
||||
if s.probability > 0.5:
|
||||
prob_style = "green"
|
||||
if s.probability > 0.7:
|
||||
prob_style = "bold green"
|
||||
table.add_row("Probability:", Text(f"{s.probability:.2f}", style=prob_style))
|
||||
|
||||
# Funding rate
|
||||
funding_style = "green" if s.funding_rate >= 0 else "red"
|
||||
table.add_row(
|
||||
"Funding:",
|
||||
Text(f"{s.funding_rate:.4f}", style=funding_style),
|
||||
)
|
||||
|
||||
# Last action
|
||||
action_style = "white"
|
||||
if s.last_action == "entry":
|
||||
action_style = "bold cyan"
|
||||
elif s.last_action == "check_exit":
|
||||
action_style = "yellow"
|
||||
table.add_row("Last Action:", Text(s.last_action, style=action_style))
|
||||
|
||||
return table
|
||||
|
||||
|
||||
class LogPanel:
|
||||
"""Panel showing log entries."""
|
||||
|
||||
def __init__(self, log_buffer: LogBuffer):
|
||||
self.log_buffer = log_buffer
|
||||
|
||||
def render(self, height: int = 10) -> Panel:
|
||||
"""Render log panel."""
|
||||
filter_name = self.log_buffer.get_current_filter().title()
|
||||
entries = self.log_buffer.get_filtered(limit=height - 2)
|
||||
|
||||
lines = []
|
||||
for entry in entries:
|
||||
line = Text()
|
||||
line.append(f"{entry.timestamp} ", style="dim")
|
||||
line.append(f"[{entry.level}] ", style=entry.level_color)
|
||||
line.append(entry.message)
|
||||
lines.append(line)
|
||||
|
||||
if not lines:
|
||||
lines.append(Text("No logs to display", style="dim"))
|
||||
|
||||
content = Group(*lines)
|
||||
|
||||
# Build "tabbed" title
|
||||
tabs = []
|
||||
|
||||
# All Logs tab
|
||||
if filter_name == "All":
|
||||
tabs.append("[bold white on blue] [L]ogs [/]")
|
||||
else:
|
||||
tabs.append("[dim] [L]ogs [/]")
|
||||
|
||||
# Trades tab
|
||||
if filter_name == "Trades":
|
||||
tabs.append("[bold white on blue] [T]rades [/]")
|
||||
else:
|
||||
tabs.append("[dim] [T]rades [/]")
|
||||
|
||||
# Errors tab
|
||||
if filter_name == "Errors":
|
||||
tabs.append("[bold white on blue] [E]rrors [/]")
|
||||
else:
|
||||
tabs.append("[dim] [E]rrors [/]")
|
||||
|
||||
title = " ".join(tabs)
|
||||
subtitle = "Press 'l', 't', 'e' to switch tabs"
|
||||
|
||||
return Panel(
|
||||
content,
|
||||
title=title,
|
||||
subtitle=subtitle,
|
||||
title_align="left",
|
||||
subtitle_align="right",
|
||||
border_style="blue",
|
||||
)
|
||||
|
||||
|
||||
class HelpBar:
|
||||
"""Bottom help bar with keyboard shortcuts."""
|
||||
|
||||
def render(self) -> Text:
|
||||
"""Render help bar."""
|
||||
text = Text()
|
||||
text.append(" [q]", style="bold")
|
||||
text.append("Quit ", style="dim")
|
||||
text.append("[r]", style="bold")
|
||||
text.append("Refresh ", style="dim")
|
||||
text.append("[1-4]", style="bold")
|
||||
text.append("Tabs ", style="dim")
|
||||
text.append("[l/t/e]", style="bold")
|
||||
text.append("LogView", style="dim")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def build_summary_panel(
|
||||
state: SharedState,
|
||||
metrics: Optional[PeriodMetrics],
|
||||
tab_bar: TabBar,
|
||||
has_monthly: bool,
|
||||
has_weekly: bool,
|
||||
) -> Panel:
|
||||
"""Build the complete summary panel with all sections."""
|
||||
# Create layout for summary content
|
||||
layout = Layout()
|
||||
|
||||
# Tab bar at top
|
||||
tabs = tab_bar.render(has_monthly, has_weekly)
|
||||
|
||||
# Create tables for each section
|
||||
metrics_table = MetricsPanel(metrics).render()
|
||||
position_table = PositionPanel(state.get_position()).render()
|
||||
account_table = AccountPanel(state.get_account()).render()
|
||||
strategy_table = StrategyPanel(state.get_strategy()).render()
|
||||
|
||||
# Build two-column layout
|
||||
left_col = Table(show_header=True, show_edge=False, box=None, padding=(0, 2))
|
||||
left_col.add_column("PERFORMANCE", style="bold cyan")
|
||||
left_col.add_column("ACCOUNT", style="bold cyan")
|
||||
left_col.add_row(metrics_table, account_table)
|
||||
|
||||
right_col = Table(show_header=True, show_edge=False, box=None, padding=(0, 2))
|
||||
right_col.add_column("CURRENT POSITION", style="bold cyan")
|
||||
right_col.add_column("STRATEGY STATE", style="bold cyan")
|
||||
right_col.add_row(position_table, strategy_table)
|
||||
|
||||
# Combine into main table
|
||||
main_table = Table(show_header=False, show_edge=False, box=None, expand=True)
|
||||
main_table.add_column(ratio=1)
|
||||
main_table.add_column(ratio=1)
|
||||
main_table.add_row(left_col, right_col)
|
||||
|
||||
content = Group(tabs, Text(""), main_table)
|
||||
|
||||
return Panel(content, border_style="blue")
|
||||
195
live_trading/ui/state.py
Normal file
195
live_trading/ui/state.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Thread-safe shared state for UI and trading loop."""
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
@dataclass
|
||||
class PositionState:
|
||||
"""Current position information."""
|
||||
|
||||
trade_id: str = ""
|
||||
symbol: str = ""
|
||||
side: str = ""
|
||||
entry_price: float = 0.0
|
||||
current_price: float = 0.0
|
||||
size: float = 0.0
|
||||
size_usdt: float = 0.0
|
||||
unrealized_pnl: float = 0.0
|
||||
unrealized_pnl_pct: float = 0.0
|
||||
stop_loss_price: float = 0.0
|
||||
take_profit_price: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyState:
|
||||
"""Current strategy signal state."""
|
||||
|
||||
z_score: float = 0.0
|
||||
probability: float = 0.0
|
||||
funding_rate: float = 0.0
|
||||
last_action: str = "hold"
|
||||
last_reason: str = ""
|
||||
last_signal_time: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AccountState:
|
||||
"""Account balance information."""
|
||||
|
||||
balance: float = 0.0
|
||||
available: float = 0.0
|
||||
leverage: int = 1
|
||||
|
||||
|
||||
class SharedState:
|
||||
"""
|
||||
Thread-safe shared state between trading loop and UI.
|
||||
|
||||
All access to state fields should go through the getter/setter methods
|
||||
which use a lock for thread safety.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._position: Optional[PositionState] = None
|
||||
self._strategy = StrategyState()
|
||||
self._account = AccountState()
|
||||
self._is_running = True
|
||||
self._last_cycle_time: Optional[str] = None
|
||||
self._mode = "DEMO"
|
||||
self._eth_symbol = "ETH/USDT:USDT"
|
||||
self._btc_symbol = "BTC/USDT:USDT"
|
||||
|
||||
# Position methods
|
||||
def get_position(self) -> Optional[PositionState]:
|
||||
"""Get current position state."""
|
||||
with self._lock:
|
||||
return self._position
|
||||
|
||||
def set_position(self, position: Optional[PositionState]) -> None:
|
||||
"""Set current position state."""
|
||||
with self._lock:
|
||||
self._position = position
|
||||
|
||||
def update_position_price(self, current_price: float) -> None:
|
||||
"""Update current price and recalculate PnL."""
|
||||
with self._lock:
|
||||
if self._position is None:
|
||||
return
|
||||
|
||||
self._position.current_price = current_price
|
||||
|
||||
if self._position.side == "long":
|
||||
pnl = (current_price - self._position.entry_price)
|
||||
self._position.unrealized_pnl = pnl * self._position.size
|
||||
pnl_pct = (current_price / self._position.entry_price - 1) * 100
|
||||
else:
|
||||
pnl = (self._position.entry_price - current_price)
|
||||
self._position.unrealized_pnl = pnl * self._position.size
|
||||
pnl_pct = (1 - current_price / self._position.entry_price) * 100
|
||||
|
||||
self._position.unrealized_pnl_pct = pnl_pct
|
||||
|
||||
def clear_position(self) -> None:
|
||||
"""Clear current position."""
|
||||
with self._lock:
|
||||
self._position = None
|
||||
|
||||
# Strategy methods
|
||||
def get_strategy(self) -> StrategyState:
|
||||
"""Get current strategy state."""
|
||||
with self._lock:
|
||||
return StrategyState(
|
||||
z_score=self._strategy.z_score,
|
||||
probability=self._strategy.probability,
|
||||
funding_rate=self._strategy.funding_rate,
|
||||
last_action=self._strategy.last_action,
|
||||
last_reason=self._strategy.last_reason,
|
||||
last_signal_time=self._strategy.last_signal_time,
|
||||
)
|
||||
|
||||
def update_strategy(
|
||||
self,
|
||||
z_score: float,
|
||||
probability: float,
|
||||
funding_rate: float,
|
||||
action: str,
|
||||
reason: str,
|
||||
) -> None:
|
||||
"""Update strategy state."""
|
||||
with self._lock:
|
||||
self._strategy.z_score = z_score
|
||||
self._strategy.probability = probability
|
||||
self._strategy.funding_rate = funding_rate
|
||||
self._strategy.last_action = action
|
||||
self._strategy.last_reason = reason
|
||||
self._strategy.last_signal_time = datetime.now(
|
||||
timezone.utc
|
||||
).isoformat()
|
||||
|
||||
# Account methods
|
||||
def get_account(self) -> AccountState:
|
||||
"""Get current account state."""
|
||||
with self._lock:
|
||||
return AccountState(
|
||||
balance=self._account.balance,
|
||||
available=self._account.available,
|
||||
leverage=self._account.leverage,
|
||||
)
|
||||
|
||||
def update_account(
|
||||
self,
|
||||
balance: float,
|
||||
available: float,
|
||||
leverage: int,
|
||||
) -> None:
|
||||
"""Update account state."""
|
||||
with self._lock:
|
||||
self._account.balance = balance
|
||||
self._account.available = available
|
||||
self._account.leverage = leverage
|
||||
|
||||
# Control methods
|
||||
def is_running(self) -> bool:
|
||||
"""Check if trading loop is running."""
|
||||
with self._lock:
|
||||
return self._is_running
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal to stop trading loop."""
|
||||
with self._lock:
|
||||
self._is_running = False
|
||||
|
||||
def get_last_cycle_time(self) -> Optional[str]:
|
||||
"""Get last trading cycle time."""
|
||||
with self._lock:
|
||||
return self._last_cycle_time
|
||||
|
||||
def set_last_cycle_time(self, time_str: str) -> None:
|
||||
"""Set last trading cycle time."""
|
||||
with self._lock:
|
||||
self._last_cycle_time = time_str
|
||||
|
||||
# Config methods
|
||||
def get_mode(self) -> str:
|
||||
"""Get trading mode (DEMO/LIVE)."""
|
||||
with self._lock:
|
||||
return self._mode
|
||||
|
||||
def set_mode(self, mode: str) -> None:
|
||||
"""Set trading mode."""
|
||||
with self._lock:
|
||||
self._mode = mode
|
||||
|
||||
def get_symbols(self) -> tuple[str, str]:
|
||||
"""Get trading symbols (eth, btc)."""
|
||||
with self._lock:
|
||||
return self._eth_symbol, self._btc_symbol
|
||||
|
||||
def set_symbols(self, eth_symbol: str, btc_symbol: str) -> None:
|
||||
"""Set trading symbols."""
|
||||
with self._lock:
|
||||
self._eth_symbol = eth_symbol
|
||||
self._btc_symbol = btc_symbol
|
||||
@@ -15,6 +15,8 @@ dependencies = [
|
||||
"plotly>=5.24.0",
|
||||
"requests>=2.32.5",
|
||||
"python-dotenv>=1.2.1",
|
||||
# Terminal UI
|
||||
"rich>=13.0.0",
|
||||
# API dependencies
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.34.0",
|
||||
|
||||
@@ -3,7 +3,16 @@ Regime Detection Research Script with Walk-Forward Training.
|
||||
|
||||
Tests multiple holding horizons to find optimal parameters
|
||||
without look-ahead bias.
|
||||
|
||||
Usage:
|
||||
uv run python research/regime_detection.py [options]
|
||||
|
||||
Options:
|
||||
--days DAYS Number of days of data (default: 90)
|
||||
--start DATE Start date (YYYY-MM-DD), overrides --days
|
||||
--end DATE End date (YYYY-MM-DD), defaults to now
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -25,18 +34,36 @@ TRAIN_RATIO = 0.7 # 70% train, 30% test
|
||||
PROFIT_THRESHOLD = 0.005 # 0.5% profit target
|
||||
Z_WINDOW = 24
|
||||
FEE_RATE = 0.001 # 0.1% round-trip fee
|
||||
DEFAULT_DAYS = 90 # Default lookback period in days
|
||||
|
||||
|
||||
def load_data():
|
||||
"""Load and align BTC/ETH data."""
|
||||
def load_data(days: int = DEFAULT_DAYS, start_date: str = None, end_date: str = None):
|
||||
"""
|
||||
Load and align BTC/ETH data.
|
||||
|
||||
Args:
|
||||
days: Number of days of historical data (default: 90)
|
||||
start_date: Optional start date (YYYY-MM-DD), overrides days
|
||||
end_date: Optional end date (YYYY-MM-DD), defaults to now
|
||||
|
||||
Returns:
|
||||
Tuple of (df_btc, df_eth) DataFrames
|
||||
"""
|
||||
dm = DataManager()
|
||||
|
||||
df_btc = dm.load_data("okx", "BTC-USDT", "1h", MarketType.SPOT)
|
||||
df_eth = dm.load_data("okx", "ETH-USDT", "1h", MarketType.SPOT)
|
||||
|
||||
# Filter to Oct-Dec 2025
|
||||
start = pd.Timestamp("2025-10-01", tz="UTC")
|
||||
end = pd.Timestamp("2025-12-31", tz="UTC")
|
||||
# Determine date range
|
||||
if end_date:
|
||||
end = pd.Timestamp(end_date, tz="UTC")
|
||||
else:
|
||||
end = pd.Timestamp.now(tz="UTC")
|
||||
|
||||
if start_date:
|
||||
start = pd.Timestamp(start_date, tz="UTC")
|
||||
else:
|
||||
start = end - pd.Timedelta(days=days)
|
||||
|
||||
df_btc = df_btc[(df_btc.index >= start) & (df_btc.index <= end)]
|
||||
df_eth = df_eth[(df_eth.index >= start) & (df_eth.index <= end)]
|
||||
@@ -46,7 +73,7 @@ def load_data():
|
||||
df_btc = df_btc.loc[common]
|
||||
df_eth = df_eth.loc[common]
|
||||
|
||||
logger.info(f"Loaded {len(common)} aligned hourly bars")
|
||||
logger.info(f"Loaded {len(common)} aligned hourly bars from {start} to {end}")
|
||||
return df_btc, df_eth
|
||||
|
||||
|
||||
@@ -168,7 +195,10 @@ def calculate_mae(features, predictions, test_idx, horizon):
|
||||
|
||||
|
||||
def calculate_net_profit(features, predictions, test_idx, horizon):
|
||||
"""Calculate estimated net profit including fees."""
|
||||
"""
|
||||
Calculate estimated net profit including fees.
|
||||
Enforces 'one trade at a time' to avoid inflating returns with overlapping signals.
|
||||
"""
|
||||
test_features = features.loc[test_idx]
|
||||
spread = test_features['spread']
|
||||
z_score = test_features['z_score']
|
||||
@@ -176,7 +206,14 @@ def calculate_net_profit(features, predictions, test_idx, horizon):
|
||||
total_pnl = 0.0
|
||||
n_trades = 0
|
||||
|
||||
# Track when we are free to trade again
|
||||
next_trade_idx = 0
|
||||
|
||||
for i, (idx, pred) in enumerate(zip(test_idx, predictions)):
|
||||
# Skip if we are still in a trade
|
||||
if i < next_trade_idx:
|
||||
continue
|
||||
|
||||
if pred != 1:
|
||||
continue
|
||||
|
||||
@@ -204,6 +241,10 @@ def calculate_net_profit(features, predictions, test_idx, horizon):
|
||||
total_pnl += net_pnl
|
||||
n_trades += 1
|
||||
|
||||
# Set next available trade index (simple non-overlapping logic)
|
||||
# We assume we hold for 'horizon' bars
|
||||
next_trade_idx = i + horizon
|
||||
|
||||
return total_pnl, n_trades
|
||||
|
||||
|
||||
@@ -295,10 +336,54 @@ def test_horizons(features, horizons):
|
||||
return results
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Regime detection research - test multiple horizons"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--days",
|
||||
type=int,
|
||||
default=DEFAULT_DAYS,
|
||||
help=f"Number of days of data (default: {DEFAULT_DAYS})"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Start date (YYYY-MM-DD), overrides --days"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--end",
|
||||
type=str,
|
||||
default=None,
|
||||
help="End date (YYYY-MM-DD), defaults to now"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="research/horizon_optimization_results.csv",
|
||||
help="Output CSV path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-horizon",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the best horizon (integer) to a file"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main research function."""
|
||||
# Load data
|
||||
df_btc, df_eth = load_data()
|
||||
args = parse_args()
|
||||
|
||||
# Load data with dynamic date range
|
||||
df_btc, df_eth = load_data(
|
||||
days=args.days,
|
||||
start_date=args.start,
|
||||
end_date=args.end
|
||||
)
|
||||
cq_df = load_cryptoquant_data()
|
||||
|
||||
# Calculate features
|
||||
@@ -312,7 +397,7 @@ def main():
|
||||
|
||||
if not results:
|
||||
print("No valid results!")
|
||||
return
|
||||
return None
|
||||
|
||||
# Find best by different metrics
|
||||
results_df = pd.DataFrame(results)
|
||||
@@ -331,9 +416,15 @@ def main():
|
||||
print(f"Lowest MAE: {lowest_mae['horizon']:.0f}h (MAE={lowest_mae['avg_mae']:.2f}%)")
|
||||
|
||||
# Save results
|
||||
output_path = "research/horizon_optimization_results.csv"
|
||||
results_df.to_csv(output_path, index=False)
|
||||
print(f"\nResults saved to {output_path}")
|
||||
results_df.to_csv(args.output, index=False)
|
||||
print(f"\nResults saved to {args.output}")
|
||||
|
||||
# Save best horizon if requested
|
||||
if args.output_horizon:
|
||||
best_h = int(best_pnl['horizon'])
|
||||
with open(args.output_horizon, 'w') as f:
|
||||
f.write(str(best_h))
|
||||
print(f"Best horizon {best_h}h saved to {args.output_horizon}")
|
||||
|
||||
return results_df
|
||||
|
||||
|
||||
16
setup_schedule.sh
Executable file
16
setup_schedule.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
# Setup script for Systemd Timer (Daily Training)
|
||||
|
||||
SERVICE_FILE="tasks/lowkey-training.service"
|
||||
TIMER_FILE="tasks/lowkey-training.timer"
|
||||
SYSTEMD_DIR="/etc/systemd/system"
|
||||
|
||||
echo "To install the daily training schedule, please run the following commands:"
|
||||
echo ""
|
||||
echo "sudo cp $SERVICE_FILE $SYSTEMD_DIR/"
|
||||
echo "sudo cp $TIMER_FILE $SYSTEMD_DIR/"
|
||||
echo "sudo systemctl daemon-reload"
|
||||
echo "sudo systemctl enable --now lowkey-training.timer"
|
||||
echo ""
|
||||
echo "To check the status:"
|
||||
echo "systemctl list-timers --all | grep lowkey"
|
||||
15
tasks/lowkey-training.service
Normal file
15
tasks/lowkey-training.service
Normal file
@@ -0,0 +1,15 @@
|
||||
[Unit]
|
||||
Description=Lowkey Backtest Daily Model Training
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=oneshot
|
||||
WorkingDirectory=/home/tamaya/Documents/Work/TCP/lowkey_backtest_live
|
||||
ExecStart=/home/tamaya/Documents/Work/TCP/lowkey_backtest_live/train_daily.sh
|
||||
User=tamaya
|
||||
Group=tamaya
|
||||
StandardOutput=append:/home/tamaya/Documents/Work/TCP/lowkey_backtest_live/logs/training.log
|
||||
StandardError=append:/home/tamaya/Documents/Work/TCP/lowkey_backtest_live/logs/training.log
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
10
tasks/lowkey-training.timer
Normal file
10
tasks/lowkey-training.timer
Normal file
@@ -0,0 +1,10 @@
|
||||
[Unit]
|
||||
Description=Run Lowkey Backtest Training Daily
|
||||
|
||||
[Timer]
|
||||
OnCalendar=*-*-* 00:30:00
|
||||
Persistent=true
|
||||
Unit=lowkey-training.service
|
||||
|
||||
[Install]
|
||||
WantedBy=timers.target
|
||||
351
tasks/prd-terminal-ui.md
Normal file
351
tasks/prd-terminal-ui.md
Normal file
@@ -0,0 +1,351 @@
|
||||
# PRD: Terminal UI for Live Trading Bot
|
||||
|
||||
## Introduction/Overview
|
||||
|
||||
The live trading bot currently uses basic console logging for output, making it difficult to monitor trading activity, track performance, and understand the system state at a glance. This feature introduces a Rich-based terminal UI that provides a professional, real-time dashboard for monitoring the live trading bot.
|
||||
|
||||
The UI will display a horizontal split layout with a **summary panel** at the top (with tabbed time-period views) and a **scrollable log panel** at the bottom. The interface will update every second and support keyboard navigation.
|
||||
|
||||
## Goals
|
||||
|
||||
1. Provide real-time visibility into trading performance (PnL, win rate, trade count)
|
||||
2. Enable monitoring of current position state (entry, SL/TP, unrealized PnL)
|
||||
3. Display strategy signals (Z-score, model probability) for transparency
|
||||
4. Support historical performance tracking across time periods (daily, weekly, monthly, all-time)
|
||||
5. Improve operational experience with keyboard shortcuts and log filtering
|
||||
6. Create a responsive design that works across different terminal sizes
|
||||
|
||||
## User Stories
|
||||
|
||||
1. **As a trader**, I want to see my total PnL and daily PnL at a glance so I can quickly assess performance.
|
||||
2. **As a trader**, I want to see my current position details (entry price, unrealized PnL, SL/TP levels) so I can monitor risk.
|
||||
3. **As a trader**, I want to view performance metrics by time period (daily, weekly, monthly) so I can track trends.
|
||||
4. **As a trader**, I want to filter logs by type (errors, trades, signals) so I can focus on relevant information.
|
||||
5. **As a trader**, I want keyboard shortcuts to navigate the UI without using a mouse.
|
||||
6. **As a trader**, I want the UI to show strategy state (Z-score, probability) so I understand why signals are generated.
|
||||
|
||||
## Functional Requirements
|
||||
|
||||
### FR1: Layout Structure
|
||||
|
||||
1.1. The UI must use a horizontal split layout with the summary panel at the top and logs panel at the bottom.
|
||||
|
||||
1.2. The summary panel must contain tabbed views accessible via number keys:
|
||||
- Tab 1 (`1`): **General** - Overall metrics since bot started
|
||||
- Tab 2 (`2`): **Monthly** - Current month metrics (shown only if data spans > 1 month)
|
||||
- Tab 3 (`3`): **Weekly** - Current week metrics (shown only if data spans > 1 week)
|
||||
- Tab 4 (`4`): **Daily** - Today's metrics
|
||||
|
||||
1.3. The logs panel must be a scrollable area showing recent log entries.
|
||||
|
||||
1.4. The UI must be responsive and adapt to terminal size (minimum 80x24).
|
||||
|
||||
### FR2: Metrics Display
|
||||
|
||||
The summary panel must display the following metrics:
|
||||
|
||||
**Performance Metrics:**
|
||||
2.1. Total PnL (USD) - cumulative profit/loss since tracking began
|
||||
2.2. Period PnL (USD) - profit/loss for selected time period (daily/weekly/monthly)
|
||||
2.3. Win Rate (%) - percentage of winning trades
|
||||
2.4. Total Number of Trades
|
||||
2.5. Average Trade Duration (hours)
|
||||
2.6. Max Drawdown (USD and %)
|
||||
|
||||
**Current Position (if open):**
|
||||
2.7. Symbol and side (long/short)
|
||||
2.8. Entry price
|
||||
2.9. Current price
|
||||
2.10. Unrealized PnL (USD and %)
|
||||
2.11. Stop-loss price and distance (%)
|
||||
2.12. Take-profit price and distance (%)
|
||||
2.13. Position size (USD)
|
||||
|
||||
**Account Status:**
|
||||
2.14. Account balance / available margin (USDT)
|
||||
2.15. Current leverage setting
|
||||
|
||||
**Strategy State:**
|
||||
2.16. Current Z-score
|
||||
2.17. Model probability
|
||||
2.18. Current funding rate (BTC)
|
||||
2.19. Last signal action and reason
|
||||
|
||||
### FR3: Historical Data Loading
|
||||
|
||||
3.1. On startup, the system must initialize SQLite database at `live_trading/trading.db`.
|
||||
|
||||
3.2. If `trade_log.csv` exists and database is empty, migrate CSV data to SQLite.
|
||||
|
||||
3.3. The UI must load current positions from `live_trading/positions.json` (kept for compatibility with existing position manager).
|
||||
|
||||
3.4. Metrics must be calculated via SQL aggregation queries for each time period.
|
||||
|
||||
3.5. If no historical data exists, the UI must show "No data" gracefully.
|
||||
|
||||
3.6. New trades must be written to both SQLite (primary) and CSV (backup/compatibility).
|
||||
|
||||
### FR4: Real-Time Updates
|
||||
|
||||
4.1. The UI must refresh every 1 second.
|
||||
|
||||
4.2. Position unrealized PnL must update based on latest price data.
|
||||
|
||||
4.3. New log entries must appear in real-time.
|
||||
|
||||
4.4. Metrics must recalculate when trades are opened/closed.
|
||||
|
||||
### FR5: Log Panel
|
||||
|
||||
5.1. The log panel must display log entries with timestamp, level, and message.
|
||||
|
||||
5.2. Log entries must be color-coded by level:
|
||||
- ERROR: Red
|
||||
- WARNING: Yellow
|
||||
- INFO: White/Default
|
||||
- DEBUG: Gray (if shown)
|
||||
|
||||
5.3. The log panel must support filtering by log type:
|
||||
- All logs (default)
|
||||
- Errors only
|
||||
- Trades only (entries containing "position", "trade", "order")
|
||||
- Signals only (entries containing "signal", "z_score", "prob")
|
||||
|
||||
5.4. Filter switching must be available via keyboard shortcut (`f` to cycle filters).
|
||||
|
||||
### FR6: Keyboard Controls
|
||||
|
||||
6.1. `q` or `Ctrl+C` - Graceful shutdown
|
||||
6.2. `r` - Force refresh data
|
||||
6.3. `1` - Switch to General tab
|
||||
6.4. `2` - Switch to Monthly tab
|
||||
6.5. `3` - Switch to Weekly tab
|
||||
6.6. `4` - Switch to Daily tab
|
||||
6.7. `f` - Cycle log filter
|
||||
6.8. Arrow keys - Scroll logs (if supported)
|
||||
|
||||
### FR7: Color Scheme
|
||||
|
||||
7.1. Use dark theme as base.
|
||||
|
||||
7.2. PnL values must be colored:
|
||||
- Positive: Green
|
||||
- Negative: Red
|
||||
- Zero/Neutral: White
|
||||
|
||||
7.3. Position side must be colored:
|
||||
- Long: Green
|
||||
- Short: Red
|
||||
|
||||
7.4. Use consistent color coding for emphasis and warnings.
|
||||
|
||||
## Non-Goals (Out of Scope)
|
||||
|
||||
1. **Mouse support** - This is a keyboard-driven terminal UI
|
||||
2. **Trade execution from UI** - The UI is read-only; trades are executed by the bot
|
||||
3. **Configuration editing** - Config changes require restarting the bot
|
||||
4. **Multi-exchange support** - Only OKX is supported
|
||||
5. **Charts/graphs** - Text-based metrics only (no ASCII charts in v1)
|
||||
6. **Sound alerts** - No audio notifications
|
||||
7. **Remote access** - Local terminal only
|
||||
|
||||
## Design Considerations
|
||||
|
||||
### Technology Choice: Rich
|
||||
|
||||
Use the [Rich](https://github.com/Textualize/rich) Python library for terminal UI:
|
||||
- Rich provides `Live` display for real-time updates
|
||||
- Rich `Layout` for split-screen design
|
||||
- Rich `Table` for metrics display
|
||||
- Rich `Panel` for bordered sections
|
||||
- Rich `Text` for colored output
|
||||
|
||||
Alternative considered: **Textual** (also by Will McGugan) provides more advanced TUI features but adds complexity. Rich is simpler and sufficient for this use case.
|
||||
|
||||
### UI Mockup
|
||||
|
||||
```
|
||||
+==============================================================================+
|
||||
| REGIME REVERSION STRATEGY - LIVE TRADING [DEMO] ETH/USDT |
|
||||
+==============================================================================+
|
||||
| [1:General] [2:Monthly] [3:Weekly] [4:Daily] |
|
||||
+------------------------------------------------------------------------------+
|
||||
| PERFORMANCE | CURRENT POSITION |
|
||||
| Total PnL: $1,234.56 | Side: LONG |
|
||||
| Today PnL: $45.23 | Entry: $3,245.50 |
|
||||
| Win Rate: 67.5% | Current: $3,289.00 |
|
||||
| Total Trades: 24 | Unrealized: +$43.50 (+1.34%) |
|
||||
| Avg Duration: 4.2h | Size: $500.00 |
|
||||
| Max Drawdown: -$156.00 | SL: $3,050.00 (-6.0%) TP: $3,408.00 (+5%)|
|
||||
| | |
|
||||
| ACCOUNT | STRATEGY STATE |
|
||||
| Balance: $5,432.10 | Z-Score: 1.45 |
|
||||
| Available: $4,932.10 | Probability: 0.72 |
|
||||
| Leverage: 2x | Funding: 0.0012 |
|
||||
+------------------------------------------------------------------------------+
|
||||
| LOGS [Filter: All] Press 'f' cycle |
|
||||
+------------------------------------------------------------------------------+
|
||||
| 14:32:15 [INFO] Trading Cycle Start: 2026-01-16T14:32:15+00:00 |
|
||||
| 14:32:16 [INFO] Signal: entry long (prob=0.72, z=-1.45, reason=z_score...) |
|
||||
| 14:32:17 [INFO] Executing LONG entry: 0.1540 ETH @ 3245.50 ($500.00) |
|
||||
| 14:32:18 [INFO] Position opened: ETH/USDT:USDT_20260116_143217 |
|
||||
| 14:32:18 [INFO] Portfolio: 1 positions, exposure=$500.00, unrealized=$0.00 |
|
||||
| 14:32:18 [INFO] --- Cycle completed in 3.2s --- |
|
||||
| 14:32:18 [INFO] Sleeping for 60 minutes... |
|
||||
| |
|
||||
+------------------------------------------------------------------------------+
|
||||
| [q]Quit [r]Refresh [1-4]Tabs [f]Filter |
|
||||
+==============================================================================+
|
||||
```
|
||||
|
||||
### File Structure
|
||||
|
||||
```
|
||||
live_trading/
|
||||
ui/
|
||||
__init__.py
|
||||
dashboard.py # Main UI orchestration and threading
|
||||
panels.py # Panel components (metrics, logs, position)
|
||||
state.py # Thread-safe shared state
|
||||
log_handler.py # Custom logging handler for UI queue
|
||||
keyboard.py # Keyboard input handling
|
||||
db/
|
||||
__init__.py
|
||||
database.py # SQLite connection and queries
|
||||
models.py # Data models (Trade, DailySummary, Session)
|
||||
migrations.py # CSV migration and schema setup
|
||||
metrics.py # Metrics aggregation queries
|
||||
trading.db # SQLite database file (created at runtime)
|
||||
```
|
||||
|
||||
## Technical Considerations
|
||||
|
||||
### Integration with Existing Code
|
||||
|
||||
1. **Logging Integration**: Create a custom `logging.Handler` that captures log messages and forwards them to the UI log panel while still writing to file.
|
||||
|
||||
2. **Data Access**: The UI needs access to:
|
||||
- `PositionManager` for current positions
|
||||
- `TradingConfig` for settings display
|
||||
- SQLite database for historical metrics
|
||||
- Real-time data from `DataFeed` for current prices
|
||||
|
||||
3. **Main Loop Modification**: The `LiveTradingBot.run()` method needs modification to run the UI in a **separate thread** alongside the trading loop. The UI thread handles rendering and keyboard input while the main thread executes trading logic.
|
||||
|
||||
4. **Graceful Shutdown**: Ensure `SIGINT`/`SIGTERM` handlers work with the UI layer and properly terminate the UI thread.
|
||||
|
||||
### Database Schema (SQLite)
|
||||
|
||||
Create `live_trading/trading.db` with the following schema:
|
||||
|
||||
```sql
|
||||
-- Trade history table
|
||||
CREATE TABLE trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
trade_id TEXT UNIQUE NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
side TEXT NOT NULL, -- 'long' or 'short'
|
||||
entry_price REAL NOT NULL,
|
||||
exit_price REAL,
|
||||
size REAL NOT NULL,
|
||||
size_usdt REAL NOT NULL,
|
||||
pnl_usd REAL,
|
||||
pnl_pct REAL,
|
||||
entry_time TEXT NOT NULL, -- ISO format
|
||||
exit_time TEXT,
|
||||
hold_duration_hours REAL,
|
||||
reason TEXT, -- 'stop_loss', 'take_profit', 'signal', etc.
|
||||
order_id_entry TEXT,
|
||||
order_id_exit TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Daily summary table (for faster queries)
|
||||
CREATE TABLE daily_summary (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
date TEXT UNIQUE NOT NULL, -- YYYY-MM-DD
|
||||
total_trades INTEGER DEFAULT 0,
|
||||
winning_trades INTEGER DEFAULT 0,
|
||||
total_pnl_usd REAL DEFAULT 0,
|
||||
max_drawdown_usd REAL DEFAULT 0,
|
||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Session metadata
|
||||
CREATE TABLE sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
start_time TEXT NOT NULL,
|
||||
end_time TEXT,
|
||||
starting_balance REAL,
|
||||
ending_balance REAL,
|
||||
total_pnl REAL,
|
||||
total_trades INTEGER DEFAULT 0
|
||||
);
|
||||
|
||||
-- Indexes for common queries
|
||||
CREATE INDEX idx_trades_entry_time ON trades(entry_time);
|
||||
CREATE INDEX idx_trades_exit_time ON trades(exit_time);
|
||||
CREATE INDEX idx_daily_summary_date ON daily_summary(date);
|
||||
```
|
||||
|
||||
### Migration from CSV
|
||||
|
||||
On first run with the new system:
|
||||
1. Check if `trading.db` exists
|
||||
2. If not, create database with schema
|
||||
3. If `trade_log.csv` exists, migrate data to `trades` table
|
||||
4. Rebuild `daily_summary` from migrated trades
|
||||
|
||||
### Dependencies
|
||||
|
||||
Add to `pyproject.toml`:
|
||||
```toml
|
||||
dependencies = [
|
||||
# ... existing deps
|
||||
"rich>=13.0.0",
|
||||
]
|
||||
```
|
||||
|
||||
Note: SQLite is part of Python's standard library (`sqlite3`), no additional dependency needed.
|
||||
|
||||
### Performance Considerations
|
||||
|
||||
- UI runs in a separate thread to avoid blocking trading logic
|
||||
- Log buffer limited to 1000 entries in memory to prevent growth
|
||||
- SQLite queries should use indexes for fast period-based aggregations
|
||||
- Historical data loading happens once at startup, incremental updates thereafter
|
||||
|
||||
### Threading Model
|
||||
|
||||
```
|
||||
Main Thread UI Thread
|
||||
| |
|
||||
v v
|
||||
[Trading Loop] [Rich Live Display]
|
||||
| |
|
||||
+---> SharedState <------------+
|
||||
(thread-safe)
|
||||
| |
|
||||
+---> LogQueue <---------------+
|
||||
(thread-safe)
|
||||
```
|
||||
|
||||
- Use `threading.Lock` for shared state access
|
||||
- Use `queue.Queue` for log message passing
|
||||
- UI thread polls for updates every 1 second
|
||||
|
||||
## Success Metrics
|
||||
|
||||
1. UI starts successfully and displays all required metrics
|
||||
2. UI updates in real-time (1-second refresh) without impacting trading performance
|
||||
3. All keyboard shortcuts function correctly
|
||||
4. Historical data loads and displays accurately from SQLite
|
||||
5. Log filtering works as expected
|
||||
6. UI gracefully handles edge cases (no data, no position, terminal resize)
|
||||
7. CSV migration completes successfully on first run
|
||||
8. Database queries complete within 100ms
|
||||
|
||||
---
|
||||
|
||||
*Generated: 2026-01-16*
|
||||
*Decisions: Threading model, 1000 log buffer, SQLite database, no fallback mode*
|
||||
50
train_daily.sh
Executable file
50
train_daily.sh
Executable file
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
# Daily ML Model Training Script
|
||||
#
|
||||
# Downloads fresh data and retrains the regime detection model.
|
||||
# Can be run manually or scheduled via cron.
|
||||
#
|
||||
# Usage:
|
||||
# ./train_daily.sh # Full workflow
|
||||
# ./train_daily.sh --skip-research # Skip research validation
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
LOG_DIR="logs"
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
TIMESTAMP=$(date +"%Y-%m-%d %H:%M:%S")
|
||||
echo "[$TIMESTAMP] Starting daily training..."
|
||||
|
||||
# 1. Download fresh data
|
||||
echo "Downloading BTC-USDT 1h data..."
|
||||
uv run python main.py download -p BTC-USDT -t 1h
|
||||
|
||||
echo "Downloading ETH-USDT 1h data..."
|
||||
uv run python main.py download -p ETH-USDT -t 1h
|
||||
|
||||
# 2. Research optimization (find best horizon)
|
||||
echo "Running research optimization..."
|
||||
uv run python research/regime_detection.py --output-horizon data/optimal_horizon.txt
|
||||
|
||||
# 3. Read best horizon
|
||||
if [[ -f "data/optimal_horizon.txt" ]]; then
|
||||
BEST_HORIZON=$(cat data/optimal_horizon.txt)
|
||||
echo "Found optimal horizon: ${BEST_HORIZON} bars"
|
||||
else
|
||||
BEST_HORIZON=102
|
||||
echo "Warning: Could not find optimal horizon file. Using default: ${BEST_HORIZON}"
|
||||
fi
|
||||
|
||||
# 4. Train model
|
||||
echo "Training ML model with horizon ${BEST_HORIZON}..."
|
||||
uv run python train_model.py --horizon "$BEST_HORIZON"
|
||||
|
||||
# 5. Cleanup
|
||||
rm -f data/optimal_horizon.txt
|
||||
|
||||
TIMESTAMP=$(date +"%Y-%m-%d %H:%M:%S")
|
||||
echo "[$TIMESTAMP] Daily training complete."
|
||||
451
train_model.py
Normal file
451
train_model.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""
|
||||
ML Model Training Script.
|
||||
|
||||
Trains the regime detection Random Forest model on historical data.
|
||||
Can be run manually or scheduled via cron for daily retraining.
|
||||
|
||||
Usage:
|
||||
uv run python train_model.py [options]
|
||||
|
||||
Options:
|
||||
--days DAYS Number of days of historical data to use (default: 90)
|
||||
--pair PAIR Trading pair for context (default: BTC-USDT)
|
||||
--timeframe TF Timeframe (default: 1h)
|
||||
--output PATH Output model path (default: data/regime_model.pkl)
|
||||
--train-ratio R Train/test split ratio (default: 0.7)
|
||||
--dry-run Run without saving model
|
||||
"""
|
||||
import argparse
|
||||
import pickle
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import ta
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import classification_report, f1_score
|
||||
|
||||
from engine.data_manager import DataManager
|
||||
from engine.market import MarketType
|
||||
from engine.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Default configuration (from research optimization)
|
||||
DEFAULT_HORIZON = 102 # 4.25 days - optimal from research
|
||||
DEFAULT_Z_WINDOW = 24 # 24h rolling window
|
||||
DEFAULT_PROFIT_TARGET = 0.005 # 0.5% profit threshold
|
||||
DEFAULT_Z_THRESHOLD = 1.0 # Z-score entry threshold
|
||||
DEFAULT_TRAIN_RATIO = 0.7 # 70% train / 30% test
|
||||
FEE_RATE = 0.001 # 0.1% round-trip fee
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train the regime detection ML model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--days",
|
||||
type=int,
|
||||
default=90,
|
||||
help="Number of days of historical data to use (default: 90)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pair",
|
||||
type=str,
|
||||
default="BTC-USDT",
|
||||
help="Base pair for context data (default: BTC-USDT)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spread-pair",
|
||||
type=str,
|
||||
default="ETH-USDT",
|
||||
help="Spread pair to trade (default: ETH-USDT)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeframe",
|
||||
type=str,
|
||||
default="1h",
|
||||
help="Timeframe (default: 1h)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--market",
|
||||
type=str,
|
||||
choices=["spot", "perpetual"],
|
||||
default="perpetual",
|
||||
help="Market type (default: perpetual)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="data/regime_model.pkl",
|
||||
help="Output model path (default: data/regime_model.pkl)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-ratio",
|
||||
type=float,
|
||||
default=DEFAULT_TRAIN_RATIO,
|
||||
help="Train/test split ratio (default: 0.7)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--horizon",
|
||||
type=int,
|
||||
default=DEFAULT_HORIZON,
|
||||
help="Prediction horizon in bars (default: 102)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Run without saving model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--download",
|
||||
action="store_true",
|
||||
help="Download latest data before training"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def download_data(dm: DataManager, pair: str, timeframe: str, market_type: MarketType):
|
||||
"""Download latest data for a pair."""
|
||||
logger.info(f"Downloading latest data for {pair}...")
|
||||
try:
|
||||
dm.download_data("okx", pair, timeframe, market_type)
|
||||
logger.info(f"Downloaded {pair} data")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download {pair}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def load_data(
|
||||
dm: DataManager,
|
||||
base_pair: str,
|
||||
spread_pair: str,
|
||||
timeframe: str,
|
||||
market_type: MarketType,
|
||||
days: int
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""Load and align historical data for both pairs."""
|
||||
df_base = dm.load_data("okx", base_pair, timeframe, market_type)
|
||||
df_spread = dm.load_data("okx", spread_pair, timeframe, market_type)
|
||||
|
||||
# Filter to last N days
|
||||
end_date = pd.Timestamp.now(tz="UTC")
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
df_base = df_base[(df_base.index >= start_date) & (df_base.index <= end_date)]
|
||||
df_spread = df_spread[(df_spread.index >= start_date) & (df_spread.index <= end_date)]
|
||||
|
||||
# Align indices
|
||||
common = df_base.index.intersection(df_spread.index)
|
||||
df_base = df_base.loc[common]
|
||||
df_spread = df_spread.loc[common]
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(common)} bars from {common.min()} to {common.max()}"
|
||||
)
|
||||
return df_base, df_spread
|
||||
|
||||
|
||||
def load_cryptoquant_data() -> pd.DataFrame | None:
|
||||
"""Load CryptoQuant on-chain data if available."""
|
||||
try:
|
||||
cq_path = Path("data/cq_training_data.csv")
|
||||
if not cq_path.exists():
|
||||
logger.info("CryptoQuant data not found, skipping on-chain features")
|
||||
return None
|
||||
|
||||
cq_df = pd.read_csv(cq_path, index_col='timestamp', parse_dates=True)
|
||||
if cq_df.index.tz is None:
|
||||
cq_df.index = cq_df.index.tz_localize('UTC')
|
||||
logger.info(f"Loaded CryptoQuant data: {len(cq_df)} rows")
|
||||
return cq_df
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load CryptoQuant data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def calculate_features(
|
||||
df_base: pd.DataFrame,
|
||||
df_spread: pd.DataFrame,
|
||||
cq_df: pd.DataFrame | None = None,
|
||||
z_window: int = DEFAULT_Z_WINDOW
|
||||
) -> pd.DataFrame:
|
||||
"""Calculate all features for the model."""
|
||||
spread = df_spread['close'] / df_base['close']
|
||||
|
||||
# Z-Score
|
||||
rolling_mean = spread.rolling(window=z_window).mean()
|
||||
rolling_std = spread.rolling(window=z_window).std()
|
||||
z_score = (spread - rolling_mean) / rolling_std
|
||||
|
||||
# Technicals
|
||||
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||
spread_roc = spread.pct_change(periods=5) * 100
|
||||
spread_change_1h = spread.pct_change(periods=1)
|
||||
|
||||
# Volume
|
||||
vol_ratio = df_spread['volume'] / df_base['volume']
|
||||
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||
|
||||
# Volatility
|
||||
ret_base = df_base['close'].pct_change()
|
||||
ret_spread = df_spread['close'].pct_change()
|
||||
vol_base = ret_base.rolling(window=z_window).std()
|
||||
vol_spread = ret_spread.rolling(window=z_window).std()
|
||||
vol_spread_ratio = vol_spread / vol_base
|
||||
|
||||
features = pd.DataFrame(index=spread.index)
|
||||
features['spread'] = spread
|
||||
features['z_score'] = z_score
|
||||
features['spread_rsi'] = spread_rsi
|
||||
features['spread_roc'] = spread_roc
|
||||
features['spread_change_1h'] = spread_change_1h
|
||||
features['vol_ratio'] = vol_ratio
|
||||
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
|
||||
features['vol_diff_ratio'] = vol_spread_ratio
|
||||
|
||||
# Add CQ features if available
|
||||
if cq_df is not None:
|
||||
cq_aligned = cq_df.reindex(features.index, method='ffill')
|
||||
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
|
||||
cq_aligned['funding_diff'] = (
|
||||
cq_aligned['eth_funding'] - cq_aligned['btc_funding']
|
||||
)
|
||||
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
|
||||
cq_aligned['inflow_ratio'] = (
|
||||
cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
|
||||
)
|
||||
features = features.join(cq_aligned)
|
||||
|
||||
return features.dropna()
|
||||
|
||||
|
||||
def calculate_targets(
|
||||
features: pd.DataFrame,
|
||||
horizon: int,
|
||||
profit_target: float = DEFAULT_PROFIT_TARGET,
|
||||
z_threshold: float = DEFAULT_Z_THRESHOLD
|
||||
) -> tuple[np.ndarray, pd.Series]:
|
||||
"""Calculate target labels for training."""
|
||||
spread = features['spread']
|
||||
z_score = features['z_score']
|
||||
|
||||
# For Short (Z > threshold): Did spread drop below target?
|
||||
future_min = spread.rolling(window=horizon).min().shift(-horizon)
|
||||
target_short = spread * (1 - profit_target)
|
||||
success_short = (z_score > z_threshold) & (future_min < target_short)
|
||||
|
||||
# For Long (Z < -threshold): Did spread rise above target?
|
||||
future_max = spread.rolling(window=horizon).max().shift(-horizon)
|
||||
target_long = spread * (1 + profit_target)
|
||||
success_long = (z_score < -z_threshold) & (future_max > target_long)
|
||||
|
||||
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||
|
||||
# Create valid mask (rows with complete future data)
|
||||
valid_mask = future_min.notna() & future_max.notna()
|
||||
|
||||
return targets, valid_mask
|
||||
|
||||
|
||||
def train_model(
|
||||
features: pd.DataFrame,
|
||||
train_ratio: float = DEFAULT_TRAIN_RATIO,
|
||||
horizon: int = DEFAULT_HORIZON
|
||||
) -> tuple[RandomForestClassifier, list[str], dict]:
|
||||
"""
|
||||
Train Random Forest model with walk-forward split.
|
||||
|
||||
Args:
|
||||
features: DataFrame with calculated features
|
||||
train_ratio: Fraction of data to use for training
|
||||
horizon: Prediction horizon in bars
|
||||
|
||||
Returns:
|
||||
Tuple of (trained model, feature columns, metrics dict)
|
||||
"""
|
||||
# Calculate targets
|
||||
targets, valid_mask = calculate_targets(features, horizon)
|
||||
|
||||
# Walk-forward split
|
||||
n_samples = len(features)
|
||||
train_size = int(n_samples * train_ratio)
|
||||
|
||||
train_features = features.iloc[:train_size]
|
||||
test_features = features.iloc[train_size:]
|
||||
|
||||
train_targets = targets[:train_size]
|
||||
test_targets = targets[train_size:]
|
||||
|
||||
train_valid = valid_mask.iloc[:train_size]
|
||||
test_valid = valid_mask.iloc[train_size:]
|
||||
|
||||
# Prepare training data
|
||||
exclude = ['spread']
|
||||
feature_cols = [c for c in features.columns if c not in exclude]
|
||||
|
||||
X_train = train_features[feature_cols].fillna(0).replace([np.inf, -np.inf], 0)
|
||||
X_train_valid = X_train[train_valid]
|
||||
y_train_valid = train_targets[train_valid]
|
||||
|
||||
if len(X_train_valid) < 100:
|
||||
raise ValueError(
|
||||
f"Not enough training data: {len(X_train_valid)} samples (need >= 100)"
|
||||
)
|
||||
|
||||
logger.info(f"Training on {len(X_train_valid)} samples...")
|
||||
|
||||
# Train model
|
||||
model = RandomForestClassifier(
|
||||
n_estimators=300,
|
||||
max_depth=5,
|
||||
min_samples_leaf=30,
|
||||
class_weight={0: 1, 1: 3},
|
||||
random_state=42
|
||||
)
|
||||
model.fit(X_train_valid, y_train_valid)
|
||||
|
||||
# Evaluate on test set
|
||||
X_test = test_features[feature_cols].fillna(0).replace([np.inf, -np.inf], 0)
|
||||
predictions = model.predict(X_test)
|
||||
|
||||
# Only evaluate on valid test rows
|
||||
test_valid_mask = test_valid.values
|
||||
y_test_valid = test_targets[test_valid_mask]
|
||||
pred_valid = predictions[test_valid_mask]
|
||||
|
||||
# Calculate metrics
|
||||
f1 = f1_score(y_test_valid, pred_valid, zero_division=0)
|
||||
|
||||
metrics = {
|
||||
'train_samples': len(X_train_valid),
|
||||
'test_samples': len(X_test),
|
||||
'f1_score': f1,
|
||||
'train_end': train_features.index[-1].isoformat(),
|
||||
'test_start': test_features.index[0].isoformat(),
|
||||
'horizon': horizon,
|
||||
'feature_cols': feature_cols,
|
||||
}
|
||||
|
||||
logger.info(f"Model trained. F1 Score: {f1:.3f}")
|
||||
logger.info(
|
||||
f"Train period: {train_features.index[0]} to {train_features.index[-1]}"
|
||||
)
|
||||
logger.info(
|
||||
f"Test period: {test_features.index[0]} to {test_features.index[-1]}"
|
||||
)
|
||||
|
||||
return model, feature_cols, metrics
|
||||
|
||||
|
||||
def save_model(
|
||||
model: RandomForestClassifier,
|
||||
feature_cols: list[str],
|
||||
metrics: dict,
|
||||
output_path: str,
|
||||
versioned: bool = True
|
||||
):
|
||||
"""
|
||||
Save trained model to file.
|
||||
|
||||
Args:
|
||||
model: Trained model
|
||||
feature_cols: List of feature column names
|
||||
metrics: Training metrics
|
||||
output_path: Output file path
|
||||
versioned: If True, also save a timestamped version
|
||||
"""
|
||||
output = Path(output_path)
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = {
|
||||
'model': model,
|
||||
'feature_cols': feature_cols,
|
||||
'metrics': metrics,
|
||||
'trained_at': datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# Save main model file
|
||||
with open(output, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
logger.info(f"Saved model to {output}")
|
||||
|
||||
# Save versioned copy
|
||||
if versioned:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
versioned_path = output.parent / f"regime_model_{timestamp}.pkl"
|
||||
with open(versioned_path, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
logger.info(f"Saved versioned model to {versioned_path}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
args = parse_args()
|
||||
|
||||
market_type = MarketType.PERPETUAL if args.market == "perpetual" else MarketType.SPOT
|
||||
dm = DataManager()
|
||||
|
||||
# Download latest data if requested
|
||||
if args.download:
|
||||
download_data(dm, args.pair, args.timeframe, market_type)
|
||||
download_data(dm, args.spread_pair, args.timeframe, market_type)
|
||||
|
||||
# Load data
|
||||
try:
|
||||
df_base, df_spread = load_data(
|
||||
dm, args.pair, args.spread_pair, args.timeframe, market_type, args.days
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load data: {e}")
|
||||
logger.info("Try running with --download flag to fetch latest data")
|
||||
sys.exit(1)
|
||||
|
||||
# Load on-chain data
|
||||
cq_df = load_cryptoquant_data()
|
||||
|
||||
# Calculate features
|
||||
features = calculate_features(df_base, df_spread, cq_df)
|
||||
logger.info(
|
||||
f"Calculated {len(features)} feature rows with {len(features.columns)} columns"
|
||||
)
|
||||
|
||||
if len(features) < 200:
|
||||
logger.error(f"Not enough data: {len(features)} rows (need >= 200)")
|
||||
sys.exit(1)
|
||||
|
||||
# Train model
|
||||
try:
|
||||
model, feature_cols, metrics = train_model(
|
||||
features, args.train_ratio, args.horizon
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Print metrics summary
|
||||
print("\n" + "=" * 60)
|
||||
print("TRAINING COMPLETE")
|
||||
print("=" * 60)
|
||||
print(f"Train samples: {metrics['train_samples']}")
|
||||
print(f"Test samples: {metrics['test_samples']}")
|
||||
print(f"F1 Score: {metrics['f1_score']:.3f}")
|
||||
print(f"Horizon: {metrics['horizon']} bars")
|
||||
print(f"Features: {len(feature_cols)}")
|
||||
print("=" * 60)
|
||||
|
||||
# Save model
|
||||
if not args.dry_run:
|
||||
save_model(model, feature_cols, metrics, args.output)
|
||||
else:
|
||||
logger.info("Dry run - model not saved")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
36
uv.lock
generated
36
uv.lock
generated
@@ -981,6 +981,7 @@ dependencies = [
|
||||
{ name = "plotly" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "requests" },
|
||||
{ name = "rich" },
|
||||
{ name = "scikit-learn" },
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "ta" },
|
||||
@@ -1004,6 +1005,7 @@ requires-dist = [
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
||||
{ name = "requests", specifier = ">=2.32.5" },
|
||||
{ name = "rich", specifier = ">=13.0.0" },
|
||||
{ name = "scikit-learn", specifier = ">=1.6.0" },
|
||||
{ name = "sqlalchemy", specifier = ">=2.0.0" },
|
||||
{ name = "ta", specifier = ">=0.11.0" },
|
||||
@@ -1012,6 +1014,18 @@ requires-dist = [
|
||||
]
|
||||
provides-extras = ["dev"]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "4.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mdurl" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matplotlib"
|
||||
version = "3.10.8"
|
||||
@@ -1078,6 +1092,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multidict"
|
||||
version = "6.7.0"
|
||||
@@ -1912,6 +1935,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "14.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "markdown-it-py" },
|
||||
{ name = "pygments" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fb/d2/8920e102050a0de7bfabeb4c4614a49248cf8d5d7a8d01885fbb24dc767a/rich-14.2.0.tar.gz", hash = "sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4", size = 219990, upload-time = "2025-10-09T14:16:53.064Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/25/7a/b0178788f8dc6cafce37a212c99565fa1fe7872c70c6c9c1e1a372d9d88f/rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd", size = 243393, upload-time = "2025-10-09T14:16:51.245Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schedule"
|
||||
version = "1.2.2"
|
||||
|
||||
Reference in New Issue
Block a user