Add daily model training scripts and terminal UI for live trading

- Introduced `train_daily.sh` for automating daily model retraining, including data download and model training steps.
- Added `install_cron.sh` for setting up a cron job to run the daily training script.
- Created `setup_schedule.sh` for configuring Systemd timers for daily training tasks.
- Implemented a terminal UI using Rich for real-time monitoring of trading performance, including metrics display and log handling.
- Updated `pyproject.toml` to include the `rich` dependency for UI functionality.
- Enhanced `.gitignore` to exclude model and log files.
- Added database support for trade persistence and metrics calculation.
- Updated README with installation and usage instructions for the new features.
This commit is contained in:
2026-01-18 11:08:57 +08:00
parent 35992ee374
commit b5550f4ff4
27 changed files with 3582 additions and 113 deletions

4
.gitignore vendored
View File

@@ -175,3 +175,7 @@ data/backtest_runs.db
.gitignore
live_trading/regime_model.pkl
live_trading/positions.json
*.pkl
*.db

304
README.md
View File

@@ -1,82 +1,262 @@
### lowkey_backtest — Supertrend Backtester
# Lowkey Backtest
### Overview
Backtest a simple, long-only strategy driven by a meta Supertrend signal on aggregated OHLCV data. The script:
- Loads 1-minute BTC/USD data from `../data/btcusd_1-min_data.csv`
- Aggregates to multiple timeframes (e.g., `5min`, `15min`, `30min`, `1h`, `4h`, `1d`)
- Computes three Supertrend variants and creates a meta signal when all agree
- Executes entries/exits at the aggregated bar open price
- Applies OKX spot fee assumptions (taker by default)
- Evaluates stop-loss using intra-bar 1-minute data
- Writes detailed trade logs and a summary CSV
A backtesting framework supporting multiple market types (spot, perpetual) with realistic trading simulation including leverage, funding, and shorts.
## Requirements
### Requirements
- Python 3.12+
- Dependencies: `pandas`, `numpy`, `ta`
- Package management: `uv`
- Package manager: `uv`
Install dependencies with uv:
## Installation
```bash
uv sync
# If a dependency is missing, add it explicitly and sync
uv add pandas numpy ta
uv sync
```
### Data
- Expected CSV location: `../data/btcusd_1-min_data.csv` (relative to the repo root)
- Required columns: `Timestamp`, `Open`, `High`, `Low`, `Close`, `Volume`
- `Timestamp` should be UNIX seconds; zero-volume rows are ignored
## Quick Reference
### Quickstart
Run the backtest with defaults:
| Command | Description |
|---------|-------------|
| `uv run python main.py download -p BTC-USDT` | Download data |
| `uv run python main.py backtest -s meta_st -p BTC-USDT` | Run backtest |
| `uv run python main.py wfa -s regime -p BTC-USDT` | Walk-forward analysis |
| `uv run python train_model.py --download` | Train/retrain ML model |
| `uv run python research/regime_detection.py` | Run research script |
---
## Backtest CLI
The main entry point is `main.py` which provides three commands: `download`, `backtest`, and `wfa`.
### Download Data
Download historical OHLCV data from exchanges.
```bash
uv run python main.py
uv run python main.py download -p BTC-USDT -t 1h
```
Outputs:
- Per-run trade logs in `backtest_logs/` named like `trade_log_<TIMEFRAME>_sl<STOPLOSS>.csv`
- Run-level summary in `backtest_summary.csv`
**Options:**
- `-p, --pair` (required): Trading pair (e.g., `BTC-USDT`, `ETH-USDT`)
- `-t, --timeframe`: Timeframe (default: `1m`)
- `-e, --exchange`: Exchange (default: `okx`)
- `-m, --market`: Market type: `spot` or `perpetual` (default: `spot`)
- `--start`: Start date in `YYYY-MM-DD` format
### Configuring a Run
Adjust parameters directly in `main.py`:
- Date range (in `load_data`): `load_data('2021-11-01', '2024-10-16')`
- Timeframes to test (any subset of `"5min", "15min", "30min", "1h", "4h", "1d"`):
- `timeframes = ["5min", "15min", "30min", "1h", "4h", "1d"]`
- Stop-loss percentages:
- `stoplosses = [0.03, 0.05, 0.1]`
- Supertrend settings (in `add_supertrend_indicators`): `(period, multiplier)` pairs `(12, 3.0)`, `(10, 1.0)`, `(11, 2.0)`
- Fee model (in `calculate_okx_taker_maker_fee`): taker `0.0010`, maker `0.0008`
**Examples:**
```bash
# Download 1-hour spot data
uv run python main.py download -p ETH-USDT -t 1h
### What the Backtester Does
- Aggregation: Resamples 1-minute data to your selected timeframe using OHLCV rules
- Supertrend signals: Computes three Supertrends and derives a meta trend of `+1` (bullish) or `-1` (bearish) when all agree; otherwise `0`
- Trade logic (long-only):
- Entry when the meta trend changes to bullish; uses aggregated bar open price
- Exit when the meta trend changes to bearish; uses aggregated bar open price
- Stop-loss: For each aggregated bar, scans corresponding 1-minute closes to detect stop-loss and exits using a realistic fill (threshold or next 1-minute open if gapped)
- Performance metrics: total return, max drawdown, Sharpe (daily, factor 252), win rate, number of trades, final/initial equity, and total fees
### Important: Lookahead Bias Note
The current implementation uses the meta Supertrend signal of the same bar for entries/exits, which introduces lookahead bias. To avoid this, lag the signal by one bar inside `backtest()` in `main.py`:
```python
# Replace the current line
meta_trend_signal = meta_trend
# With a one-bar lag to remove lookahead
# meta_trend_signal = np.roll(meta_trend, 1)
# meta_trend_signal[0] = 0
# Download perpetual data from a specific date
uv run python main.py download -p BTC-USDT -m perpetual --start 2024-01-01
```
### Outputs
- `backtest_logs/trade_log_<TIMEFRAME>_sl<STOPLOSS>.csv`: trade-by-trade records including type (`buy`, `sell`, `stop_loss`, `forced_close`), timestamps, prices, balances, PnL, and fees
- `backtest_summary.csv`: one row per (timeframe, stop-loss) combination with `timeframe`, `stop_loss`, `total_return`, `max_drawdown`, `sharpe_ratio`, `win_rate`, `num_trades`, `final_equity`, `initial_equity`, `num_stop_losses`, `total_fees`
### Run Backtest
### Troubleshooting
- CSV not found: Ensure the dataset is located at `../data/btcusd_1-min_data.csv`
- Missing packages: Run `uv add pandas numpy ta` then `uv sync`
- Memory/performance: Large date ranges on 1-minute data can be heavy; narrow the date span or test fewer timeframes
Run a backtest with a specific strategy.
```bash
uv run python main.py backtest -s <strategy> -p <pair> [options]
```
**Available Strategies:**
- `meta_st` - Meta Supertrend (triple supertrend consensus)
- `regime` - Regime Reversion (ML-based spread trading)
- `rsi` - RSI overbought/oversold
- `macross` - Moving Average Crossover
**Options:**
- `-s, --strategy` (required): Strategy name
- `-p, --pair` (required): Trading pair
- `-t, --timeframe`: Timeframe (default: `1m`)
- `--start`: Start date
- `--end`: End date
- `-g, --grid`: Run grid search optimization
- `--plot`: Show equity curve plot
- `--sl`: Stop loss percentage
- `--tp`: Take profit percentage
- `--trail`: Enable trailing stop
- `--fees`: Override fee rate
- `--slippage`: Slippage (default: `0.001`)
- `-l, --leverage`: Leverage multiplier
**Examples:**
```bash
# Basic backtest with Meta Supertrend
uv run python main.py backtest -s meta_st -p BTC-USDT -t 1h
# Backtest with date range and plot
uv run python main.py backtest -s meta_st -p BTC-USDT --start 2024-01-01 --end 2024-12-31 --plot
# Grid search optimization
uv run python main.py backtest -s meta_st -p BTC-USDT -t 4h -g
# Backtest with risk parameters
uv run python main.py backtest -s meta_st -p BTC-USDT --sl 0.05 --tp 0.10 --trail
# Regime strategy on ETH/BTC spread
uv run python main.py backtest -s regime -p ETH-USDT -t 1h
```
### Walk-Forward Analysis (WFA)
Run walk-forward optimization to avoid overfitting.
```bash
uv run python main.py wfa -s <strategy> -p <pair> [options]
```
**Options:**
- `-s, --strategy` (required): Strategy name
- `-p, --pair` (required): Trading pair
- `-t, --timeframe`: Timeframe (default: `1d`)
- `-w, --windows`: Number of walk-forward windows (default: `10`)
- `--plot`: Show WFA results plot
**Examples:**
```bash
# Walk-forward analysis with 10 windows
uv run python main.py wfa -s meta_st -p BTC-USDT -t 1d -w 10
# WFA with plot output
uv run python main.py wfa -s regime -p ETH-USDT --plot
```
---
## Research Scripts
Research scripts are located in the `research/` directory for experimental analysis.
### Regime Detection Research
Tests multiple holding horizons for the regime reversion strategy using walk-forward training.
```bash
uv run python research/regime_detection.py
```
**Options:**
- `--days DAYS`: Number of days of historical data (default: 90)
- `--start DATE`: Start date (YYYY-MM-DD), overrides `--days`
- `--end DATE`: End date (YYYY-MM-DD), defaults to now
- `--output PATH`: Output CSV path
**Examples:**
```bash
# Use last 90 days (default)
uv run python research/regime_detection.py
# Use last 180 days
uv run python research/regime_detection.py --days 180
# Specific date range
uv run python research/regime_detection.py --start 2025-07-01 --end 2025-12-31
```
**What it does:**
- Loads BTC and ETH hourly data
- Calculates spread features (Z-score, RSI, volume ratios)
- Trains RandomForest classifier with walk-forward methodology
- Tests horizons from 6h to 150h
- Outputs best parameters by F1 score, Net PnL, and MAE
**Output:**
- Console: Summary of results for each horizon
- File: `research/horizon_optimization_results.csv`
---
## ML Model Training
The `regime` strategy uses a RandomForest classifier that can be trained with new data.
### Train Model
Train or retrain the ML model with latest data:
```bash
uv run python train_model.py [options]
```
**Options:**
- `--days DAYS`: Days of historical data (default: 90)
- `--pair PAIR`: Base pair for context (default: BTC-USDT)
- `--spread-pair PAIR`: Trading pair (default: ETH-USDT)
- `--timeframe TF`: Timeframe (default: 1h)
- `--market TYPE`: Market type: `spot` or `perpetual` (default: perpetual)
- `--output PATH`: Model output path (default: `data/regime_model.pkl`)
- `--train-ratio R`: Train/test split ratio (default: 0.7)
- `--horizon H`: Prediction horizon in bars (default: 102)
- `--download`: Download latest data before training
- `--dry-run`: Run without saving model
**Examples:**
```bash
# Train with last 90 days of data
uv run python train_model.py
# Download fresh data and train
uv run python train_model.py --download
# Train with 180 days of data
uv run python train_model.py --days 180
# Train on spot market data
uv run python train_model.py --market spot
# Dry run to see metrics without saving
uv run python train_model.py --dry-run
```
### Daily Retraining (Cron)
To automate daily model retraining, add a cron job:
```bash
# Edit crontab
crontab -e
# Add entry to retrain daily at 00:30 UTC
30 0 * * * cd /path/to/lowkey_backtest_live && uv run python train_model.py --download >> logs/training.log 2>&1
```
### Model Files
| File | Description |
|------|-------------|
| `data/regime_model.pkl` | Current production model |
| `data/regime_model_YYYYMMDD_HHMMSS.pkl` | Versioned model snapshots |
The model file contains:
- Trained RandomForest classifier
- Feature column names
- Training metrics (F1 score, sample counts)
- Training timestamp
---
## Output Files
| Location | Description |
|----------|-------------|
| `backtest_logs/` | Trade logs and WFA results |
| `research/` | Research output files |
| `data/` | Downloaded OHLCV data and ML models |
| `data/regime_model.pkl` | Trained ML model for regime strategy |
---
## Running Tests
```bash
uv run pytest tests/
```
Run a specific test file:
```bash
uv run pytest tests/test_data_manager.py
```

29
install_cron.sh Executable file
View File

@@ -0,0 +1,29 @@
#!/bin/bash
# Install cron job for daily model training
# Runs daily at 00:30
PROJECT_DIR="/home/tamaya/Documents/Work/TCP/lowkey_backtest_live"
SCRIPT_PATH="$PROJECT_DIR/train_daily.sh"
LOG_PATH="$PROJECT_DIR/logs/training.log"
# Check if script exists
if [ ! -f "$SCRIPT_PATH" ]; then
echo "Error: $SCRIPT_PATH not found!"
exit 1
fi
# Make executable
chmod +x "$SCRIPT_PATH"
# Prepare cron entry
# 30 0 * * * = 00:30 daily
CRON_CMD="30 0 * * * $SCRIPT_PATH >> $LOG_PATH 2>&1"
# Check if job already exists
(crontab -l 2>/dev/null | grep -F "$SCRIPT_PATH") && echo "Cron job already exists." && exit 0
# Add to crontab
(crontab -l 2>/dev/null; echo "$CRON_CMD") | crontab -
echo "Cron job installed successfully:"
echo "$CRON_CMD"

View File

@@ -60,7 +60,7 @@ class TradingConfig:
# Position sizing
max_position_usdt: float = -1.0 # Max position size in USDT. If <= 0, use all available funds
min_position_usdt: float = 10.0 # Min position size in USDT
min_position_usdt: float = 1.0 # Min position size in USDT
leverage: int = 1 # Leverage (1x = no leverage)
margin_mode: str = "cross" # "cross" or "isolated"

View File

@@ -0,0 +1,13 @@
"""Database module for live trading persistence."""
from .database import get_db, init_db
from .models import Trade, DailySummary, Session
from .metrics import MetricsCalculator
__all__ = [
"get_db",
"init_db",
"Trade",
"DailySummary",
"Session",
"MetricsCalculator",
]

325
live_trading/db/database.py Normal file
View File

@@ -0,0 +1,325 @@
"""SQLite database connection and operations."""
import sqlite3
import logging
from pathlib import Path
from typing import Optional
from contextlib import contextmanager
from .models import Trade, DailySummary, Session
logger = logging.getLogger(__name__)
# Database schema
SCHEMA = """
-- Trade history table
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trade_id TEXT UNIQUE NOT NULL,
symbol TEXT NOT NULL,
side TEXT NOT NULL,
entry_price REAL NOT NULL,
exit_price REAL,
size REAL NOT NULL,
size_usdt REAL NOT NULL,
pnl_usd REAL,
pnl_pct REAL,
entry_time TEXT NOT NULL,
exit_time TEXT,
hold_duration_hours REAL,
reason TEXT,
order_id_entry TEXT,
order_id_exit TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Daily summary table
CREATE TABLE IF NOT EXISTS daily_summary (
id INTEGER PRIMARY KEY AUTOINCREMENT,
date TEXT UNIQUE NOT NULL,
total_trades INTEGER DEFAULT 0,
winning_trades INTEGER DEFAULT 0,
total_pnl_usd REAL DEFAULT 0,
max_drawdown_usd REAL DEFAULT 0,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Session metadata
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
start_time TEXT NOT NULL,
end_time TEXT,
starting_balance REAL,
ending_balance REAL,
total_pnl REAL,
total_trades INTEGER DEFAULT 0
);
-- Indexes for common queries
CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades(entry_time);
CREATE INDEX IF NOT EXISTS idx_trades_exit_time ON trades(exit_time);
CREATE INDEX IF NOT EXISTS idx_daily_summary_date ON daily_summary(date);
"""
_db_instance: Optional["TradingDatabase"] = None
class TradingDatabase:
"""SQLite database for trade persistence."""
def __init__(self, db_path: Path):
self.db_path = db_path
self._connection: Optional[sqlite3.Connection] = None
@property
def connection(self) -> sqlite3.Connection:
"""Get or create database connection."""
if self._connection is None:
self._connection = sqlite3.connect(
str(self.db_path),
check_same_thread=False,
)
self._connection.row_factory = sqlite3.Row
return self._connection
def init_schema(self) -> None:
"""Initialize database schema."""
with self.connection:
self.connection.executescript(SCHEMA)
logger.info(f"Database initialized at {self.db_path}")
def close(self) -> None:
"""Close database connection."""
if self._connection:
self._connection.close()
self._connection = None
@contextmanager
def transaction(self):
"""Context manager for database transactions."""
try:
yield self.connection
self.connection.commit()
except Exception:
self.connection.rollback()
raise
def insert_trade(self, trade: Trade) -> int:
"""
Insert a new trade record.
Args:
trade: Trade object to insert
Returns:
Row ID of inserted trade
"""
sql = """
INSERT INTO trades (
trade_id, symbol, side, entry_price, exit_price,
size, size_usdt, pnl_usd, pnl_pct, entry_time,
exit_time, hold_duration_hours, reason,
order_id_entry, order_id_exit
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
with self.transaction():
cursor = self.connection.execute(
sql,
(
trade.trade_id,
trade.symbol,
trade.side,
trade.entry_price,
trade.exit_price,
trade.size,
trade.size_usdt,
trade.pnl_usd,
trade.pnl_pct,
trade.entry_time,
trade.exit_time,
trade.hold_duration_hours,
trade.reason,
trade.order_id_entry,
trade.order_id_exit,
),
)
return cursor.lastrowid
def update_trade(self, trade_id: str, **kwargs) -> bool:
"""
Update an existing trade record.
Args:
trade_id: Trade ID to update
**kwargs: Fields to update
Returns:
True if trade was updated
"""
if not kwargs:
return False
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
sql = f"UPDATE trades SET {set_clause} WHERE trade_id = ?"
with self.transaction():
cursor = self.connection.execute(
sql, (*kwargs.values(), trade_id)
)
return cursor.rowcount > 0
def get_trade(self, trade_id: str) -> Optional[Trade]:
"""Get a trade by ID."""
sql = "SELECT * FROM trades WHERE trade_id = ?"
row = self.connection.execute(sql, (trade_id,)).fetchone()
if row:
return Trade(**dict(row))
return None
def get_trades(
self,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
limit: Optional[int] = None,
) -> list[Trade]:
"""
Get trades within a time range.
Args:
start_time: ISO format start time filter
end_time: ISO format end time filter
limit: Maximum number of trades to return
Returns:
List of Trade objects
"""
conditions = []
params = []
if start_time:
conditions.append("entry_time >= ?")
params.append(start_time)
if end_time:
conditions.append("entry_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions) if conditions else "1=1"
limit_clause = f"LIMIT {limit}" if limit else ""
sql = f"""
SELECT * FROM trades
WHERE {where_clause}
ORDER BY entry_time DESC
{limit_clause}
"""
rows = self.connection.execute(sql, params).fetchall()
return [Trade(**dict(row)) for row in rows]
def get_all_trades(self) -> list[Trade]:
"""Get all trades."""
sql = "SELECT * FROM trades ORDER BY entry_time DESC"
rows = self.connection.execute(sql).fetchall()
return [Trade(**dict(row)) for row in rows]
def count_trades(self) -> int:
"""Get total number of trades."""
sql = "SELECT COUNT(*) FROM trades WHERE exit_time IS NOT NULL"
return self.connection.execute(sql).fetchone()[0]
def upsert_daily_summary(self, summary: DailySummary) -> None:
"""Insert or update daily summary."""
sql = """
INSERT INTO daily_summary (
date, total_trades, winning_trades, total_pnl_usd, max_drawdown_usd
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(date) DO UPDATE SET
total_trades = excluded.total_trades,
winning_trades = excluded.winning_trades,
total_pnl_usd = excluded.total_pnl_usd,
max_drawdown_usd = excluded.max_drawdown_usd,
updated_at = CURRENT_TIMESTAMP
"""
with self.transaction():
self.connection.execute(
sql,
(
summary.date,
summary.total_trades,
summary.winning_trades,
summary.total_pnl_usd,
summary.max_drawdown_usd,
),
)
def get_daily_summary(self, date: str) -> Optional[DailySummary]:
"""Get daily summary for a specific date."""
sql = "SELECT * FROM daily_summary WHERE date = ?"
row = self.connection.execute(sql, (date,)).fetchone()
if row:
return DailySummary(**dict(row))
return None
def insert_session(self, session: Session) -> int:
"""Insert a new session record."""
sql = """
INSERT INTO sessions (
start_time, end_time, starting_balance,
ending_balance, total_pnl, total_trades
) VALUES (?, ?, ?, ?, ?, ?)
"""
with self.transaction():
cursor = self.connection.execute(
sql,
(
session.start_time,
session.end_time,
session.starting_balance,
session.ending_balance,
session.total_pnl,
session.total_trades,
),
)
return cursor.lastrowid
def update_session(self, session_id: int, **kwargs) -> bool:
"""Update an existing session."""
if not kwargs:
return False
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
sql = f"UPDATE sessions SET {set_clause} WHERE id = ?"
with self.transaction():
cursor = self.connection.execute(
sql, (*kwargs.values(), session_id)
)
return cursor.rowcount > 0
def get_latest_session(self) -> Optional[Session]:
"""Get the most recent session."""
sql = "SELECT * FROM sessions ORDER BY id DESC LIMIT 1"
row = self.connection.execute(sql).fetchone()
if row:
return Session(**dict(row))
return None
def init_db(db_path: Path) -> TradingDatabase:
"""
Initialize the database.
Args:
db_path: Path to the SQLite database file
Returns:
TradingDatabase instance
"""
global _db_instance
_db_instance = TradingDatabase(db_path)
_db_instance.init_schema()
return _db_instance
def get_db() -> Optional[TradingDatabase]:
"""Get the global database instance."""
return _db_instance

235
live_trading/db/metrics.py Normal file
View File

@@ -0,0 +1,235 @@
"""Metrics calculation from trade database."""
import logging
from dataclasses import dataclass
from datetime import datetime, timezone, timedelta
from typing import Optional
from .database import TradingDatabase
logger = logging.getLogger(__name__)
@dataclass
class PeriodMetrics:
"""Trading metrics for a time period."""
period_name: str
start_time: Optional[str]
end_time: Optional[str]
total_pnl: float = 0.0
total_trades: int = 0
winning_trades: int = 0
losing_trades: int = 0
win_rate: float = 0.0
avg_trade_duration_hours: float = 0.0
max_drawdown: float = 0.0
max_drawdown_pct: float = 0.0
best_trade: float = 0.0
worst_trade: float = 0.0
avg_win: float = 0.0
avg_loss: float = 0.0
class MetricsCalculator:
"""Calculate trading metrics from database."""
def __init__(self, db: TradingDatabase):
self.db = db
def get_all_time_metrics(self) -> PeriodMetrics:
"""Get metrics for all trades ever."""
return self._calculate_metrics("All Time", None, None)
def get_monthly_metrics(self) -> PeriodMetrics:
"""Get metrics for current month."""
now = datetime.now(timezone.utc)
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return self._calculate_metrics(
"Monthly",
start.isoformat(),
now.isoformat(),
)
def get_weekly_metrics(self) -> PeriodMetrics:
"""Get metrics for current week (Monday to now)."""
now = datetime.now(timezone.utc)
days_since_monday = now.weekday()
start = now - timedelta(days=days_since_monday)
start = start.replace(hour=0, minute=0, second=0, microsecond=0)
return self._calculate_metrics(
"Weekly",
start.isoformat(),
now.isoformat(),
)
def get_daily_metrics(self) -> PeriodMetrics:
"""Get metrics for today (UTC)."""
now = datetime.now(timezone.utc)
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
return self._calculate_metrics(
"Daily",
start.isoformat(),
now.isoformat(),
)
def _calculate_metrics(
self,
period_name: str,
start_time: Optional[str],
end_time: Optional[str],
) -> PeriodMetrics:
"""
Calculate metrics for a time period.
Args:
period_name: Name of the period
start_time: ISO format start time (None for all time)
end_time: ISO format end time (None for all time)
Returns:
PeriodMetrics object
"""
metrics = PeriodMetrics(
period_name=period_name,
start_time=start_time,
end_time=end_time,
)
# Build query conditions
conditions = ["exit_time IS NOT NULL"]
params = []
if start_time:
conditions.append("exit_time >= ?")
params.append(start_time)
if end_time:
conditions.append("exit_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions)
# Get aggregate metrics
sql = f"""
SELECT
COUNT(*) as total_trades,
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
SUM(CASE WHEN pnl_usd < 0 THEN 1 ELSE 0 END) as losing_trades,
COALESCE(SUM(pnl_usd), 0) as total_pnl,
COALESCE(AVG(hold_duration_hours), 0) as avg_duration,
COALESCE(MAX(pnl_usd), 0) as best_trade,
COALESCE(MIN(pnl_usd), 0) as worst_trade,
COALESCE(AVG(CASE WHEN pnl_usd > 0 THEN pnl_usd END), 0) as avg_win,
COALESCE(AVG(CASE WHEN pnl_usd < 0 THEN pnl_usd END), 0) as avg_loss
FROM trades
WHERE {where_clause}
"""
row = self.db.connection.execute(sql, params).fetchone()
if row and row["total_trades"] > 0:
metrics.total_trades = row["total_trades"]
metrics.winning_trades = row["winning_trades"] or 0
metrics.losing_trades = row["losing_trades"] or 0
metrics.total_pnl = row["total_pnl"]
metrics.avg_trade_duration_hours = row["avg_duration"]
metrics.best_trade = row["best_trade"]
metrics.worst_trade = row["worst_trade"]
metrics.avg_win = row["avg_win"]
metrics.avg_loss = row["avg_loss"]
if metrics.total_trades > 0:
metrics.win_rate = (
metrics.winning_trades / metrics.total_trades * 100
)
# Calculate max drawdown
metrics.max_drawdown = self._calculate_max_drawdown(
start_time, end_time
)
return metrics
def _calculate_max_drawdown(
self,
start_time: Optional[str],
end_time: Optional[str],
) -> float:
"""Calculate maximum drawdown for a period."""
conditions = ["exit_time IS NOT NULL"]
params = []
if start_time:
conditions.append("exit_time >= ?")
params.append(start_time)
if end_time:
conditions.append("exit_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions)
sql = f"""
SELECT pnl_usd
FROM trades
WHERE {where_clause}
ORDER BY exit_time
"""
rows = self.db.connection.execute(sql, params).fetchall()
if not rows:
return 0.0
cumulative = 0.0
peak = 0.0
max_drawdown = 0.0
for row in rows:
pnl = row["pnl_usd"] or 0.0
cumulative += pnl
peak = max(peak, cumulative)
drawdown = peak - cumulative
max_drawdown = max(max_drawdown, drawdown)
return max_drawdown
def has_monthly_data(self) -> bool:
"""Check if we have data spanning more than current month."""
sql = """
SELECT MIN(exit_time) as first_trade
FROM trades
WHERE exit_time IS NOT NULL
"""
row = self.db.connection.execute(sql).fetchone()
if not row or not row["first_trade"]:
return False
first_trade = datetime.fromisoformat(row["first_trade"])
now = datetime.now(timezone.utc)
month_start = now.replace(day=1, hour=0, minute=0, second=0)
return first_trade < month_start
def has_weekly_data(self) -> bool:
"""Check if we have data spanning more than current week."""
sql = """
SELECT MIN(exit_time) as first_trade
FROM trades
WHERE exit_time IS NOT NULL
"""
row = self.db.connection.execute(sql).fetchone()
if not row or not row["first_trade"]:
return False
first_trade = datetime.fromisoformat(row["first_trade"])
now = datetime.now(timezone.utc)
days_since_monday = now.weekday()
week_start = now - timedelta(days=days_since_monday)
week_start = week_start.replace(hour=0, minute=0, second=0)
return first_trade < week_start
def get_session_start_balance(self) -> Optional[float]:
"""Get starting balance from latest session."""
sql = "SELECT starting_balance FROM sessions ORDER BY id DESC LIMIT 1"
row = self.db.connection.execute(sql).fetchone()
return row["starting_balance"] if row else None

View File

@@ -0,0 +1,191 @@
"""Database migrations and CSV import."""
import csv
import logging
from pathlib import Path
from datetime import datetime
from .database import TradingDatabase
from .models import Trade, DailySummary
logger = logging.getLogger(__name__)
def migrate_csv_to_db(db: TradingDatabase, csv_path: Path) -> int:
"""
Migrate trades from CSV file to SQLite database.
Args:
db: TradingDatabase instance
csv_path: Path to trade_log.csv
Returns:
Number of trades migrated
"""
if not csv_path.exists():
logger.info("No CSV file to migrate")
return 0
# Check if database already has trades
existing_count = db.count_trades()
if existing_count > 0:
logger.info(
f"Database already has {existing_count} trades, skipping migration"
)
return 0
migrated = 0
try:
with open(csv_path, "r", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
trade = _csv_row_to_trade(row)
if trade:
try:
db.insert_trade(trade)
migrated += 1
except Exception as e:
logger.warning(
f"Failed to migrate trade {row.get('trade_id')}: {e}"
)
logger.info(f"Migrated {migrated} trades from CSV to database")
except Exception as e:
logger.error(f"CSV migration failed: {e}")
return migrated
def _csv_row_to_trade(row: dict) -> Trade | None:
"""Convert a CSV row to a Trade object."""
try:
return Trade(
trade_id=row["trade_id"],
symbol=row["symbol"],
side=row["side"],
entry_price=float(row["entry_price"]),
exit_price=_safe_float(row.get("exit_price")),
size=float(row["size"]),
size_usdt=float(row["size_usdt"]),
pnl_usd=_safe_float(row.get("pnl_usd")),
pnl_pct=_safe_float(row.get("pnl_pct")),
entry_time=row["entry_time"],
exit_time=row.get("exit_time") or None,
hold_duration_hours=_safe_float(row.get("hold_duration_hours")),
reason=row.get("reason") or None,
order_id_entry=row.get("order_id_entry") or None,
order_id_exit=row.get("order_id_exit") or None,
)
except (KeyError, ValueError) as e:
logger.warning(f"Invalid CSV row: {e}")
return None
def _safe_float(value: str | None) -> float | None:
"""Safely convert string to float."""
if value is None or value == "":
return None
try:
return float(value)
except ValueError:
return None
def rebuild_daily_summaries(db: TradingDatabase) -> int:
"""
Rebuild daily summary table from trades.
Args:
db: TradingDatabase instance
Returns:
Number of daily summaries created
"""
sql = """
SELECT
DATE(exit_time) as date,
COUNT(*) as total_trades,
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
SUM(pnl_usd) as total_pnl_usd
FROM trades
WHERE exit_time IS NOT NULL
GROUP BY DATE(exit_time)
ORDER BY date
"""
rows = db.connection.execute(sql).fetchall()
count = 0
for row in rows:
summary = DailySummary(
date=row["date"],
total_trades=row["total_trades"],
winning_trades=row["winning_trades"],
total_pnl_usd=row["total_pnl_usd"] or 0.0,
max_drawdown_usd=0.0, # Calculated separately
)
db.upsert_daily_summary(summary)
count += 1
# Calculate max drawdowns
_calculate_daily_drawdowns(db)
logger.info(f"Rebuilt {count} daily summaries")
return count
def _calculate_daily_drawdowns(db: TradingDatabase) -> None:
"""Calculate and update max drawdown for each day."""
sql = """
SELECT trade_id, DATE(exit_time) as date, pnl_usd
FROM trades
WHERE exit_time IS NOT NULL
ORDER BY exit_time
"""
rows = db.connection.execute(sql).fetchall()
# Track cumulative PnL and drawdown per day
daily_drawdowns: dict[str, float] = {}
cumulative_pnl = 0.0
peak_pnl = 0.0
for row in rows:
date = row["date"]
pnl = row["pnl_usd"] or 0.0
cumulative_pnl += pnl
peak_pnl = max(peak_pnl, cumulative_pnl)
drawdown = peak_pnl - cumulative_pnl
if date not in daily_drawdowns:
daily_drawdowns[date] = 0.0
daily_drawdowns[date] = max(daily_drawdowns[date], drawdown)
# Update daily summaries with drawdown
for date, drawdown in daily_drawdowns.items():
db.connection.execute(
"UPDATE daily_summary SET max_drawdown_usd = ? WHERE date = ?",
(drawdown, date),
)
db.connection.commit()
def run_migrations(db: TradingDatabase, csv_path: Path) -> None:
"""
Run all migrations.
Args:
db: TradingDatabase instance
csv_path: Path to trade_log.csv for migration
"""
logger.info("Running database migrations...")
# Migrate CSV data if exists
migrate_csv_to_db(db, csv_path)
# Rebuild daily summaries
rebuild_daily_summaries(db)
logger.info("Migrations complete")

69
live_trading/db/models.py Normal file
View File

@@ -0,0 +1,69 @@
"""Data models for trade persistence."""
from dataclasses import dataclass, asdict
from typing import Optional
from datetime import datetime
@dataclass
class Trade:
"""Represents a completed trade."""
trade_id: str
symbol: str
side: str
entry_price: float
size: float
size_usdt: float
entry_time: str
exit_price: Optional[float] = None
pnl_usd: Optional[float] = None
pnl_pct: Optional[float] = None
exit_time: Optional[str] = None
hold_duration_hours: Optional[float] = None
reason: Optional[str] = None
order_id_entry: Optional[str] = None
order_id_exit: Optional[str] = None
id: Optional[int] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
@classmethod
def from_row(cls, row: tuple, columns: list[str]) -> "Trade":
"""Create Trade from database row."""
data = dict(zip(columns, row))
return cls(**data)
@dataclass
class DailySummary:
"""Daily trading summary."""
date: str
total_trades: int = 0
winning_trades: int = 0
total_pnl_usd: float = 0.0
max_drawdown_usd: float = 0.0
id: Optional[int] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
@dataclass
class Session:
"""Trading session metadata."""
start_time: str
end_time: Optional[str] = None
starting_balance: Optional[float] = None
ending_balance: Optional[float] = None
total_pnl: Optional[float] = None
total_trades: int = 0
id: Optional[int] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)

View File

@@ -39,17 +39,39 @@ class LiveRegimeStrategy:
self.paths = path_config
self.model: Optional[RandomForestClassifier] = None
self.feature_cols: Optional[list] = None
self.horizon: int = 102 # Default horizon
self._last_model_load_time: float = 0.0
self._load_or_train_model()
def reload_model_if_changed(self) -> None:
"""Check if model file has changed and reload if necessary."""
if not self.paths.model_path.exists():
return
try:
mtime = self.paths.model_path.stat().st_mtime
if mtime > self._last_model_load_time:
logger.info(f"Model file changed, reloading... (last: {self._last_model_load_time}, new: {mtime})")
self._load_or_train_model()
except Exception as e:
logger.warning(f"Error checking model file: {e}")
def _load_or_train_model(self) -> None:
"""Load pre-trained model or train a new one."""
if self.paths.model_path.exists():
try:
self._last_model_load_time = self.paths.model_path.stat().st_mtime
with open(self.paths.model_path, 'rb') as f:
saved = pickle.load(f)
self.model = saved['model']
self.feature_cols = saved['feature_cols']
logger.info(f"Loaded model from {self.paths.model_path}")
# Load horizon from metrics if available
if 'metrics' in saved and 'horizon' in saved['metrics']:
self.horizon = saved['metrics']['horizon']
logger.info(f"Loaded model from {self.paths.model_path} (horizon={self.horizon})")
else:
logger.info(f"Loaded model from {self.paths.model_path} (default horizon={self.horizon})")
return
except Exception as e:
logger.warning(f"Could not load model: {e}")
@@ -66,6 +88,7 @@ class LiveRegimeStrategy:
pickle.dump({
'model': self.model,
'feature_cols': self.feature_cols,
'metrics': {'horizon': self.horizon} # Save horizon
}, f)
logger.info(f"Saved model to {self.paths.model_path}")
except Exception as e:
@@ -81,7 +104,7 @@ class LiveRegimeStrategy:
logger.info(f"Training model on {len(features)} samples...")
z_thresh = self.config.z_entry_threshold
horizon = 102 # Optimal horizon from research
horizon = self.horizon
profit_target = 0.005 # 0.5% profit threshold
# Define targets

View File

@@ -11,14 +11,19 @@ Usage:
# Run with specific settings
uv run python -m live_trading.main --max-position 500 --leverage 2
# Run without UI (headless mode)
uv run python -m live_trading.main --no-ui
"""
import argparse
import logging
import queue
import signal
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
@@ -28,22 +33,47 @@ from live_trading.okx_client import OKXClient
from live_trading.data_feed import DataFeed
from live_trading.position_manager import PositionManager
from live_trading.live_regime_strategy import LiveRegimeStrategy
from live_trading.db.database import init_db, TradingDatabase
from live_trading.db.migrations import run_migrations
from live_trading.ui.state import SharedState, PositionState
from live_trading.ui.dashboard import TradingDashboard, setup_ui_logging
def setup_logging(log_dir: Path) -> logging.Logger:
"""Configure logging for the trading bot."""
def setup_logging(
log_dir: Path,
log_queue: Optional[queue.Queue] = None,
) -> logging.Logger:
"""
Configure logging for the trading bot.
Args:
log_dir: Directory for log files
log_queue: Optional queue for UI log handler
Returns:
Logger instance
"""
log_file = log_dir / "live_trading.log"
handlers = [
logging.FileHandler(log_file),
]
# Only add StreamHandler if no UI (log_queue is None)
if log_queue is None:
handlers.append(logging.StreamHandler(sys.stdout))
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout),
],
force=True
handlers=handlers,
force=True,
)
# Add UI log handler if queue provided
if log_queue is not None:
setup_ui_logging(log_queue)
return logging.getLogger(__name__)
@@ -59,11 +89,15 @@ class LiveTradingBot:
self,
okx_config: OKXConfig,
trading_config: TradingConfig,
path_config: PathConfig
path_config: PathConfig,
database: Optional[TradingDatabase] = None,
shared_state: Optional[SharedState] = None,
):
self.okx_config = okx_config
self.trading_config = trading_config
self.path_config = path_config
self.db = database
self.state = shared_state
self.logger = logging.getLogger(__name__)
self.running = True
@@ -74,7 +108,7 @@ class LiveTradingBot:
self.okx_client = OKXClient(okx_config, trading_config)
self.data_feed = DataFeed(self.okx_client, trading_config, path_config)
self.position_manager = PositionManager(
self.okx_client, trading_config, path_config
self.okx_client, trading_config, path_config, database
)
self.strategy = LiveRegimeStrategy(trading_config, path_config)
@@ -82,6 +116,16 @@ class LiveTradingBot:
signal.signal(signal.SIGINT, self._handle_shutdown)
signal.signal(signal.SIGTERM, self._handle_shutdown)
# Initialize shared state if provided
if self.state:
mode = "DEMO" if okx_config.demo_mode else "LIVE"
self.state.set_mode(mode)
self.state.set_symbols(
trading_config.eth_symbol,
trading_config.btc_symbol,
)
self.state.update_account(0.0, 0.0, trading_config.leverage)
self._print_startup_banner()
def _print_startup_banner(self) -> None:
@@ -109,6 +153,8 @@ class LiveTradingBot:
"""Handle shutdown signals gracefully."""
self.logger.info("Shutdown signal received, stopping...")
self.running = False
if self.state:
self.state.stop()
def run_trading_cycle(self) -> None:
"""
@@ -118,10 +164,20 @@ class LiveTradingBot:
2. Update open positions
3. Generate trading signal
4. Execute trades if signal triggers
5. Update shared state for UI
"""
# Reload model if it has changed (e.g. daily training)
try:
self.strategy.reload_model_if_changed()
except Exception as e:
self.logger.warning(f"Failed to reload model: {e}")
cycle_start = datetime.now(timezone.utc)
self.logger.info(f"--- Trading Cycle Start: {cycle_start.isoformat()} ---")
if self.state:
self.state.set_last_cycle_time(cycle_start.isoformat())
try:
# 1. Fetch market data
features = self.data_feed.get_latest_data()
@@ -154,15 +210,22 @@ class LiveTradingBot:
funding = self.data_feed.get_current_funding_rates()
# 5. Generate trading signal
signal = self.strategy.generate_signal(features, funding)
sig = self.strategy.generate_signal(features, funding)
# 6. Execute trades based on signal
if signal['action'] == 'entry':
self._execute_entry(signal, eth_price)
elif signal['action'] == 'check_exit':
self._execute_exit(signal)
# 6. Update shared state with strategy info
self._update_strategy_state(sig, funding)
# 7. Log portfolio summary
# 7. Execute trades based on signal
if sig['action'] == 'entry':
self._execute_entry(sig, eth_price)
elif sig['action'] == 'check_exit':
self._execute_exit(sig)
# 8. Update shared state with position and account
self._update_position_state(eth_price)
self._update_account_state()
# 9. Log portfolio summary
summary = self.position_manager.get_portfolio_summary()
self.logger.info(
f"Portfolio: {summary['open_positions']} positions, "
@@ -178,6 +241,61 @@ class LiveTradingBot:
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
self.logger.info(f"--- Cycle completed in {cycle_duration:.1f}s ---")
def _update_strategy_state(self, sig: dict, funding: dict) -> None:
"""Update shared state with strategy information."""
if not self.state:
return
self.state.update_strategy(
z_score=sig.get('z_score', 0.0),
probability=sig.get('probability', 0.0),
funding_rate=funding.get('btc_funding', 0.0),
action=sig.get('action', 'hold'),
reason=sig.get('reason', ''),
)
def _update_position_state(self, current_price: float) -> None:
"""Update shared state with current position."""
if not self.state:
return
symbol = self.trading_config.eth_symbol
position = self.position_manager.get_position_for_symbol(symbol)
if position is None:
self.state.clear_position()
return
pos_state = PositionState(
trade_id=position.trade_id,
symbol=position.symbol,
side=position.side,
entry_price=position.entry_price,
current_price=position.current_price,
size=position.size,
size_usdt=position.size_usdt,
unrealized_pnl=position.unrealized_pnl,
unrealized_pnl_pct=position.unrealized_pnl_pct,
stop_loss_price=position.stop_loss_price,
take_profit_price=position.take_profit_price,
)
self.state.set_position(pos_state)
def _update_account_state(self) -> None:
"""Update shared state with account information."""
if not self.state:
return
try:
balance = self.okx_client.get_balance()
self.state.update_account(
balance=balance.get('total', 0.0),
available=balance.get('free', 0.0),
leverage=self.trading_config.leverage,
)
except Exception as e:
self.logger.warning(f"Failed to update account state: {e}")
def _execute_entry(self, signal: dict, current_price: float) -> None:
"""Execute entry trade."""
symbol = self.trading_config.eth_symbol
@@ -191,11 +309,15 @@ class LiveTradingBot:
# Get account balance
balance = self.okx_client.get_balance()
available_usdt = balance['free']
self.logger.info(f"Account balance: ${available_usdt:.2f} USDT available")
# Calculate position size
size_usdt = self.strategy.calculate_position_size(signal, available_usdt)
if size_usdt <= 0:
self.logger.info("Position size too small, skipping entry")
self.logger.info(
f"Position size too small (${size_usdt:.2f}), skipping entry. "
f"Min required: ${self.strategy.config.min_position_usdt:.2f}"
)
return
size_eth = size_usdt / current_price
@@ -290,22 +412,30 @@ class LiveTradingBot:
except Exception as e:
self.logger.error(f"Exit execution failed: {e}", exc_info=True)
def _is_running(self) -> bool:
"""Check if bot should continue running."""
if not self.running:
return False
if self.state and not self.state.is_running():
return False
return True
def run(self) -> None:
"""Main trading loop."""
self.logger.info("Starting trading loop...")
while self.running:
while self._is_running():
try:
self.run_trading_cycle()
if self.running:
if self._is_running():
sleep_seconds = self.trading_config.sleep_seconds
minutes = sleep_seconds // 60
self.logger.info(f"Sleeping for {minutes} minutes...")
# Sleep in smaller chunks to allow faster shutdown
for _ in range(sleep_seconds):
if not self.running:
if not self._is_running():
break
time.sleep(1)
@@ -319,6 +449,8 @@ class LiveTradingBot:
# Cleanup
self.logger.info("Shutting down...")
self.position_manager.save_positions()
if self.db:
self.db.close()
self.logger.info("Shutdown complete")
@@ -350,6 +482,11 @@ def parse_args():
action="store_true",
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
)
parser.add_argument(
"--no-ui",
action="store_true",
help="Run in headless mode without terminal UI"
)
return parser.parse_args()
@@ -370,19 +507,64 @@ def main():
if args.live:
okx_config.demo_mode = False
# Determine if UI should be enabled
use_ui = not args.no_ui and sys.stdin.isatty()
# Initialize database
db_path = path_config.base_dir / "live_trading" / "trading.db"
db = init_db(db_path)
# Run migrations (imports CSV if exists)
run_migrations(db, path_config.trade_log_file)
# Initialize UI components if enabled
log_queue: Optional[queue.Queue] = None
shared_state: Optional[SharedState] = None
dashboard: Optional[TradingDashboard] = None
if use_ui:
log_queue = queue.Queue(maxsize=1000)
shared_state = SharedState()
# Setup logging
logger = setup_logging(path_config.logs_dir)
logger = setup_logging(path_config.logs_dir, log_queue)
try:
# Create and run bot
bot = LiveTradingBot(okx_config, trading_config, path_config)
# Create bot
bot = LiveTradingBot(
okx_config,
trading_config,
path_config,
database=db,
shared_state=shared_state,
)
# Start dashboard if UI enabled
if use_ui and shared_state and log_queue:
dashboard = TradingDashboard(
state=shared_state,
db=db,
log_queue=log_queue,
on_quit=lambda: setattr(bot, 'running', False),
)
dashboard.start()
logger.info("Dashboard started")
# Run bot
bot.run()
except ValueError as e:
logger.error(f"Configuration error: {e}")
sys.exit(1)
except Exception as e:
logger.error(f"Fatal error: {e}", exc_info=True)
sys.exit(1)
finally:
# Cleanup
if dashboard:
dashboard.stop()
if db:
db.close()
if __name__ == "__main__":

View File

@@ -3,16 +3,21 @@ Position Manager for Live Trading.
Tracks open positions, manages risk, and handles SL/TP logic.
"""
import csv
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
from dataclasses import dataclass, field, asdict
from typing import Optional, TYPE_CHECKING
from dataclasses import dataclass, asdict
from .okx_client import OKXClient
from .config import TradingConfig, PathConfig
if TYPE_CHECKING:
from .db.database import TradingDatabase
from .db.models import Trade
logger = logging.getLogger(__name__)
@@ -78,11 +83,13 @@ class PositionManager:
self,
okx_client: OKXClient,
trading_config: TradingConfig,
path_config: PathConfig
path_config: PathConfig,
database: Optional["TradingDatabase"] = None,
):
self.client = okx_client
self.config = trading_config
self.paths = path_config
self.db = database
self.positions: dict[str, Position] = {}
self.trade_log: list[dict] = []
self._load_positions()
@@ -249,16 +256,55 @@ class PositionManager:
return trade_record
def _append_trade_log(self, trade_record: dict) -> None:
"""Append trade record to CSV log file."""
import csv
"""Append trade record to CSV and SQLite database."""
# Write to CSV (backup/compatibility)
self._append_trade_csv(trade_record)
# Write to SQLite (primary)
self._append_trade_db(trade_record)
def _append_trade_csv(self, trade_record: dict) -> None:
"""Append trade record to CSV log file."""
file_exists = self.paths.trade_log_file.exists()
try:
with open(self.paths.trade_log_file, 'a', newline='') as f:
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
if not file_exists:
writer.writeheader()
writer.writerow(trade_record)
except Exception as e:
logger.error(f"Failed to write trade to CSV: {e}")
def _append_trade_db(self, trade_record: dict) -> None:
"""Append trade record to SQLite database."""
if self.db is None:
return
try:
from .db.models import Trade
trade = Trade(
trade_id=trade_record['trade_id'],
symbol=trade_record['symbol'],
side=trade_record['side'],
entry_price=trade_record['entry_price'],
exit_price=trade_record.get('exit_price'),
size=trade_record['size'],
size_usdt=trade_record['size_usdt'],
pnl_usd=trade_record.get('pnl_usd'),
pnl_pct=trade_record.get('pnl_pct'),
entry_time=trade_record['entry_time'],
exit_time=trade_record.get('exit_time'),
hold_duration_hours=trade_record.get('hold_duration_hours'),
reason=trade_record.get('reason'),
order_id_entry=trade_record.get('order_id_entry'),
order_id_exit=trade_record.get('order_id_exit'),
)
self.db.insert_trade(trade)
logger.debug(f"Trade {trade.trade_id} saved to database")
except Exception as e:
logger.error(f"Failed to write trade to database: {e}")
def update_positions(self, current_prices: dict[str, float]) -> list[dict]:
"""

View File

@@ -0,0 +1,10 @@
"""Terminal UI module for live trading dashboard."""
from .dashboard import TradingDashboard
from .state import SharedState
from .log_handler import UILogHandler
__all__ = [
"TradingDashboard",
"SharedState",
"UILogHandler",
]

View File

@@ -0,0 +1,240 @@
"""Main trading dashboard UI orchestration."""
import logging
import queue
import threading
import time
from typing import Optional, Callable
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from rich.panel import Panel
from rich.text import Text
from .state import SharedState
from .log_handler import LogBuffer, UILogHandler
from .keyboard import KeyboardHandler
from .panels import (
HeaderPanel,
TabBar,
LogPanel,
HelpBar,
build_summary_panel,
)
from ..db.database import TradingDatabase
from ..db.metrics import MetricsCalculator, PeriodMetrics
logger = logging.getLogger(__name__)
class TradingDashboard:
"""
Main trading dashboard orchestrator.
Runs in a separate thread and provides real-time UI updates
while the trading loop runs in the main thread.
"""
def __init__(
self,
state: SharedState,
db: TradingDatabase,
log_queue: queue.Queue,
on_quit: Optional[Callable] = None,
):
self.state = state
self.db = db
self.log_queue = log_queue
self.on_quit = on_quit
self.console = Console()
self.log_buffer = LogBuffer(max_entries=1000)
self.keyboard = KeyboardHandler()
self.metrics_calculator = MetricsCalculator(db)
self._running = False
self._thread: Optional[threading.Thread] = None
self._active_tab = 0
self._cached_metrics: dict[int, PeriodMetrics] = {}
self._last_metrics_refresh = 0.0
def start(self) -> None:
"""Start the dashboard in a separate thread."""
if self._running:
return
self._running = True
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
logger.debug("Dashboard thread started")
def stop(self) -> None:
"""Stop the dashboard."""
self._running = False
if self._thread:
self._thread.join(timeout=2.0)
logger.debug("Dashboard thread stopped")
def _run(self) -> None:
"""Main dashboard loop."""
try:
with self.keyboard:
with Live(
self._build_layout(),
console=self.console,
refresh_per_second=1,
screen=True,
) as live:
while self._running and self.state.is_running():
# Process keyboard input
action = self.keyboard.get_action(timeout=0.1)
if action:
self._handle_action(action)
# Drain log queue
self.log_buffer.drain_queue(self.log_queue)
# Refresh metrics periodically (every 5 seconds)
now = time.time()
if now - self._last_metrics_refresh > 5.0:
self._refresh_metrics()
self._last_metrics_refresh = now
# Update display
live.update(self._build_layout())
# Small sleep to prevent CPU spinning
time.sleep(0.1)
except Exception as e:
logger.error(f"Dashboard error: {e}", exc_info=True)
finally:
self._running = False
def _handle_action(self, action: str) -> None:
"""Handle keyboard action."""
if action == "quit":
logger.info("Quit requested from UI")
self.state.stop()
if self.on_quit:
self.on_quit()
elif action == "refresh":
self._refresh_metrics()
logger.debug("Manual refresh triggered")
elif action == "filter":
new_filter = self.log_buffer.cycle_filter()
logger.debug(f"Log filter changed to: {new_filter}")
elif action == "filter_trades":
self.log_buffer.set_filter(LogBuffer.FILTER_TRADES)
logger.debug("Log filter set to: trades")
elif action == "filter_all":
self.log_buffer.set_filter(LogBuffer.FILTER_ALL)
logger.debug("Log filter set to: all")
elif action == "filter_errors":
self.log_buffer.set_filter(LogBuffer.FILTER_ERRORS)
logger.debug("Log filter set to: errors")
elif action == "tab_general":
self._active_tab = 0
elif action == "tab_monthly":
if self._has_monthly_data():
self._active_tab = 1
elif action == "tab_weekly":
if self._has_weekly_data():
self._active_tab = 2
elif action == "tab_daily":
self._active_tab = 3
def _refresh_metrics(self) -> None:
"""Refresh metrics from database."""
try:
self._cached_metrics[0] = self.metrics_calculator.get_all_time_metrics()
self._cached_metrics[1] = self.metrics_calculator.get_monthly_metrics()
self._cached_metrics[2] = self.metrics_calculator.get_weekly_metrics()
self._cached_metrics[3] = self.metrics_calculator.get_daily_metrics()
except Exception as e:
logger.warning(f"Failed to refresh metrics: {e}")
def _has_monthly_data(self) -> bool:
"""Check if monthly tab should be shown."""
try:
return self.metrics_calculator.has_monthly_data()
except Exception:
return False
def _has_weekly_data(self) -> bool:
"""Check if weekly tab should be shown."""
try:
return self.metrics_calculator.has_weekly_data()
except Exception:
return False
def _build_layout(self) -> Layout:
"""Build the complete dashboard layout."""
layout = Layout()
# Calculate available height
term_height = self.console.height or 40
# Header takes 3 lines
# Help bar takes 1 line
# Summary panel takes about 12-14 lines
# Rest goes to logs
log_height = max(8, term_height - 20)
layout.split_column(
Layout(name="header", size=3),
Layout(name="summary", size=14),
Layout(name="logs", size=log_height),
Layout(name="help", size=1),
)
# Header
layout["header"].update(HeaderPanel(self.state).render())
# Summary panel with tabs
current_metrics = self._cached_metrics.get(self._active_tab)
tab_bar = TabBar(active_tab=self._active_tab)
layout["summary"].update(
build_summary_panel(
state=self.state,
metrics=current_metrics,
tab_bar=tab_bar,
has_monthly=self._has_monthly_data(),
has_weekly=self._has_weekly_data(),
)
)
# Log panel
layout["logs"].update(LogPanel(self.log_buffer).render(height=log_height))
# Help bar
layout["help"].update(HelpBar().render())
return layout
def setup_ui_logging(log_queue: queue.Queue) -> UILogHandler:
"""
Set up logging to capture messages for UI.
Args:
log_queue: Queue to send log messages to
Returns:
UILogHandler instance
"""
handler = UILogHandler(log_queue)
handler.setLevel(logging.INFO)
# Add handler to root logger
root_logger = logging.getLogger()
root_logger.addHandler(handler)
return handler

128
live_trading/ui/keyboard.py Normal file
View File

@@ -0,0 +1,128 @@
"""Keyboard input handling for terminal UI."""
import sys
import select
import termios
import tty
from typing import Optional, Callable
from dataclasses import dataclass
@dataclass
class KeyAction:
"""Represents a keyboard action."""
key: str
action: str
description: str
class KeyboardHandler:
"""
Non-blocking keyboard input handler.
Uses terminal raw mode to capture single keypresses
without waiting for Enter.
"""
# Key mappings
ACTIONS = {
"q": "quit",
"Q": "quit",
"\x03": "quit", # Ctrl+C
"r": "refresh",
"R": "refresh",
"f": "filter",
"F": "filter",
"t": "filter_trades",
"T": "filter_trades",
"l": "filter_all",
"L": "filter_all",
"e": "filter_errors",
"E": "filter_errors",
"1": "tab_general",
"2": "tab_monthly",
"3": "tab_weekly",
"4": "tab_daily",
}
def __init__(self):
self._old_settings = None
self._enabled = False
def enable(self) -> bool:
"""
Enable raw keyboard input mode.
Returns:
True if enabled successfully
"""
try:
if not sys.stdin.isatty():
return False
self._old_settings = termios.tcgetattr(sys.stdin)
tty.setcbreak(sys.stdin.fileno())
self._enabled = True
return True
except Exception:
return False
def disable(self) -> None:
"""Restore normal terminal mode."""
if self._enabled and self._old_settings:
try:
termios.tcsetattr(
sys.stdin,
termios.TCSADRAIN,
self._old_settings,
)
except Exception:
pass
self._enabled = False
def get_key(self, timeout: float = 0.1) -> Optional[str]:
"""
Get a keypress if available (non-blocking).
Args:
timeout: Seconds to wait for input
Returns:
Key character or None if no input
"""
if not self._enabled:
return None
try:
readable, _, _ = select.select([sys.stdin], [], [], timeout)
if readable:
return sys.stdin.read(1)
except Exception:
pass
return None
def get_action(self, timeout: float = 0.1) -> Optional[str]:
"""
Get action name for pressed key.
Args:
timeout: Seconds to wait for input
Returns:
Action name or None
"""
key = self.get_key(timeout)
if key:
return self.ACTIONS.get(key)
return None
def __enter__(self):
"""Context manager entry."""
self.enable()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.disable()
return False

View File

@@ -0,0 +1,178 @@
"""Custom logging handler for UI integration."""
import logging
import queue
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from collections import deque
@dataclass
class LogEntry:
"""A single log entry."""
timestamp: str
level: str
message: str
logger_name: str
@property
def level_color(self) -> str:
"""Get Rich color for log level."""
colors = {
"DEBUG": "dim",
"INFO": "white",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "bold red",
}
return colors.get(self.level, "white")
class UILogHandler(logging.Handler):
"""
Custom logging handler that sends logs to UI.
Uses a thread-safe queue to pass log entries from the trading
thread to the UI thread.
"""
def __init__(
self,
log_queue: queue.Queue,
max_entries: int = 1000,
):
super().__init__()
self.log_queue = log_queue
self.max_entries = max_entries
self.setFormatter(
logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
)
def emit(self, record: logging.LogRecord) -> None:
"""Emit a log record to the queue."""
try:
entry = LogEntry(
timestamp=datetime.fromtimestamp(record.created).strftime(
"%H:%M:%S"
),
level=record.levelname,
message=self.format_message(record),
logger_name=record.name,
)
# Non-blocking put, drop if queue is full
try:
self.log_queue.put_nowait(entry)
except queue.Full:
pass
except Exception:
self.handleError(record)
def format_message(self, record: logging.LogRecord) -> str:
"""Format the log message."""
return record.getMessage()
class LogBuffer:
"""
Thread-safe buffer for log entries with filtering support.
Maintains a fixed-size buffer of log entries and supports
filtering by log type.
"""
FILTER_ALL = "all"
FILTER_ERRORS = "errors"
FILTER_TRADES = "trades"
FILTER_SIGNALS = "signals"
FILTERS = [FILTER_ALL, FILTER_ERRORS, FILTER_TRADES, FILTER_SIGNALS]
def __init__(self, max_entries: int = 1000):
self.max_entries = max_entries
self._entries: deque[LogEntry] = deque(maxlen=max_entries)
self._current_filter = self.FILTER_ALL
def add(self, entry: LogEntry) -> None:
"""Add a log entry to the buffer."""
self._entries.append(entry)
def get_filtered(self, limit: int = 50) -> list[LogEntry]:
"""
Get filtered log entries.
Args:
limit: Maximum number of entries to return
Returns:
List of filtered LogEntry objects (most recent first)
"""
entries = list(self._entries)
if self._current_filter == self.FILTER_ERRORS:
entries = [e for e in entries if e.level in ("ERROR", "CRITICAL")]
elif self._current_filter == self.FILTER_TRADES:
# Key terms indicating actual trading activity
include_keywords = [
"order", "entry", "exit", "executed", "filled",
"opening", "closing", "position opened", "position closed"
]
# Terms to exclude (noise)
exclude_keywords = [
"sync complete", "0 positions", "portfolio: 0 positions"
]
entries = [
e for e in entries
if any(kw in e.message.lower() for kw in include_keywords)
and not any(ex in e.message.lower() for ex in exclude_keywords)
]
elif self._current_filter == self.FILTER_SIGNALS:
signal_keywords = ["signal", "z_score", "prob", "z="]
entries = [
e for e in entries
if any(kw in e.message.lower() for kw in signal_keywords)
]
# Return most recent entries
return list(reversed(entries[-limit:]))
def set_filter(self, filter_name: str) -> None:
"""Set a specific filter."""
if filter_name in self.FILTERS:
self._current_filter = filter_name
def cycle_filter(self) -> str:
"""Cycle to next filter and return its name."""
current_idx = self.FILTERS.index(self._current_filter)
next_idx = (current_idx + 1) % len(self.FILTERS)
self._current_filter = self.FILTERS[next_idx]
return self._current_filter
def get_current_filter(self) -> str:
"""Get current filter name."""
return self._current_filter
def clear(self) -> None:
"""Clear all log entries."""
self._entries.clear()
def drain_queue(self, log_queue: queue.Queue) -> int:
"""
Drain log entries from queue into buffer.
Args:
log_queue: Queue to drain from
Returns:
Number of entries drained
"""
count = 0
while True:
try:
entry = log_queue.get_nowait()
self.add(entry)
count += 1
except queue.Empty:
break
return count

399
live_trading/ui/panels.py Normal file
View File

@@ -0,0 +1,399 @@
"""UI panel components using Rich."""
from typing import Optional
from rich.console import Console, Group
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from rich.layout import Layout
from .state import SharedState, PositionState, StrategyState, AccountState
from .log_handler import LogBuffer, LogEntry
from ..db.metrics import PeriodMetrics
def format_pnl(value: float, include_sign: bool = True) -> Text:
"""Format PnL value with color."""
if value > 0:
sign = "+" if include_sign else ""
return Text(f"{sign}${value:.2f}", style="green")
elif value < 0:
return Text(f"${value:.2f}", style="red")
else:
return Text(f"${value:.2f}", style="white")
def format_pct(value: float, include_sign: bool = True) -> Text:
"""Format percentage value with color."""
if value > 0:
sign = "+" if include_sign else ""
return Text(f"{sign}{value:.2f}%", style="green")
elif value < 0:
return Text(f"{value:.2f}%", style="red")
else:
return Text(f"{value:.2f}%", style="white")
def format_side(side: str) -> Text:
"""Format position side with color."""
if side.lower() == "long":
return Text("LONG", style="bold green")
else:
return Text("SHORT", style="bold red")
class HeaderPanel:
"""Header panel with title and mode indicator."""
def __init__(self, state: SharedState):
self.state = state
def render(self) -> Panel:
"""Render the header panel."""
mode = self.state.get_mode()
eth_symbol, _ = self.state.get_symbols()
mode_style = "yellow" if mode == "DEMO" else "bold red"
mode_text = Text(f"[{mode}]", style=mode_style)
title = Text()
title.append("REGIME REVERSION STRATEGY - LIVE TRADING", style="bold white")
title.append(" ")
title.append(mode_text)
title.append(" ")
title.append(eth_symbol, style="cyan")
return Panel(title, style="blue", height=3)
class TabBar:
"""Tab bar for period selection."""
TABS = ["1:General", "2:Monthly", "3:Weekly", "4:Daily"]
def __init__(self, active_tab: int = 0):
self.active_tab = active_tab
def render(
self,
has_monthly: bool = True,
has_weekly: bool = True,
) -> Text:
"""Render the tab bar."""
text = Text()
text.append(" ")
for i, tab in enumerate(self.TABS):
# Check if tab should be shown
if i == 1 and not has_monthly:
continue
if i == 2 and not has_weekly:
continue
if i == self.active_tab:
text.append(f"[{tab}]", style="bold white on blue")
else:
text.append(f"[{tab}]", style="dim")
text.append(" ")
return text
class MetricsPanel:
"""Panel showing trading metrics."""
def __init__(self, metrics: Optional[PeriodMetrics] = None):
self.metrics = metrics
def render(self) -> Table:
"""Render metrics as a table."""
table = Table(
show_header=False,
show_edge=False,
box=None,
padding=(0, 1),
)
table.add_column("Label", style="dim")
table.add_column("Value")
if self.metrics is None or self.metrics.total_trades == 0:
table.add_row("Status", Text("No trade data", style="dim"))
return table
m = self.metrics
table.add_row("Total PnL:", format_pnl(m.total_pnl))
table.add_row("Win Rate:", Text(f"{m.win_rate:.1f}%", style="white"))
table.add_row("Total Trades:", Text(str(m.total_trades), style="white"))
table.add_row(
"Win/Loss:",
Text(f"{m.winning_trades}/{m.losing_trades}", style="white"),
)
table.add_row(
"Avg Duration:",
Text(f"{m.avg_trade_duration_hours:.1f}h", style="white"),
)
table.add_row("Max Drawdown:", format_pnl(-m.max_drawdown))
table.add_row("Best Trade:", format_pnl(m.best_trade))
table.add_row("Worst Trade:", format_pnl(m.worst_trade))
return table
class PositionPanel:
"""Panel showing current position."""
def __init__(self, position: Optional[PositionState] = None):
self.position = position
def render(self) -> Table:
"""Render position as a table."""
table = Table(
show_header=False,
show_edge=False,
box=None,
padding=(0, 1),
)
table.add_column("Label", style="dim")
table.add_column("Value")
if self.position is None:
table.add_row("Status", Text("No open position", style="dim"))
return table
p = self.position
table.add_row("Side:", format_side(p.side))
table.add_row("Entry:", Text(f"${p.entry_price:.2f}", style="white"))
table.add_row("Current:", Text(f"${p.current_price:.2f}", style="white"))
# Unrealized PnL
pnl_text = Text()
pnl_text.append_text(format_pnl(p.unrealized_pnl))
pnl_text.append(" (")
pnl_text.append_text(format_pct(p.unrealized_pnl_pct))
pnl_text.append(")")
table.add_row("Unrealized:", pnl_text)
table.add_row("Size:", Text(f"${p.size_usdt:.2f}", style="white"))
# SL/TP
if p.side == "long":
sl_dist = (p.stop_loss_price / p.entry_price - 1) * 100
tp_dist = (p.take_profit_price / p.entry_price - 1) * 100
else:
sl_dist = (1 - p.stop_loss_price / p.entry_price) * 100
tp_dist = (1 - p.take_profit_price / p.entry_price) * 100
sl_text = Text(f"${p.stop_loss_price:.2f} ({sl_dist:+.1f}%)", style="red")
tp_text = Text(f"${p.take_profit_price:.2f} ({tp_dist:+.1f}%)", style="green")
table.add_row("Stop Loss:", sl_text)
table.add_row("Take Profit:", tp_text)
return table
class AccountPanel:
"""Panel showing account information."""
def __init__(self, account: Optional[AccountState] = None):
self.account = account
def render(self) -> Table:
"""Render account info as a table."""
table = Table(
show_header=False,
show_edge=False,
box=None,
padding=(0, 1),
)
table.add_column("Label", style="dim")
table.add_column("Value")
if self.account is None:
table.add_row("Status", Text("Loading...", style="dim"))
return table
a = self.account
table.add_row("Balance:", Text(f"${a.balance:.2f}", style="white"))
table.add_row("Available:", Text(f"${a.available:.2f}", style="white"))
table.add_row("Leverage:", Text(f"{a.leverage}x", style="cyan"))
return table
class StrategyPanel:
"""Panel showing strategy state."""
def __init__(self, strategy: Optional[StrategyState] = None):
self.strategy = strategy
def render(self) -> Table:
"""Render strategy state as a table."""
table = Table(
show_header=False,
show_edge=False,
box=None,
padding=(0, 1),
)
table.add_column("Label", style="dim")
table.add_column("Value")
if self.strategy is None:
table.add_row("Status", Text("Waiting...", style="dim"))
return table
s = self.strategy
# Z-score with color based on threshold
z_style = "white"
if abs(s.z_score) > 1.0:
z_style = "yellow"
if abs(s.z_score) > 1.5:
z_style = "bold yellow"
table.add_row("Z-Score:", Text(f"{s.z_score:.2f}", style=z_style))
# Probability with color
prob_style = "white"
if s.probability > 0.5:
prob_style = "green"
if s.probability > 0.7:
prob_style = "bold green"
table.add_row("Probability:", Text(f"{s.probability:.2f}", style=prob_style))
# Funding rate
funding_style = "green" if s.funding_rate >= 0 else "red"
table.add_row(
"Funding:",
Text(f"{s.funding_rate:.4f}", style=funding_style),
)
# Last action
action_style = "white"
if s.last_action == "entry":
action_style = "bold cyan"
elif s.last_action == "check_exit":
action_style = "yellow"
table.add_row("Last Action:", Text(s.last_action, style=action_style))
return table
class LogPanel:
"""Panel showing log entries."""
def __init__(self, log_buffer: LogBuffer):
self.log_buffer = log_buffer
def render(self, height: int = 10) -> Panel:
"""Render log panel."""
filter_name = self.log_buffer.get_current_filter().title()
entries = self.log_buffer.get_filtered(limit=height - 2)
lines = []
for entry in entries:
line = Text()
line.append(f"{entry.timestamp} ", style="dim")
line.append(f"[{entry.level}] ", style=entry.level_color)
line.append(entry.message)
lines.append(line)
if not lines:
lines.append(Text("No logs to display", style="dim"))
content = Group(*lines)
# Build "tabbed" title
tabs = []
# All Logs tab
if filter_name == "All":
tabs.append("[bold white on blue] [L]ogs [/]")
else:
tabs.append("[dim] [L]ogs [/]")
# Trades tab
if filter_name == "Trades":
tabs.append("[bold white on blue] [T]rades [/]")
else:
tabs.append("[dim] [T]rades [/]")
# Errors tab
if filter_name == "Errors":
tabs.append("[bold white on blue] [E]rrors [/]")
else:
tabs.append("[dim] [E]rrors [/]")
title = " ".join(tabs)
subtitle = "Press 'l', 't', 'e' to switch tabs"
return Panel(
content,
title=title,
subtitle=subtitle,
title_align="left",
subtitle_align="right",
border_style="blue",
)
class HelpBar:
"""Bottom help bar with keyboard shortcuts."""
def render(self) -> Text:
"""Render help bar."""
text = Text()
text.append(" [q]", style="bold")
text.append("Quit ", style="dim")
text.append("[r]", style="bold")
text.append("Refresh ", style="dim")
text.append("[1-4]", style="bold")
text.append("Tabs ", style="dim")
text.append("[l/t/e]", style="bold")
text.append("LogView", style="dim")
return text
def build_summary_panel(
state: SharedState,
metrics: Optional[PeriodMetrics],
tab_bar: TabBar,
has_monthly: bool,
has_weekly: bool,
) -> Panel:
"""Build the complete summary panel with all sections."""
# Create layout for summary content
layout = Layout()
# Tab bar at top
tabs = tab_bar.render(has_monthly, has_weekly)
# Create tables for each section
metrics_table = MetricsPanel(metrics).render()
position_table = PositionPanel(state.get_position()).render()
account_table = AccountPanel(state.get_account()).render()
strategy_table = StrategyPanel(state.get_strategy()).render()
# Build two-column layout
left_col = Table(show_header=True, show_edge=False, box=None, padding=(0, 2))
left_col.add_column("PERFORMANCE", style="bold cyan")
left_col.add_column("ACCOUNT", style="bold cyan")
left_col.add_row(metrics_table, account_table)
right_col = Table(show_header=True, show_edge=False, box=None, padding=(0, 2))
right_col.add_column("CURRENT POSITION", style="bold cyan")
right_col.add_column("STRATEGY STATE", style="bold cyan")
right_col.add_row(position_table, strategy_table)
# Combine into main table
main_table = Table(show_header=False, show_edge=False, box=None, expand=True)
main_table.add_column(ratio=1)
main_table.add_column(ratio=1)
main_table.add_row(left_col, right_col)
content = Group(tabs, Text(""), main_table)
return Panel(content, border_style="blue")

195
live_trading/ui/state.py Normal file
View File

@@ -0,0 +1,195 @@
"""Thread-safe shared state for UI and trading loop."""
import threading
from dataclasses import dataclass, field
from typing import Optional
from datetime import datetime, timezone
@dataclass
class PositionState:
"""Current position information."""
trade_id: str = ""
symbol: str = ""
side: str = ""
entry_price: float = 0.0
current_price: float = 0.0
size: float = 0.0
size_usdt: float = 0.0
unrealized_pnl: float = 0.0
unrealized_pnl_pct: float = 0.0
stop_loss_price: float = 0.0
take_profit_price: float = 0.0
@dataclass
class StrategyState:
"""Current strategy signal state."""
z_score: float = 0.0
probability: float = 0.0
funding_rate: float = 0.0
last_action: str = "hold"
last_reason: str = ""
last_signal_time: str = ""
@dataclass
class AccountState:
"""Account balance information."""
balance: float = 0.0
available: float = 0.0
leverage: int = 1
class SharedState:
"""
Thread-safe shared state between trading loop and UI.
All access to state fields should go through the getter/setter methods
which use a lock for thread safety.
"""
def __init__(self):
self._lock = threading.Lock()
self._position: Optional[PositionState] = None
self._strategy = StrategyState()
self._account = AccountState()
self._is_running = True
self._last_cycle_time: Optional[str] = None
self._mode = "DEMO"
self._eth_symbol = "ETH/USDT:USDT"
self._btc_symbol = "BTC/USDT:USDT"
# Position methods
def get_position(self) -> Optional[PositionState]:
"""Get current position state."""
with self._lock:
return self._position
def set_position(self, position: Optional[PositionState]) -> None:
"""Set current position state."""
with self._lock:
self._position = position
def update_position_price(self, current_price: float) -> None:
"""Update current price and recalculate PnL."""
with self._lock:
if self._position is None:
return
self._position.current_price = current_price
if self._position.side == "long":
pnl = (current_price - self._position.entry_price)
self._position.unrealized_pnl = pnl * self._position.size
pnl_pct = (current_price / self._position.entry_price - 1) * 100
else:
pnl = (self._position.entry_price - current_price)
self._position.unrealized_pnl = pnl * self._position.size
pnl_pct = (1 - current_price / self._position.entry_price) * 100
self._position.unrealized_pnl_pct = pnl_pct
def clear_position(self) -> None:
"""Clear current position."""
with self._lock:
self._position = None
# Strategy methods
def get_strategy(self) -> StrategyState:
"""Get current strategy state."""
with self._lock:
return StrategyState(
z_score=self._strategy.z_score,
probability=self._strategy.probability,
funding_rate=self._strategy.funding_rate,
last_action=self._strategy.last_action,
last_reason=self._strategy.last_reason,
last_signal_time=self._strategy.last_signal_time,
)
def update_strategy(
self,
z_score: float,
probability: float,
funding_rate: float,
action: str,
reason: str,
) -> None:
"""Update strategy state."""
with self._lock:
self._strategy.z_score = z_score
self._strategy.probability = probability
self._strategy.funding_rate = funding_rate
self._strategy.last_action = action
self._strategy.last_reason = reason
self._strategy.last_signal_time = datetime.now(
timezone.utc
).isoformat()
# Account methods
def get_account(self) -> AccountState:
"""Get current account state."""
with self._lock:
return AccountState(
balance=self._account.balance,
available=self._account.available,
leverage=self._account.leverage,
)
def update_account(
self,
balance: float,
available: float,
leverage: int,
) -> None:
"""Update account state."""
with self._lock:
self._account.balance = balance
self._account.available = available
self._account.leverage = leverage
# Control methods
def is_running(self) -> bool:
"""Check if trading loop is running."""
with self._lock:
return self._is_running
def stop(self) -> None:
"""Signal to stop trading loop."""
with self._lock:
self._is_running = False
def get_last_cycle_time(self) -> Optional[str]:
"""Get last trading cycle time."""
with self._lock:
return self._last_cycle_time
def set_last_cycle_time(self, time_str: str) -> None:
"""Set last trading cycle time."""
with self._lock:
self._last_cycle_time = time_str
# Config methods
def get_mode(self) -> str:
"""Get trading mode (DEMO/LIVE)."""
with self._lock:
return self._mode
def set_mode(self, mode: str) -> None:
"""Set trading mode."""
with self._lock:
self._mode = mode
def get_symbols(self) -> tuple[str, str]:
"""Get trading symbols (eth, btc)."""
with self._lock:
return self._eth_symbol, self._btc_symbol
def set_symbols(self, eth_symbol: str, btc_symbol: str) -> None:
"""Set trading symbols."""
with self._lock:
self._eth_symbol = eth_symbol
self._btc_symbol = btc_symbol

View File

@@ -15,6 +15,8 @@ dependencies = [
"plotly>=5.24.0",
"requests>=2.32.5",
"python-dotenv>=1.2.1",
# Terminal UI
"rich>=13.0.0",
# API dependencies
"fastapi>=0.115.0",
"uvicorn[standard]>=0.34.0",

View File

@@ -3,7 +3,16 @@ Regime Detection Research Script with Walk-Forward Training.
Tests multiple holding horizons to find optimal parameters
without look-ahead bias.
Usage:
uv run python research/regime_detection.py [options]
Options:
--days DAYS Number of days of data (default: 90)
--start DATE Start date (YYYY-MM-DD), overrides --days
--end DATE End date (YYYY-MM-DD), defaults to now
"""
import argparse
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -25,18 +34,36 @@ TRAIN_RATIO = 0.7 # 70% train, 30% test
PROFIT_THRESHOLD = 0.005 # 0.5% profit target
Z_WINDOW = 24
FEE_RATE = 0.001 # 0.1% round-trip fee
DEFAULT_DAYS = 90 # Default lookback period in days
def load_data():
"""Load and align BTC/ETH data."""
def load_data(days: int = DEFAULT_DAYS, start_date: str = None, end_date: str = None):
"""
Load and align BTC/ETH data.
Args:
days: Number of days of historical data (default: 90)
start_date: Optional start date (YYYY-MM-DD), overrides days
end_date: Optional end date (YYYY-MM-DD), defaults to now
Returns:
Tuple of (df_btc, df_eth) DataFrames
"""
dm = DataManager()
df_btc = dm.load_data("okx", "BTC-USDT", "1h", MarketType.SPOT)
df_eth = dm.load_data("okx", "ETH-USDT", "1h", MarketType.SPOT)
# Filter to Oct-Dec 2025
start = pd.Timestamp("2025-10-01", tz="UTC")
end = pd.Timestamp("2025-12-31", tz="UTC")
# Determine date range
if end_date:
end = pd.Timestamp(end_date, tz="UTC")
else:
end = pd.Timestamp.now(tz="UTC")
if start_date:
start = pd.Timestamp(start_date, tz="UTC")
else:
start = end - pd.Timedelta(days=days)
df_btc = df_btc[(df_btc.index >= start) & (df_btc.index <= end)]
df_eth = df_eth[(df_eth.index >= start) & (df_eth.index <= end)]
@@ -46,7 +73,7 @@ def load_data():
df_btc = df_btc.loc[common]
df_eth = df_eth.loc[common]
logger.info(f"Loaded {len(common)} aligned hourly bars")
logger.info(f"Loaded {len(common)} aligned hourly bars from {start} to {end}")
return df_btc, df_eth
@@ -168,7 +195,10 @@ def calculate_mae(features, predictions, test_idx, horizon):
def calculate_net_profit(features, predictions, test_idx, horizon):
"""Calculate estimated net profit including fees."""
"""
Calculate estimated net profit including fees.
Enforces 'one trade at a time' to avoid inflating returns with overlapping signals.
"""
test_features = features.loc[test_idx]
spread = test_features['spread']
z_score = test_features['z_score']
@@ -176,7 +206,14 @@ def calculate_net_profit(features, predictions, test_idx, horizon):
total_pnl = 0.0
n_trades = 0
# Track when we are free to trade again
next_trade_idx = 0
for i, (idx, pred) in enumerate(zip(test_idx, predictions)):
# Skip if we are still in a trade
if i < next_trade_idx:
continue
if pred != 1:
continue
@@ -204,6 +241,10 @@ def calculate_net_profit(features, predictions, test_idx, horizon):
total_pnl += net_pnl
n_trades += 1
# Set next available trade index (simple non-overlapping logic)
# We assume we hold for 'horizon' bars
next_trade_idx = i + horizon
return total_pnl, n_trades
@@ -295,10 +336,54 @@ def test_horizons(features, horizons):
return results
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Regime detection research - test multiple horizons"
)
parser.add_argument(
"--days",
type=int,
default=DEFAULT_DAYS,
help=f"Number of days of data (default: {DEFAULT_DAYS})"
)
parser.add_argument(
"--start",
type=str,
default=None,
help="Start date (YYYY-MM-DD), overrides --days"
)
parser.add_argument(
"--end",
type=str,
default=None,
help="End date (YYYY-MM-DD), defaults to now"
)
parser.add_argument(
"--output",
type=str,
default="research/horizon_optimization_results.csv",
help="Output CSV path"
)
parser.add_argument(
"--output-horizon",
type=str,
default=None,
help="Path to save the best horizon (integer) to a file"
)
return parser.parse_args()
def main():
"""Main research function."""
# Load data
df_btc, df_eth = load_data()
args = parse_args()
# Load data with dynamic date range
df_btc, df_eth = load_data(
days=args.days,
start_date=args.start,
end_date=args.end
)
cq_df = load_cryptoquant_data()
# Calculate features
@@ -312,7 +397,7 @@ def main():
if not results:
print("No valid results!")
return
return None
# Find best by different metrics
results_df = pd.DataFrame(results)
@@ -331,9 +416,15 @@ def main():
print(f"Lowest MAE: {lowest_mae['horizon']:.0f}h (MAE={lowest_mae['avg_mae']:.2f}%)")
# Save results
output_path = "research/horizon_optimization_results.csv"
results_df.to_csv(output_path, index=False)
print(f"\nResults saved to {output_path}")
results_df.to_csv(args.output, index=False)
print(f"\nResults saved to {args.output}")
# Save best horizon if requested
if args.output_horizon:
best_h = int(best_pnl['horizon'])
with open(args.output_horizon, 'w') as f:
f.write(str(best_h))
print(f"Best horizon {best_h}h saved to {args.output_horizon}")
return results_df

16
setup_schedule.sh Executable file
View File

@@ -0,0 +1,16 @@
#!/bin/bash
# Setup script for Systemd Timer (Daily Training)
SERVICE_FILE="tasks/lowkey-training.service"
TIMER_FILE="tasks/lowkey-training.timer"
SYSTEMD_DIR="/etc/systemd/system"
echo "To install the daily training schedule, please run the following commands:"
echo ""
echo "sudo cp $SERVICE_FILE $SYSTEMD_DIR/"
echo "sudo cp $TIMER_FILE $SYSTEMD_DIR/"
echo "sudo systemctl daemon-reload"
echo "sudo systemctl enable --now lowkey-training.timer"
echo ""
echo "To check the status:"
echo "systemctl list-timers --all | grep lowkey"

View File

@@ -0,0 +1,15 @@
[Unit]
Description=Lowkey Backtest Daily Model Training
After=network.target
[Service]
Type=oneshot
WorkingDirectory=/home/tamaya/Documents/Work/TCP/lowkey_backtest_live
ExecStart=/home/tamaya/Documents/Work/TCP/lowkey_backtest_live/train_daily.sh
User=tamaya
Group=tamaya
StandardOutput=append:/home/tamaya/Documents/Work/TCP/lowkey_backtest_live/logs/training.log
StandardError=append:/home/tamaya/Documents/Work/TCP/lowkey_backtest_live/logs/training.log
[Install]
WantedBy=multi-user.target

View File

@@ -0,0 +1,10 @@
[Unit]
Description=Run Lowkey Backtest Training Daily
[Timer]
OnCalendar=*-*-* 00:30:00
Persistent=true
Unit=lowkey-training.service
[Install]
WantedBy=timers.target

351
tasks/prd-terminal-ui.md Normal file
View File

@@ -0,0 +1,351 @@
# PRD: Terminal UI for Live Trading Bot
## Introduction/Overview
The live trading bot currently uses basic console logging for output, making it difficult to monitor trading activity, track performance, and understand the system state at a glance. This feature introduces a Rich-based terminal UI that provides a professional, real-time dashboard for monitoring the live trading bot.
The UI will display a horizontal split layout with a **summary panel** at the top (with tabbed time-period views) and a **scrollable log panel** at the bottom. The interface will update every second and support keyboard navigation.
## Goals
1. Provide real-time visibility into trading performance (PnL, win rate, trade count)
2. Enable monitoring of current position state (entry, SL/TP, unrealized PnL)
3. Display strategy signals (Z-score, model probability) for transparency
4. Support historical performance tracking across time periods (daily, weekly, monthly, all-time)
5. Improve operational experience with keyboard shortcuts and log filtering
6. Create a responsive design that works across different terminal sizes
## User Stories
1. **As a trader**, I want to see my total PnL and daily PnL at a glance so I can quickly assess performance.
2. **As a trader**, I want to see my current position details (entry price, unrealized PnL, SL/TP levels) so I can monitor risk.
3. **As a trader**, I want to view performance metrics by time period (daily, weekly, monthly) so I can track trends.
4. **As a trader**, I want to filter logs by type (errors, trades, signals) so I can focus on relevant information.
5. **As a trader**, I want keyboard shortcuts to navigate the UI without using a mouse.
6. **As a trader**, I want the UI to show strategy state (Z-score, probability) so I understand why signals are generated.
## Functional Requirements
### FR1: Layout Structure
1.1. The UI must use a horizontal split layout with the summary panel at the top and logs panel at the bottom.
1.2. The summary panel must contain tabbed views accessible via number keys:
- Tab 1 (`1`): **General** - Overall metrics since bot started
- Tab 2 (`2`): **Monthly** - Current month metrics (shown only if data spans > 1 month)
- Tab 3 (`3`): **Weekly** - Current week metrics (shown only if data spans > 1 week)
- Tab 4 (`4`): **Daily** - Today's metrics
1.3. The logs panel must be a scrollable area showing recent log entries.
1.4. The UI must be responsive and adapt to terminal size (minimum 80x24).
### FR2: Metrics Display
The summary panel must display the following metrics:
**Performance Metrics:**
2.1. Total PnL (USD) - cumulative profit/loss since tracking began
2.2. Period PnL (USD) - profit/loss for selected time period (daily/weekly/monthly)
2.3. Win Rate (%) - percentage of winning trades
2.4. Total Number of Trades
2.5. Average Trade Duration (hours)
2.6. Max Drawdown (USD and %)
**Current Position (if open):**
2.7. Symbol and side (long/short)
2.8. Entry price
2.9. Current price
2.10. Unrealized PnL (USD and %)
2.11. Stop-loss price and distance (%)
2.12. Take-profit price and distance (%)
2.13. Position size (USD)
**Account Status:**
2.14. Account balance / available margin (USDT)
2.15. Current leverage setting
**Strategy State:**
2.16. Current Z-score
2.17. Model probability
2.18. Current funding rate (BTC)
2.19. Last signal action and reason
### FR3: Historical Data Loading
3.1. On startup, the system must initialize SQLite database at `live_trading/trading.db`.
3.2. If `trade_log.csv` exists and database is empty, migrate CSV data to SQLite.
3.3. The UI must load current positions from `live_trading/positions.json` (kept for compatibility with existing position manager).
3.4. Metrics must be calculated via SQL aggregation queries for each time period.
3.5. If no historical data exists, the UI must show "No data" gracefully.
3.6. New trades must be written to both SQLite (primary) and CSV (backup/compatibility).
### FR4: Real-Time Updates
4.1. The UI must refresh every 1 second.
4.2. Position unrealized PnL must update based on latest price data.
4.3. New log entries must appear in real-time.
4.4. Metrics must recalculate when trades are opened/closed.
### FR5: Log Panel
5.1. The log panel must display log entries with timestamp, level, and message.
5.2. Log entries must be color-coded by level:
- ERROR: Red
- WARNING: Yellow
- INFO: White/Default
- DEBUG: Gray (if shown)
5.3. The log panel must support filtering by log type:
- All logs (default)
- Errors only
- Trades only (entries containing "position", "trade", "order")
- Signals only (entries containing "signal", "z_score", "prob")
5.4. Filter switching must be available via keyboard shortcut (`f` to cycle filters).
### FR6: Keyboard Controls
6.1. `q` or `Ctrl+C` - Graceful shutdown
6.2. `r` - Force refresh data
6.3. `1` - Switch to General tab
6.4. `2` - Switch to Monthly tab
6.5. `3` - Switch to Weekly tab
6.6. `4` - Switch to Daily tab
6.7. `f` - Cycle log filter
6.8. Arrow keys - Scroll logs (if supported)
### FR7: Color Scheme
7.1. Use dark theme as base.
7.2. PnL values must be colored:
- Positive: Green
- Negative: Red
- Zero/Neutral: White
7.3. Position side must be colored:
- Long: Green
- Short: Red
7.4. Use consistent color coding for emphasis and warnings.
## Non-Goals (Out of Scope)
1. **Mouse support** - This is a keyboard-driven terminal UI
2. **Trade execution from UI** - The UI is read-only; trades are executed by the bot
3. **Configuration editing** - Config changes require restarting the bot
4. **Multi-exchange support** - Only OKX is supported
5. **Charts/graphs** - Text-based metrics only (no ASCII charts in v1)
6. **Sound alerts** - No audio notifications
7. **Remote access** - Local terminal only
## Design Considerations
### Technology Choice: Rich
Use the [Rich](https://github.com/Textualize/rich) Python library for terminal UI:
- Rich provides `Live` display for real-time updates
- Rich `Layout` for split-screen design
- Rich `Table` for metrics display
- Rich `Panel` for bordered sections
- Rich `Text` for colored output
Alternative considered: **Textual** (also by Will McGugan) provides more advanced TUI features but adds complexity. Rich is simpler and sufficient for this use case.
### UI Mockup
```
+==============================================================================+
| REGIME REVERSION STRATEGY - LIVE TRADING [DEMO] ETH/USDT |
+==============================================================================+
| [1:General] [2:Monthly] [3:Weekly] [4:Daily] |
+------------------------------------------------------------------------------+
| PERFORMANCE | CURRENT POSITION |
| Total PnL: $1,234.56 | Side: LONG |
| Today PnL: $45.23 | Entry: $3,245.50 |
| Win Rate: 67.5% | Current: $3,289.00 |
| Total Trades: 24 | Unrealized: +$43.50 (+1.34%) |
| Avg Duration: 4.2h | Size: $500.00 |
| Max Drawdown: -$156.00 | SL: $3,050.00 (-6.0%) TP: $3,408.00 (+5%)|
| | |
| ACCOUNT | STRATEGY STATE |
| Balance: $5,432.10 | Z-Score: 1.45 |
| Available: $4,932.10 | Probability: 0.72 |
| Leverage: 2x | Funding: 0.0012 |
+------------------------------------------------------------------------------+
| LOGS [Filter: All] Press 'f' cycle |
+------------------------------------------------------------------------------+
| 14:32:15 [INFO] Trading Cycle Start: 2026-01-16T14:32:15+00:00 |
| 14:32:16 [INFO] Signal: entry long (prob=0.72, z=-1.45, reason=z_score...) |
| 14:32:17 [INFO] Executing LONG entry: 0.1540 ETH @ 3245.50 ($500.00) |
| 14:32:18 [INFO] Position opened: ETH/USDT:USDT_20260116_143217 |
| 14:32:18 [INFO] Portfolio: 1 positions, exposure=$500.00, unrealized=$0.00 |
| 14:32:18 [INFO] --- Cycle completed in 3.2s --- |
| 14:32:18 [INFO] Sleeping for 60 minutes... |
| |
+------------------------------------------------------------------------------+
| [q]Quit [r]Refresh [1-4]Tabs [f]Filter |
+==============================================================================+
```
### File Structure
```
live_trading/
ui/
__init__.py
dashboard.py # Main UI orchestration and threading
panels.py # Panel components (metrics, logs, position)
state.py # Thread-safe shared state
log_handler.py # Custom logging handler for UI queue
keyboard.py # Keyboard input handling
db/
__init__.py
database.py # SQLite connection and queries
models.py # Data models (Trade, DailySummary, Session)
migrations.py # CSV migration and schema setup
metrics.py # Metrics aggregation queries
trading.db # SQLite database file (created at runtime)
```
## Technical Considerations
### Integration with Existing Code
1. **Logging Integration**: Create a custom `logging.Handler` that captures log messages and forwards them to the UI log panel while still writing to file.
2. **Data Access**: The UI needs access to:
- `PositionManager` for current positions
- `TradingConfig` for settings display
- SQLite database for historical metrics
- Real-time data from `DataFeed` for current prices
3. **Main Loop Modification**: The `LiveTradingBot.run()` method needs modification to run the UI in a **separate thread** alongside the trading loop. The UI thread handles rendering and keyboard input while the main thread executes trading logic.
4. **Graceful Shutdown**: Ensure `SIGINT`/`SIGTERM` handlers work with the UI layer and properly terminate the UI thread.
### Database Schema (SQLite)
Create `live_trading/trading.db` with the following schema:
```sql
-- Trade history table
CREATE TABLE trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trade_id TEXT UNIQUE NOT NULL,
symbol TEXT NOT NULL,
side TEXT NOT NULL, -- 'long' or 'short'
entry_price REAL NOT NULL,
exit_price REAL,
size REAL NOT NULL,
size_usdt REAL NOT NULL,
pnl_usd REAL,
pnl_pct REAL,
entry_time TEXT NOT NULL, -- ISO format
exit_time TEXT,
hold_duration_hours REAL,
reason TEXT, -- 'stop_loss', 'take_profit', 'signal', etc.
order_id_entry TEXT,
order_id_exit TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Daily summary table (for faster queries)
CREATE TABLE daily_summary (
id INTEGER PRIMARY KEY AUTOINCREMENT,
date TEXT UNIQUE NOT NULL, -- YYYY-MM-DD
total_trades INTEGER DEFAULT 0,
winning_trades INTEGER DEFAULT 0,
total_pnl_usd REAL DEFAULT 0,
max_drawdown_usd REAL DEFAULT 0,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Session metadata
CREATE TABLE sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
start_time TEXT NOT NULL,
end_time TEXT,
starting_balance REAL,
ending_balance REAL,
total_pnl REAL,
total_trades INTEGER DEFAULT 0
);
-- Indexes for common queries
CREATE INDEX idx_trades_entry_time ON trades(entry_time);
CREATE INDEX idx_trades_exit_time ON trades(exit_time);
CREATE INDEX idx_daily_summary_date ON daily_summary(date);
```
### Migration from CSV
On first run with the new system:
1. Check if `trading.db` exists
2. If not, create database with schema
3. If `trade_log.csv` exists, migrate data to `trades` table
4. Rebuild `daily_summary` from migrated trades
### Dependencies
Add to `pyproject.toml`:
```toml
dependencies = [
# ... existing deps
"rich>=13.0.0",
]
```
Note: SQLite is part of Python's standard library (`sqlite3`), no additional dependency needed.
### Performance Considerations
- UI runs in a separate thread to avoid blocking trading logic
- Log buffer limited to 1000 entries in memory to prevent growth
- SQLite queries should use indexes for fast period-based aggregations
- Historical data loading happens once at startup, incremental updates thereafter
### Threading Model
```
Main Thread UI Thread
| |
v v
[Trading Loop] [Rich Live Display]
| |
+---> SharedState <------------+
(thread-safe)
| |
+---> LogQueue <---------------+
(thread-safe)
```
- Use `threading.Lock` for shared state access
- Use `queue.Queue` for log message passing
- UI thread polls for updates every 1 second
## Success Metrics
1. UI starts successfully and displays all required metrics
2. UI updates in real-time (1-second refresh) without impacting trading performance
3. All keyboard shortcuts function correctly
4. Historical data loads and displays accurately from SQLite
5. Log filtering works as expected
6. UI gracefully handles edge cases (no data, no position, terminal resize)
7. CSV migration completes successfully on first run
8. Database queries complete within 100ms
---
*Generated: 2026-01-16*
*Decisions: Threading model, 1000 log buffer, SQLite database, no fallback mode*

50
train_daily.sh Executable file
View File

@@ -0,0 +1,50 @@
#!/bin/bash
# Daily ML Model Training Script
#
# Downloads fresh data and retrains the regime detection model.
# Can be run manually or scheduled via cron.
#
# Usage:
# ./train_daily.sh # Full workflow
# ./train_daily.sh --skip-research # Skip research validation
set -e # Exit on error
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
LOG_DIR="logs"
mkdir -p "$LOG_DIR"
TIMESTAMP=$(date +"%Y-%m-%d %H:%M:%S")
echo "[$TIMESTAMP] Starting daily training..."
# 1. Download fresh data
echo "Downloading BTC-USDT 1h data..."
uv run python main.py download -p BTC-USDT -t 1h
echo "Downloading ETH-USDT 1h data..."
uv run python main.py download -p ETH-USDT -t 1h
# 2. Research optimization (find best horizon)
echo "Running research optimization..."
uv run python research/regime_detection.py --output-horizon data/optimal_horizon.txt
# 3. Read best horizon
if [[ -f "data/optimal_horizon.txt" ]]; then
BEST_HORIZON=$(cat data/optimal_horizon.txt)
echo "Found optimal horizon: ${BEST_HORIZON} bars"
else
BEST_HORIZON=102
echo "Warning: Could not find optimal horizon file. Using default: ${BEST_HORIZON}"
fi
# 4. Train model
echo "Training ML model with horizon ${BEST_HORIZON}..."
uv run python train_model.py --horizon "$BEST_HORIZON"
# 5. Cleanup
rm -f data/optimal_horizon.txt
TIMESTAMP=$(date +"%Y-%m-%d %H:%M:%S")
echo "[$TIMESTAMP] Daily training complete."

451
train_model.py Normal file
View File

@@ -0,0 +1,451 @@
"""
ML Model Training Script.
Trains the regime detection Random Forest model on historical data.
Can be run manually or scheduled via cron for daily retraining.
Usage:
uv run python train_model.py [options]
Options:
--days DAYS Number of days of historical data to use (default: 90)
--pair PAIR Trading pair for context (default: BTC-USDT)
--timeframe TF Timeframe (default: 1h)
--output PATH Output model path (default: data/regime_model.pkl)
--train-ratio R Train/test split ratio (default: 0.7)
--dry-run Run without saving model
"""
import argparse
import pickle
import sys
from datetime import datetime, timedelta
from pathlib import Path
import numpy as np
import pandas as pd
import ta
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, f1_score
from engine.data_manager import DataManager
from engine.market import MarketType
from engine.logging_config import get_logger
logger = get_logger(__name__)
# Default configuration (from research optimization)
DEFAULT_HORIZON = 102 # 4.25 days - optimal from research
DEFAULT_Z_WINDOW = 24 # 24h rolling window
DEFAULT_PROFIT_TARGET = 0.005 # 0.5% profit threshold
DEFAULT_Z_THRESHOLD = 1.0 # Z-score entry threshold
DEFAULT_TRAIN_RATIO = 0.7 # 70% train / 30% test
FEE_RATE = 0.001 # 0.1% round-trip fee
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Train the regime detection ML model"
)
parser.add_argument(
"--days",
type=int,
default=90,
help="Number of days of historical data to use (default: 90)"
)
parser.add_argument(
"--pair",
type=str,
default="BTC-USDT",
help="Base pair for context data (default: BTC-USDT)"
)
parser.add_argument(
"--spread-pair",
type=str,
default="ETH-USDT",
help="Spread pair to trade (default: ETH-USDT)"
)
parser.add_argument(
"--timeframe",
type=str,
default="1h",
help="Timeframe (default: 1h)"
)
parser.add_argument(
"--market",
type=str,
choices=["spot", "perpetual"],
default="perpetual",
help="Market type (default: perpetual)"
)
parser.add_argument(
"--output",
type=str,
default="data/regime_model.pkl",
help="Output model path (default: data/regime_model.pkl)"
)
parser.add_argument(
"--train-ratio",
type=float,
default=DEFAULT_TRAIN_RATIO,
help="Train/test split ratio (default: 0.7)"
)
parser.add_argument(
"--horizon",
type=int,
default=DEFAULT_HORIZON,
help="Prediction horizon in bars (default: 102)"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Run without saving model"
)
parser.add_argument(
"--download",
action="store_true",
help="Download latest data before training"
)
return parser.parse_args()
def download_data(dm: DataManager, pair: str, timeframe: str, market_type: MarketType):
"""Download latest data for a pair."""
logger.info(f"Downloading latest data for {pair}...")
try:
dm.download_data("okx", pair, timeframe, market_type)
logger.info(f"Downloaded {pair} data")
except Exception as e:
logger.error(f"Failed to download {pair}: {e}")
raise
def load_data(
dm: DataManager,
base_pair: str,
spread_pair: str,
timeframe: str,
market_type: MarketType,
days: int
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Load and align historical data for both pairs."""
df_base = dm.load_data("okx", base_pair, timeframe, market_type)
df_spread = dm.load_data("okx", spread_pair, timeframe, market_type)
# Filter to last N days
end_date = pd.Timestamp.now(tz="UTC")
start_date = end_date - timedelta(days=days)
df_base = df_base[(df_base.index >= start_date) & (df_base.index <= end_date)]
df_spread = df_spread[(df_spread.index >= start_date) & (df_spread.index <= end_date)]
# Align indices
common = df_base.index.intersection(df_spread.index)
df_base = df_base.loc[common]
df_spread = df_spread.loc[common]
logger.info(
f"Loaded {len(common)} bars from {common.min()} to {common.max()}"
)
return df_base, df_spread
def load_cryptoquant_data() -> pd.DataFrame | None:
"""Load CryptoQuant on-chain data if available."""
try:
cq_path = Path("data/cq_training_data.csv")
if not cq_path.exists():
logger.info("CryptoQuant data not found, skipping on-chain features")
return None
cq_df = pd.read_csv(cq_path, index_col='timestamp', parse_dates=True)
if cq_df.index.tz is None:
cq_df.index = cq_df.index.tz_localize('UTC')
logger.info(f"Loaded CryptoQuant data: {len(cq_df)} rows")
return cq_df
except Exception as e:
logger.warning(f"Could not load CryptoQuant data: {e}")
return None
def calculate_features(
df_base: pd.DataFrame,
df_spread: pd.DataFrame,
cq_df: pd.DataFrame | None = None,
z_window: int = DEFAULT_Z_WINDOW
) -> pd.DataFrame:
"""Calculate all features for the model."""
spread = df_spread['close'] / df_base['close']
# Z-Score
rolling_mean = spread.rolling(window=z_window).mean()
rolling_std = spread.rolling(window=z_window).std()
z_score = (spread - rolling_mean) / rolling_std
# Technicals
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
spread_roc = spread.pct_change(periods=5) * 100
spread_change_1h = spread.pct_change(periods=1)
# Volume
vol_ratio = df_spread['volume'] / df_base['volume']
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
# Volatility
ret_base = df_base['close'].pct_change()
ret_spread = df_spread['close'].pct_change()
vol_base = ret_base.rolling(window=z_window).std()
vol_spread = ret_spread.rolling(window=z_window).std()
vol_spread_ratio = vol_spread / vol_base
features = pd.DataFrame(index=spread.index)
features['spread'] = spread
features['z_score'] = z_score
features['spread_rsi'] = spread_rsi
features['spread_roc'] = spread_roc
features['spread_change_1h'] = spread_change_1h
features['vol_ratio'] = vol_ratio
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
features['vol_diff_ratio'] = vol_spread_ratio
# Add CQ features if available
if cq_df is not None:
cq_aligned = cq_df.reindex(features.index, method='ffill')
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
cq_aligned['funding_diff'] = (
cq_aligned['eth_funding'] - cq_aligned['btc_funding']
)
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
cq_aligned['inflow_ratio'] = (
cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
)
features = features.join(cq_aligned)
return features.dropna()
def calculate_targets(
features: pd.DataFrame,
horizon: int,
profit_target: float = DEFAULT_PROFIT_TARGET,
z_threshold: float = DEFAULT_Z_THRESHOLD
) -> tuple[np.ndarray, pd.Series]:
"""Calculate target labels for training."""
spread = features['spread']
z_score = features['z_score']
# For Short (Z > threshold): Did spread drop below target?
future_min = spread.rolling(window=horizon).min().shift(-horizon)
target_short = spread * (1 - profit_target)
success_short = (z_score > z_threshold) & (future_min < target_short)
# For Long (Z < -threshold): Did spread rise above target?
future_max = spread.rolling(window=horizon).max().shift(-horizon)
target_long = spread * (1 + profit_target)
success_long = (z_score < -z_threshold) & (future_max > target_long)
targets = np.select([success_short, success_long], [1, 1], default=0)
# Create valid mask (rows with complete future data)
valid_mask = future_min.notna() & future_max.notna()
return targets, valid_mask
def train_model(
features: pd.DataFrame,
train_ratio: float = DEFAULT_TRAIN_RATIO,
horizon: int = DEFAULT_HORIZON
) -> tuple[RandomForestClassifier, list[str], dict]:
"""
Train Random Forest model with walk-forward split.
Args:
features: DataFrame with calculated features
train_ratio: Fraction of data to use for training
horizon: Prediction horizon in bars
Returns:
Tuple of (trained model, feature columns, metrics dict)
"""
# Calculate targets
targets, valid_mask = calculate_targets(features, horizon)
# Walk-forward split
n_samples = len(features)
train_size = int(n_samples * train_ratio)
train_features = features.iloc[:train_size]
test_features = features.iloc[train_size:]
train_targets = targets[:train_size]
test_targets = targets[train_size:]
train_valid = valid_mask.iloc[:train_size]
test_valid = valid_mask.iloc[train_size:]
# Prepare training data
exclude = ['spread']
feature_cols = [c for c in features.columns if c not in exclude]
X_train = train_features[feature_cols].fillna(0).replace([np.inf, -np.inf], 0)
X_train_valid = X_train[train_valid]
y_train_valid = train_targets[train_valid]
if len(X_train_valid) < 100:
raise ValueError(
f"Not enough training data: {len(X_train_valid)} samples (need >= 100)"
)
logger.info(f"Training on {len(X_train_valid)} samples...")
# Train model
model = RandomForestClassifier(
n_estimators=300,
max_depth=5,
min_samples_leaf=30,
class_weight={0: 1, 1: 3},
random_state=42
)
model.fit(X_train_valid, y_train_valid)
# Evaluate on test set
X_test = test_features[feature_cols].fillna(0).replace([np.inf, -np.inf], 0)
predictions = model.predict(X_test)
# Only evaluate on valid test rows
test_valid_mask = test_valid.values
y_test_valid = test_targets[test_valid_mask]
pred_valid = predictions[test_valid_mask]
# Calculate metrics
f1 = f1_score(y_test_valid, pred_valid, zero_division=0)
metrics = {
'train_samples': len(X_train_valid),
'test_samples': len(X_test),
'f1_score': f1,
'train_end': train_features.index[-1].isoformat(),
'test_start': test_features.index[0].isoformat(),
'horizon': horizon,
'feature_cols': feature_cols,
}
logger.info(f"Model trained. F1 Score: {f1:.3f}")
logger.info(
f"Train period: {train_features.index[0]} to {train_features.index[-1]}"
)
logger.info(
f"Test period: {test_features.index[0]} to {test_features.index[-1]}"
)
return model, feature_cols, metrics
def save_model(
model: RandomForestClassifier,
feature_cols: list[str],
metrics: dict,
output_path: str,
versioned: bool = True
):
"""
Save trained model to file.
Args:
model: Trained model
feature_cols: List of feature column names
metrics: Training metrics
output_path: Output file path
versioned: If True, also save a timestamped version
"""
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
data = {
'model': model,
'feature_cols': feature_cols,
'metrics': metrics,
'trained_at': datetime.now().isoformat(),
}
# Save main model file
with open(output, 'wb') as f:
pickle.dump(data, f)
logger.info(f"Saved model to {output}")
# Save versioned copy
if versioned:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
versioned_path = output.parent / f"regime_model_{timestamp}.pkl"
with open(versioned_path, 'wb') as f:
pickle.dump(data, f)
logger.info(f"Saved versioned model to {versioned_path}")
def main():
"""Main training function."""
args = parse_args()
market_type = MarketType.PERPETUAL if args.market == "perpetual" else MarketType.SPOT
dm = DataManager()
# Download latest data if requested
if args.download:
download_data(dm, args.pair, args.timeframe, market_type)
download_data(dm, args.spread_pair, args.timeframe, market_type)
# Load data
try:
df_base, df_spread = load_data(
dm, args.pair, args.spread_pair, args.timeframe, market_type, args.days
)
except Exception as e:
logger.error(f"Failed to load data: {e}")
logger.info("Try running with --download flag to fetch latest data")
sys.exit(1)
# Load on-chain data
cq_df = load_cryptoquant_data()
# Calculate features
features = calculate_features(df_base, df_spread, cq_df)
logger.info(
f"Calculated {len(features)} feature rows with {len(features.columns)} columns"
)
if len(features) < 200:
logger.error(f"Not enough data: {len(features)} rows (need >= 200)")
sys.exit(1)
# Train model
try:
model, feature_cols, metrics = train_model(
features, args.train_ratio, args.horizon
)
except ValueError as e:
logger.error(f"Training failed: {e}")
sys.exit(1)
# Print metrics summary
print("\n" + "=" * 60)
print("TRAINING COMPLETE")
print("=" * 60)
print(f"Train samples: {metrics['train_samples']}")
print(f"Test samples: {metrics['test_samples']}")
print(f"F1 Score: {metrics['f1_score']:.3f}")
print(f"Horizon: {metrics['horizon']} bars")
print(f"Features: {len(feature_cols)}")
print("=" * 60)
# Save model
if not args.dry_run:
save_model(model, feature_cols, metrics, args.output)
else:
logger.info("Dry run - model not saved")
return 0
if __name__ == "__main__":
sys.exit(main())

36
uv.lock generated
View File

@@ -981,6 +981,7 @@ dependencies = [
{ name = "plotly" },
{ name = "python-dotenv" },
{ name = "requests" },
{ name = "rich" },
{ name = "scikit-learn" },
{ name = "sqlalchemy" },
{ name = "ta" },
@@ -1004,6 +1005,7 @@ requires-dist = [
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
{ name = "python-dotenv", specifier = ">=1.2.1" },
{ name = "requests", specifier = ">=2.32.5" },
{ name = "rich", specifier = ">=13.0.0" },
{ name = "scikit-learn", specifier = ">=1.6.0" },
{ name = "sqlalchemy", specifier = ">=2.0.0" },
{ name = "ta", specifier = ">=0.11.0" },
@@ -1012,6 +1014,18 @@ requires-dist = [
]
provides-extras = ["dev"]
[[package]]
name = "markdown-it-py"
version = "4.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mdurl" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
]
[[package]]
name = "matplotlib"
version = "3.10.8"
@@ -1078,6 +1092,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" },
]
[[package]]
name = "mdurl"
version = "0.1.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
]
[[package]]
name = "multidict"
version = "6.7.0"
@@ -1912,6 +1935,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
]
[[package]]
name = "rich"
version = "14.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markdown-it-py" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/fb/d2/8920e102050a0de7bfabeb4c4614a49248cf8d5d7a8d01885fbb24dc767a/rich-14.2.0.tar.gz", hash = "sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4", size = 219990, upload-time = "2025-10-09T14:16:53.064Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/25/7a/b0178788f8dc6cafce37a212c99565fa1fe7872c70c6c9c1e1a372d9d88f/rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd", size = 243393, upload-time = "2025-10-09T14:16:51.245Z" },
]
[[package]]
name = "schedule"
version = "1.2.2"