Add daily model training scripts and terminal UI for live trading

- Introduced `train_daily.sh` for automating daily model retraining, including data download and model training steps.
- Added `install_cron.sh` for setting up a cron job to run the daily training script.
- Created `setup_schedule.sh` for configuring Systemd timers for daily training tasks.
- Implemented a terminal UI using Rich for real-time monitoring of trading performance, including metrics display and log handling.
- Updated `pyproject.toml` to include the `rich` dependency for UI functionality.
- Enhanced `.gitignore` to exclude model and log files.
- Added database support for trade persistence and metrics calculation.
- Updated README with installation and usage instructions for the new features.
This commit is contained in:
2026-01-18 11:08:57 +08:00
parent 35992ee374
commit b5550f4ff4
27 changed files with 3582 additions and 113 deletions

View File

@@ -60,7 +60,7 @@ class TradingConfig:
# Position sizing
max_position_usdt: float = -1.0 # Max position size in USDT. If <= 0, use all available funds
min_position_usdt: float = 10.0 # Min position size in USDT
min_position_usdt: float = 1.0 # Min position size in USDT
leverage: int = 1 # Leverage (1x = no leverage)
margin_mode: str = "cross" # "cross" or "isolated"

View File

@@ -0,0 +1,13 @@
"""Database module for live trading persistence."""
from .database import get_db, init_db
from .models import Trade, DailySummary, Session
from .metrics import MetricsCalculator
__all__ = [
"get_db",
"init_db",
"Trade",
"DailySummary",
"Session",
"MetricsCalculator",
]

325
live_trading/db/database.py Normal file
View File

@@ -0,0 +1,325 @@
"""SQLite database connection and operations."""
import sqlite3
import logging
from pathlib import Path
from typing import Optional
from contextlib import contextmanager
from .models import Trade, DailySummary, Session
logger = logging.getLogger(__name__)
# Database schema
SCHEMA = """
-- Trade history table
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trade_id TEXT UNIQUE NOT NULL,
symbol TEXT NOT NULL,
side TEXT NOT NULL,
entry_price REAL NOT NULL,
exit_price REAL,
size REAL NOT NULL,
size_usdt REAL NOT NULL,
pnl_usd REAL,
pnl_pct REAL,
entry_time TEXT NOT NULL,
exit_time TEXT,
hold_duration_hours REAL,
reason TEXT,
order_id_entry TEXT,
order_id_exit TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Daily summary table
CREATE TABLE IF NOT EXISTS daily_summary (
id INTEGER PRIMARY KEY AUTOINCREMENT,
date TEXT UNIQUE NOT NULL,
total_trades INTEGER DEFAULT 0,
winning_trades INTEGER DEFAULT 0,
total_pnl_usd REAL DEFAULT 0,
max_drawdown_usd REAL DEFAULT 0,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
);
-- Session metadata
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
start_time TEXT NOT NULL,
end_time TEXT,
starting_balance REAL,
ending_balance REAL,
total_pnl REAL,
total_trades INTEGER DEFAULT 0
);
-- Indexes for common queries
CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades(entry_time);
CREATE INDEX IF NOT EXISTS idx_trades_exit_time ON trades(exit_time);
CREATE INDEX IF NOT EXISTS idx_daily_summary_date ON daily_summary(date);
"""
_db_instance: Optional["TradingDatabase"] = None
class TradingDatabase:
"""SQLite database for trade persistence."""
def __init__(self, db_path: Path):
self.db_path = db_path
self._connection: Optional[sqlite3.Connection] = None
@property
def connection(self) -> sqlite3.Connection:
"""Get or create database connection."""
if self._connection is None:
self._connection = sqlite3.connect(
str(self.db_path),
check_same_thread=False,
)
self._connection.row_factory = sqlite3.Row
return self._connection
def init_schema(self) -> None:
"""Initialize database schema."""
with self.connection:
self.connection.executescript(SCHEMA)
logger.info(f"Database initialized at {self.db_path}")
def close(self) -> None:
"""Close database connection."""
if self._connection:
self._connection.close()
self._connection = None
@contextmanager
def transaction(self):
"""Context manager for database transactions."""
try:
yield self.connection
self.connection.commit()
except Exception:
self.connection.rollback()
raise
def insert_trade(self, trade: Trade) -> int:
"""
Insert a new trade record.
Args:
trade: Trade object to insert
Returns:
Row ID of inserted trade
"""
sql = """
INSERT INTO trades (
trade_id, symbol, side, entry_price, exit_price,
size, size_usdt, pnl_usd, pnl_pct, entry_time,
exit_time, hold_duration_hours, reason,
order_id_entry, order_id_exit
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
with self.transaction():
cursor = self.connection.execute(
sql,
(
trade.trade_id,
trade.symbol,
trade.side,
trade.entry_price,
trade.exit_price,
trade.size,
trade.size_usdt,
trade.pnl_usd,
trade.pnl_pct,
trade.entry_time,
trade.exit_time,
trade.hold_duration_hours,
trade.reason,
trade.order_id_entry,
trade.order_id_exit,
),
)
return cursor.lastrowid
def update_trade(self, trade_id: str, **kwargs) -> bool:
"""
Update an existing trade record.
Args:
trade_id: Trade ID to update
**kwargs: Fields to update
Returns:
True if trade was updated
"""
if not kwargs:
return False
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
sql = f"UPDATE trades SET {set_clause} WHERE trade_id = ?"
with self.transaction():
cursor = self.connection.execute(
sql, (*kwargs.values(), trade_id)
)
return cursor.rowcount > 0
def get_trade(self, trade_id: str) -> Optional[Trade]:
"""Get a trade by ID."""
sql = "SELECT * FROM trades WHERE trade_id = ?"
row = self.connection.execute(sql, (trade_id,)).fetchone()
if row:
return Trade(**dict(row))
return None
def get_trades(
self,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
limit: Optional[int] = None,
) -> list[Trade]:
"""
Get trades within a time range.
Args:
start_time: ISO format start time filter
end_time: ISO format end time filter
limit: Maximum number of trades to return
Returns:
List of Trade objects
"""
conditions = []
params = []
if start_time:
conditions.append("entry_time >= ?")
params.append(start_time)
if end_time:
conditions.append("entry_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions) if conditions else "1=1"
limit_clause = f"LIMIT {limit}" if limit else ""
sql = f"""
SELECT * FROM trades
WHERE {where_clause}
ORDER BY entry_time DESC
{limit_clause}
"""
rows = self.connection.execute(sql, params).fetchall()
return [Trade(**dict(row)) for row in rows]
def get_all_trades(self) -> list[Trade]:
"""Get all trades."""
sql = "SELECT * FROM trades ORDER BY entry_time DESC"
rows = self.connection.execute(sql).fetchall()
return [Trade(**dict(row)) for row in rows]
def count_trades(self) -> int:
"""Get total number of trades."""
sql = "SELECT COUNT(*) FROM trades WHERE exit_time IS NOT NULL"
return self.connection.execute(sql).fetchone()[0]
def upsert_daily_summary(self, summary: DailySummary) -> None:
"""Insert or update daily summary."""
sql = """
INSERT INTO daily_summary (
date, total_trades, winning_trades, total_pnl_usd, max_drawdown_usd
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(date) DO UPDATE SET
total_trades = excluded.total_trades,
winning_trades = excluded.winning_trades,
total_pnl_usd = excluded.total_pnl_usd,
max_drawdown_usd = excluded.max_drawdown_usd,
updated_at = CURRENT_TIMESTAMP
"""
with self.transaction():
self.connection.execute(
sql,
(
summary.date,
summary.total_trades,
summary.winning_trades,
summary.total_pnl_usd,
summary.max_drawdown_usd,
),
)
def get_daily_summary(self, date: str) -> Optional[DailySummary]:
"""Get daily summary for a specific date."""
sql = "SELECT * FROM daily_summary WHERE date = ?"
row = self.connection.execute(sql, (date,)).fetchone()
if row:
return DailySummary(**dict(row))
return None
def insert_session(self, session: Session) -> int:
"""Insert a new session record."""
sql = """
INSERT INTO sessions (
start_time, end_time, starting_balance,
ending_balance, total_pnl, total_trades
) VALUES (?, ?, ?, ?, ?, ?)
"""
with self.transaction():
cursor = self.connection.execute(
sql,
(
session.start_time,
session.end_time,
session.starting_balance,
session.ending_balance,
session.total_pnl,
session.total_trades,
),
)
return cursor.lastrowid
def update_session(self, session_id: int, **kwargs) -> bool:
"""Update an existing session."""
if not kwargs:
return False
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
sql = f"UPDATE sessions SET {set_clause} WHERE id = ?"
with self.transaction():
cursor = self.connection.execute(
sql, (*kwargs.values(), session_id)
)
return cursor.rowcount > 0
def get_latest_session(self) -> Optional[Session]:
"""Get the most recent session."""
sql = "SELECT * FROM sessions ORDER BY id DESC LIMIT 1"
row = self.connection.execute(sql).fetchone()
if row:
return Session(**dict(row))
return None
def init_db(db_path: Path) -> TradingDatabase:
"""
Initialize the database.
Args:
db_path: Path to the SQLite database file
Returns:
TradingDatabase instance
"""
global _db_instance
_db_instance = TradingDatabase(db_path)
_db_instance.init_schema()
return _db_instance
def get_db() -> Optional[TradingDatabase]:
"""Get the global database instance."""
return _db_instance

235
live_trading/db/metrics.py Normal file
View File

@@ -0,0 +1,235 @@
"""Metrics calculation from trade database."""
import logging
from dataclasses import dataclass
from datetime import datetime, timezone, timedelta
from typing import Optional
from .database import TradingDatabase
logger = logging.getLogger(__name__)
@dataclass
class PeriodMetrics:
"""Trading metrics for a time period."""
period_name: str
start_time: Optional[str]
end_time: Optional[str]
total_pnl: float = 0.0
total_trades: int = 0
winning_trades: int = 0
losing_trades: int = 0
win_rate: float = 0.0
avg_trade_duration_hours: float = 0.0
max_drawdown: float = 0.0
max_drawdown_pct: float = 0.0
best_trade: float = 0.0
worst_trade: float = 0.0
avg_win: float = 0.0
avg_loss: float = 0.0
class MetricsCalculator:
"""Calculate trading metrics from database."""
def __init__(self, db: TradingDatabase):
self.db = db
def get_all_time_metrics(self) -> PeriodMetrics:
"""Get metrics for all trades ever."""
return self._calculate_metrics("All Time", None, None)
def get_monthly_metrics(self) -> PeriodMetrics:
"""Get metrics for current month."""
now = datetime.now(timezone.utc)
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return self._calculate_metrics(
"Monthly",
start.isoformat(),
now.isoformat(),
)
def get_weekly_metrics(self) -> PeriodMetrics:
"""Get metrics for current week (Monday to now)."""
now = datetime.now(timezone.utc)
days_since_monday = now.weekday()
start = now - timedelta(days=days_since_monday)
start = start.replace(hour=0, minute=0, second=0, microsecond=0)
return self._calculate_metrics(
"Weekly",
start.isoformat(),
now.isoformat(),
)
def get_daily_metrics(self) -> PeriodMetrics:
"""Get metrics for today (UTC)."""
now = datetime.now(timezone.utc)
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
return self._calculate_metrics(
"Daily",
start.isoformat(),
now.isoformat(),
)
def _calculate_metrics(
self,
period_name: str,
start_time: Optional[str],
end_time: Optional[str],
) -> PeriodMetrics:
"""
Calculate metrics for a time period.
Args:
period_name: Name of the period
start_time: ISO format start time (None for all time)
end_time: ISO format end time (None for all time)
Returns:
PeriodMetrics object
"""
metrics = PeriodMetrics(
period_name=period_name,
start_time=start_time,
end_time=end_time,
)
# Build query conditions
conditions = ["exit_time IS NOT NULL"]
params = []
if start_time:
conditions.append("exit_time >= ?")
params.append(start_time)
if end_time:
conditions.append("exit_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions)
# Get aggregate metrics
sql = f"""
SELECT
COUNT(*) as total_trades,
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
SUM(CASE WHEN pnl_usd < 0 THEN 1 ELSE 0 END) as losing_trades,
COALESCE(SUM(pnl_usd), 0) as total_pnl,
COALESCE(AVG(hold_duration_hours), 0) as avg_duration,
COALESCE(MAX(pnl_usd), 0) as best_trade,
COALESCE(MIN(pnl_usd), 0) as worst_trade,
COALESCE(AVG(CASE WHEN pnl_usd > 0 THEN pnl_usd END), 0) as avg_win,
COALESCE(AVG(CASE WHEN pnl_usd < 0 THEN pnl_usd END), 0) as avg_loss
FROM trades
WHERE {where_clause}
"""
row = self.db.connection.execute(sql, params).fetchone()
if row and row["total_trades"] > 0:
metrics.total_trades = row["total_trades"]
metrics.winning_trades = row["winning_trades"] or 0
metrics.losing_trades = row["losing_trades"] or 0
metrics.total_pnl = row["total_pnl"]
metrics.avg_trade_duration_hours = row["avg_duration"]
metrics.best_trade = row["best_trade"]
metrics.worst_trade = row["worst_trade"]
metrics.avg_win = row["avg_win"]
metrics.avg_loss = row["avg_loss"]
if metrics.total_trades > 0:
metrics.win_rate = (
metrics.winning_trades / metrics.total_trades * 100
)
# Calculate max drawdown
metrics.max_drawdown = self._calculate_max_drawdown(
start_time, end_time
)
return metrics
def _calculate_max_drawdown(
self,
start_time: Optional[str],
end_time: Optional[str],
) -> float:
"""Calculate maximum drawdown for a period."""
conditions = ["exit_time IS NOT NULL"]
params = []
if start_time:
conditions.append("exit_time >= ?")
params.append(start_time)
if end_time:
conditions.append("exit_time <= ?")
params.append(end_time)
where_clause = " AND ".join(conditions)
sql = f"""
SELECT pnl_usd
FROM trades
WHERE {where_clause}
ORDER BY exit_time
"""
rows = self.db.connection.execute(sql, params).fetchall()
if not rows:
return 0.0
cumulative = 0.0
peak = 0.0
max_drawdown = 0.0
for row in rows:
pnl = row["pnl_usd"] or 0.0
cumulative += pnl
peak = max(peak, cumulative)
drawdown = peak - cumulative
max_drawdown = max(max_drawdown, drawdown)
return max_drawdown
def has_monthly_data(self) -> bool:
"""Check if we have data spanning more than current month."""
sql = """
SELECT MIN(exit_time) as first_trade
FROM trades
WHERE exit_time IS NOT NULL
"""
row = self.db.connection.execute(sql).fetchone()
if not row or not row["first_trade"]:
return False
first_trade = datetime.fromisoformat(row["first_trade"])
now = datetime.now(timezone.utc)
month_start = now.replace(day=1, hour=0, minute=0, second=0)
return first_trade < month_start
def has_weekly_data(self) -> bool:
"""Check if we have data spanning more than current week."""
sql = """
SELECT MIN(exit_time) as first_trade
FROM trades
WHERE exit_time IS NOT NULL
"""
row = self.db.connection.execute(sql).fetchone()
if not row or not row["first_trade"]:
return False
first_trade = datetime.fromisoformat(row["first_trade"])
now = datetime.now(timezone.utc)
days_since_monday = now.weekday()
week_start = now - timedelta(days=days_since_monday)
week_start = week_start.replace(hour=0, minute=0, second=0)
return first_trade < week_start
def get_session_start_balance(self) -> Optional[float]:
"""Get starting balance from latest session."""
sql = "SELECT starting_balance FROM sessions ORDER BY id DESC LIMIT 1"
row = self.db.connection.execute(sql).fetchone()
return row["starting_balance"] if row else None

View File

@@ -0,0 +1,191 @@
"""Database migrations and CSV import."""
import csv
import logging
from pathlib import Path
from datetime import datetime
from .database import TradingDatabase
from .models import Trade, DailySummary
logger = logging.getLogger(__name__)
def migrate_csv_to_db(db: TradingDatabase, csv_path: Path) -> int:
"""
Migrate trades from CSV file to SQLite database.
Args:
db: TradingDatabase instance
csv_path: Path to trade_log.csv
Returns:
Number of trades migrated
"""
if not csv_path.exists():
logger.info("No CSV file to migrate")
return 0
# Check if database already has trades
existing_count = db.count_trades()
if existing_count > 0:
logger.info(
f"Database already has {existing_count} trades, skipping migration"
)
return 0
migrated = 0
try:
with open(csv_path, "r", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
trade = _csv_row_to_trade(row)
if trade:
try:
db.insert_trade(trade)
migrated += 1
except Exception as e:
logger.warning(
f"Failed to migrate trade {row.get('trade_id')}: {e}"
)
logger.info(f"Migrated {migrated} trades from CSV to database")
except Exception as e:
logger.error(f"CSV migration failed: {e}")
return migrated
def _csv_row_to_trade(row: dict) -> Trade | None:
"""Convert a CSV row to a Trade object."""
try:
return Trade(
trade_id=row["trade_id"],
symbol=row["symbol"],
side=row["side"],
entry_price=float(row["entry_price"]),
exit_price=_safe_float(row.get("exit_price")),
size=float(row["size"]),
size_usdt=float(row["size_usdt"]),
pnl_usd=_safe_float(row.get("pnl_usd")),
pnl_pct=_safe_float(row.get("pnl_pct")),
entry_time=row["entry_time"],
exit_time=row.get("exit_time") or None,
hold_duration_hours=_safe_float(row.get("hold_duration_hours")),
reason=row.get("reason") or None,
order_id_entry=row.get("order_id_entry") or None,
order_id_exit=row.get("order_id_exit") or None,
)
except (KeyError, ValueError) as e:
logger.warning(f"Invalid CSV row: {e}")
return None
def _safe_float(value: str | None) -> float | None:
"""Safely convert string to float."""
if value is None or value == "":
return None
try:
return float(value)
except ValueError:
return None
def rebuild_daily_summaries(db: TradingDatabase) -> int:
"""
Rebuild daily summary table from trades.
Args:
db: TradingDatabase instance
Returns:
Number of daily summaries created
"""
sql = """
SELECT
DATE(exit_time) as date,
COUNT(*) as total_trades,
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
SUM(pnl_usd) as total_pnl_usd
FROM trades
WHERE exit_time IS NOT NULL
GROUP BY DATE(exit_time)
ORDER BY date
"""
rows = db.connection.execute(sql).fetchall()
count = 0
for row in rows:
summary = DailySummary(
date=row["date"],
total_trades=row["total_trades"],
winning_trades=row["winning_trades"],
total_pnl_usd=row["total_pnl_usd"] or 0.0,
max_drawdown_usd=0.0, # Calculated separately
)
db.upsert_daily_summary(summary)
count += 1
# Calculate max drawdowns
_calculate_daily_drawdowns(db)
logger.info(f"Rebuilt {count} daily summaries")
return count
def _calculate_daily_drawdowns(db: TradingDatabase) -> None:
"""Calculate and update max drawdown for each day."""
sql = """
SELECT trade_id, DATE(exit_time) as date, pnl_usd
FROM trades
WHERE exit_time IS NOT NULL
ORDER BY exit_time
"""
rows = db.connection.execute(sql).fetchall()
# Track cumulative PnL and drawdown per day
daily_drawdowns: dict[str, float] = {}
cumulative_pnl = 0.0
peak_pnl = 0.0
for row in rows:
date = row["date"]
pnl = row["pnl_usd"] or 0.0
cumulative_pnl += pnl
peak_pnl = max(peak_pnl, cumulative_pnl)
drawdown = peak_pnl - cumulative_pnl
if date not in daily_drawdowns:
daily_drawdowns[date] = 0.0
daily_drawdowns[date] = max(daily_drawdowns[date], drawdown)
# Update daily summaries with drawdown
for date, drawdown in daily_drawdowns.items():
db.connection.execute(
"UPDATE daily_summary SET max_drawdown_usd = ? WHERE date = ?",
(drawdown, date),
)
db.connection.commit()
def run_migrations(db: TradingDatabase, csv_path: Path) -> None:
"""
Run all migrations.
Args:
db: TradingDatabase instance
csv_path: Path to trade_log.csv for migration
"""
logger.info("Running database migrations...")
# Migrate CSV data if exists
migrate_csv_to_db(db, csv_path)
# Rebuild daily summaries
rebuild_daily_summaries(db)
logger.info("Migrations complete")

69
live_trading/db/models.py Normal file
View File

@@ -0,0 +1,69 @@
"""Data models for trade persistence."""
from dataclasses import dataclass, asdict
from typing import Optional
from datetime import datetime
@dataclass
class Trade:
"""Represents a completed trade."""
trade_id: str
symbol: str
side: str
entry_price: float
size: float
size_usdt: float
entry_time: str
exit_price: Optional[float] = None
pnl_usd: Optional[float] = None
pnl_pct: Optional[float] = None
exit_time: Optional[str] = None
hold_duration_hours: Optional[float] = None
reason: Optional[str] = None
order_id_entry: Optional[str] = None
order_id_exit: Optional[str] = None
id: Optional[int] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
@classmethod
def from_row(cls, row: tuple, columns: list[str]) -> "Trade":
"""Create Trade from database row."""
data = dict(zip(columns, row))
return cls(**data)
@dataclass
class DailySummary:
"""Daily trading summary."""
date: str
total_trades: int = 0
winning_trades: int = 0
total_pnl_usd: float = 0.0
max_drawdown_usd: float = 0.0
id: Optional[int] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
@dataclass
class Session:
"""Trading session metadata."""
start_time: str
end_time: Optional[str] = None
starting_balance: Optional[float] = None
ending_balance: Optional[float] = None
total_pnl: Optional[float] = None
total_trades: int = 0
id: Optional[int] = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)

View File

@@ -39,18 +39,40 @@ class LiveRegimeStrategy:
self.paths = path_config
self.model: Optional[RandomForestClassifier] = None
self.feature_cols: Optional[list] = None
self.horizon: int = 102 # Default horizon
self._last_model_load_time: float = 0.0
self._load_or_train_model()
def reload_model_if_changed(self) -> None:
"""Check if model file has changed and reload if necessary."""
if not self.paths.model_path.exists():
return
try:
mtime = self.paths.model_path.stat().st_mtime
if mtime > self._last_model_load_time:
logger.info(f"Model file changed, reloading... (last: {self._last_model_load_time}, new: {mtime})")
self._load_or_train_model()
except Exception as e:
logger.warning(f"Error checking model file: {e}")
def _load_or_train_model(self) -> None:
"""Load pre-trained model or train a new one."""
if self.paths.model_path.exists():
try:
self._last_model_load_time = self.paths.model_path.stat().st_mtime
with open(self.paths.model_path, 'rb') as f:
saved = pickle.load(f)
self.model = saved['model']
self.feature_cols = saved['feature_cols']
logger.info(f"Loaded model from {self.paths.model_path}")
return
# Load horizon from metrics if available
if 'metrics' in saved and 'horizon' in saved['metrics']:
self.horizon = saved['metrics']['horizon']
logger.info(f"Loaded model from {self.paths.model_path} (horizon={self.horizon})")
else:
logger.info(f"Loaded model from {self.paths.model_path} (default horizon={self.horizon})")
return
except Exception as e:
logger.warning(f"Could not load model: {e}")
@@ -66,6 +88,7 @@ class LiveRegimeStrategy:
pickle.dump({
'model': self.model,
'feature_cols': self.feature_cols,
'metrics': {'horizon': self.horizon} # Save horizon
}, f)
logger.info(f"Saved model to {self.paths.model_path}")
except Exception as e:
@@ -81,7 +104,7 @@ class LiveRegimeStrategy:
logger.info(f"Training model on {len(features)} samples...")
z_thresh = self.config.z_entry_threshold
horizon = 102 # Optimal horizon from research
horizon = self.horizon
profit_target = 0.005 # 0.5% profit threshold
# Define targets

View File

@@ -11,14 +11,19 @@ Usage:
# Run with specific settings
uv run python -m live_trading.main --max-position 500 --leverage 2
# Run without UI (headless mode)
uv run python -m live_trading.main --no-ui
"""
import argparse
import logging
import queue
import signal
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
@@ -28,22 +33,47 @@ from live_trading.okx_client import OKXClient
from live_trading.data_feed import DataFeed
from live_trading.position_manager import PositionManager
from live_trading.live_regime_strategy import LiveRegimeStrategy
from live_trading.db.database import init_db, TradingDatabase
from live_trading.db.migrations import run_migrations
from live_trading.ui.state import SharedState, PositionState
from live_trading.ui.dashboard import TradingDashboard, setup_ui_logging
def setup_logging(log_dir: Path) -> logging.Logger:
"""Configure logging for the trading bot."""
def setup_logging(
log_dir: Path,
log_queue: Optional[queue.Queue] = None,
) -> logging.Logger:
"""
Configure logging for the trading bot.
Args:
log_dir: Directory for log files
log_queue: Optional queue for UI log handler
Returns:
Logger instance
"""
log_file = log_dir / "live_trading.log"
handlers = [
logging.FileHandler(log_file),
]
# Only add StreamHandler if no UI (log_queue is None)
if log_queue is None:
handlers.append(logging.StreamHandler(sys.stdout))
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout),
],
force=True
handlers=handlers,
force=True,
)
# Add UI log handler if queue provided
if log_queue is not None:
setup_ui_logging(log_queue)
return logging.getLogger(__name__)
@@ -59,11 +89,15 @@ class LiveTradingBot:
self,
okx_config: OKXConfig,
trading_config: TradingConfig,
path_config: PathConfig
path_config: PathConfig,
database: Optional[TradingDatabase] = None,
shared_state: Optional[SharedState] = None,
):
self.okx_config = okx_config
self.trading_config = trading_config
self.path_config = path_config
self.db = database
self.state = shared_state
self.logger = logging.getLogger(__name__)
self.running = True
@@ -74,7 +108,7 @@ class LiveTradingBot:
self.okx_client = OKXClient(okx_config, trading_config)
self.data_feed = DataFeed(self.okx_client, trading_config, path_config)
self.position_manager = PositionManager(
self.okx_client, trading_config, path_config
self.okx_client, trading_config, path_config, database
)
self.strategy = LiveRegimeStrategy(trading_config, path_config)
@@ -82,6 +116,16 @@ class LiveTradingBot:
signal.signal(signal.SIGINT, self._handle_shutdown)
signal.signal(signal.SIGTERM, self._handle_shutdown)
# Initialize shared state if provided
if self.state:
mode = "DEMO" if okx_config.demo_mode else "LIVE"
self.state.set_mode(mode)
self.state.set_symbols(
trading_config.eth_symbol,
trading_config.btc_symbol,
)
self.state.update_account(0.0, 0.0, trading_config.leverage)
self._print_startup_banner()
def _print_startup_banner(self) -> None:
@@ -109,6 +153,8 @@ class LiveTradingBot:
"""Handle shutdown signals gracefully."""
self.logger.info("Shutdown signal received, stopping...")
self.running = False
if self.state:
self.state.stop()
def run_trading_cycle(self) -> None:
"""
@@ -118,10 +164,20 @@ class LiveTradingBot:
2. Update open positions
3. Generate trading signal
4. Execute trades if signal triggers
5. Update shared state for UI
"""
# Reload model if it has changed (e.g. daily training)
try:
self.strategy.reload_model_if_changed()
except Exception as e:
self.logger.warning(f"Failed to reload model: {e}")
cycle_start = datetime.now(timezone.utc)
self.logger.info(f"--- Trading Cycle Start: {cycle_start.isoformat()} ---")
if self.state:
self.state.set_last_cycle_time(cycle_start.isoformat())
try:
# 1. Fetch market data
features = self.data_feed.get_latest_data()
@@ -154,15 +210,22 @@ class LiveTradingBot:
funding = self.data_feed.get_current_funding_rates()
# 5. Generate trading signal
signal = self.strategy.generate_signal(features, funding)
sig = self.strategy.generate_signal(features, funding)
# 6. Execute trades based on signal
if signal['action'] == 'entry':
self._execute_entry(signal, eth_price)
elif signal['action'] == 'check_exit':
self._execute_exit(signal)
# 6. Update shared state with strategy info
self._update_strategy_state(sig, funding)
# 7. Log portfolio summary
# 7. Execute trades based on signal
if sig['action'] == 'entry':
self._execute_entry(sig, eth_price)
elif sig['action'] == 'check_exit':
self._execute_exit(sig)
# 8. Update shared state with position and account
self._update_position_state(eth_price)
self._update_account_state()
# 9. Log portfolio summary
summary = self.position_manager.get_portfolio_summary()
self.logger.info(
f"Portfolio: {summary['open_positions']} positions, "
@@ -178,6 +241,61 @@ class LiveTradingBot:
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
self.logger.info(f"--- Cycle completed in {cycle_duration:.1f}s ---")
def _update_strategy_state(self, sig: dict, funding: dict) -> None:
"""Update shared state with strategy information."""
if not self.state:
return
self.state.update_strategy(
z_score=sig.get('z_score', 0.0),
probability=sig.get('probability', 0.0),
funding_rate=funding.get('btc_funding', 0.0),
action=sig.get('action', 'hold'),
reason=sig.get('reason', ''),
)
def _update_position_state(self, current_price: float) -> None:
"""Update shared state with current position."""
if not self.state:
return
symbol = self.trading_config.eth_symbol
position = self.position_manager.get_position_for_symbol(symbol)
if position is None:
self.state.clear_position()
return
pos_state = PositionState(
trade_id=position.trade_id,
symbol=position.symbol,
side=position.side,
entry_price=position.entry_price,
current_price=position.current_price,
size=position.size,
size_usdt=position.size_usdt,
unrealized_pnl=position.unrealized_pnl,
unrealized_pnl_pct=position.unrealized_pnl_pct,
stop_loss_price=position.stop_loss_price,
take_profit_price=position.take_profit_price,
)
self.state.set_position(pos_state)
def _update_account_state(self) -> None:
"""Update shared state with account information."""
if not self.state:
return
try:
balance = self.okx_client.get_balance()
self.state.update_account(
balance=balance.get('total', 0.0),
available=balance.get('free', 0.0),
leverage=self.trading_config.leverage,
)
except Exception as e:
self.logger.warning(f"Failed to update account state: {e}")
def _execute_entry(self, signal: dict, current_price: float) -> None:
"""Execute entry trade."""
symbol = self.trading_config.eth_symbol
@@ -191,11 +309,15 @@ class LiveTradingBot:
# Get account balance
balance = self.okx_client.get_balance()
available_usdt = balance['free']
self.logger.info(f"Account balance: ${available_usdt:.2f} USDT available")
# Calculate position size
size_usdt = self.strategy.calculate_position_size(signal, available_usdt)
if size_usdt <= 0:
self.logger.info("Position size too small, skipping entry")
self.logger.info(
f"Position size too small (${size_usdt:.2f}), skipping entry. "
f"Min required: ${self.strategy.config.min_position_usdt:.2f}"
)
return
size_eth = size_usdt / current_price
@@ -290,22 +412,30 @@ class LiveTradingBot:
except Exception as e:
self.logger.error(f"Exit execution failed: {e}", exc_info=True)
def _is_running(self) -> bool:
"""Check if bot should continue running."""
if not self.running:
return False
if self.state and not self.state.is_running():
return False
return True
def run(self) -> None:
"""Main trading loop."""
self.logger.info("Starting trading loop...")
while self.running:
while self._is_running():
try:
self.run_trading_cycle()
if self.running:
if self._is_running():
sleep_seconds = self.trading_config.sleep_seconds
minutes = sleep_seconds // 60
self.logger.info(f"Sleeping for {minutes} minutes...")
# Sleep in smaller chunks to allow faster shutdown
for _ in range(sleep_seconds):
if not self.running:
if not self._is_running():
break
time.sleep(1)
@@ -319,6 +449,8 @@ class LiveTradingBot:
# Cleanup
self.logger.info("Shutting down...")
self.position_manager.save_positions()
if self.db:
self.db.close()
self.logger.info("Shutdown complete")
@@ -350,6 +482,11 @@ def parse_args():
action="store_true",
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
)
parser.add_argument(
"--no-ui",
action="store_true",
help="Run in headless mode without terminal UI"
)
return parser.parse_args()
@@ -370,19 +507,64 @@ def main():
if args.live:
okx_config.demo_mode = False
# Determine if UI should be enabled
use_ui = not args.no_ui and sys.stdin.isatty()
# Initialize database
db_path = path_config.base_dir / "live_trading" / "trading.db"
db = init_db(db_path)
# Run migrations (imports CSV if exists)
run_migrations(db, path_config.trade_log_file)
# Initialize UI components if enabled
log_queue: Optional[queue.Queue] = None
shared_state: Optional[SharedState] = None
dashboard: Optional[TradingDashboard] = None
if use_ui:
log_queue = queue.Queue(maxsize=1000)
shared_state = SharedState()
# Setup logging
logger = setup_logging(path_config.logs_dir)
logger = setup_logging(path_config.logs_dir, log_queue)
try:
# Create and run bot
bot = LiveTradingBot(okx_config, trading_config, path_config)
# Create bot
bot = LiveTradingBot(
okx_config,
trading_config,
path_config,
database=db,
shared_state=shared_state,
)
# Start dashboard if UI enabled
if use_ui and shared_state and log_queue:
dashboard = TradingDashboard(
state=shared_state,
db=db,
log_queue=log_queue,
on_quit=lambda: setattr(bot, 'running', False),
)
dashboard.start()
logger.info("Dashboard started")
# Run bot
bot.run()
except ValueError as e:
logger.error(f"Configuration error: {e}")
sys.exit(1)
except Exception as e:
logger.error(f"Fatal error: {e}", exc_info=True)
sys.exit(1)
finally:
# Cleanup
if dashboard:
dashboard.stop()
if db:
db.close()
if __name__ == "__main__":

View File

@@ -3,16 +3,21 @@ Position Manager for Live Trading.
Tracks open positions, manages risk, and handles SL/TP logic.
"""
import csv
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
from dataclasses import dataclass, field, asdict
from typing import Optional, TYPE_CHECKING
from dataclasses import dataclass, asdict
from .okx_client import OKXClient
from .config import TradingConfig, PathConfig
if TYPE_CHECKING:
from .db.database import TradingDatabase
from .db.models import Trade
logger = logging.getLogger(__name__)
@@ -78,11 +83,13 @@ class PositionManager:
self,
okx_client: OKXClient,
trading_config: TradingConfig,
path_config: PathConfig
path_config: PathConfig,
database: Optional["TradingDatabase"] = None,
):
self.client = okx_client
self.config = trading_config
self.paths = path_config
self.db = database
self.positions: dict[str, Position] = {}
self.trade_log: list[dict] = []
self._load_positions()
@@ -249,16 +256,55 @@ class PositionManager:
return trade_record
def _append_trade_log(self, trade_record: dict) -> None:
"""Append trade record to CSV log file."""
import csv
"""Append trade record to CSV and SQLite database."""
# Write to CSV (backup/compatibility)
self._append_trade_csv(trade_record)
# Write to SQLite (primary)
self._append_trade_db(trade_record)
def _append_trade_csv(self, trade_record: dict) -> None:
"""Append trade record to CSV log file."""
file_exists = self.paths.trade_log_file.exists()
with open(self.paths.trade_log_file, 'a', newline='') as f:
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
if not file_exists:
writer.writeheader()
writer.writerow(trade_record)
try:
with open(self.paths.trade_log_file, 'a', newline='') as f:
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
if not file_exists:
writer.writeheader()
writer.writerow(trade_record)
except Exception as e:
logger.error(f"Failed to write trade to CSV: {e}")
def _append_trade_db(self, trade_record: dict) -> None:
"""Append trade record to SQLite database."""
if self.db is None:
return
try:
from .db.models import Trade
trade = Trade(
trade_id=trade_record['trade_id'],
symbol=trade_record['symbol'],
side=trade_record['side'],
entry_price=trade_record['entry_price'],
exit_price=trade_record.get('exit_price'),
size=trade_record['size'],
size_usdt=trade_record['size_usdt'],
pnl_usd=trade_record.get('pnl_usd'),
pnl_pct=trade_record.get('pnl_pct'),
entry_time=trade_record['entry_time'],
exit_time=trade_record.get('exit_time'),
hold_duration_hours=trade_record.get('hold_duration_hours'),
reason=trade_record.get('reason'),
order_id_entry=trade_record.get('order_id_entry'),
order_id_exit=trade_record.get('order_id_exit'),
)
self.db.insert_trade(trade)
logger.debug(f"Trade {trade.trade_id} saved to database")
except Exception as e:
logger.error(f"Failed to write trade to database: {e}")
def update_positions(self, current_prices: dict[str, float]) -> list[dict]:
"""

View File

@@ -0,0 +1,10 @@
"""Terminal UI module for live trading dashboard."""
from .dashboard import TradingDashboard
from .state import SharedState
from .log_handler import UILogHandler
__all__ = [
"TradingDashboard",
"SharedState",
"UILogHandler",
]

View File

@@ -0,0 +1,240 @@
"""Main trading dashboard UI orchestration."""
import logging
import queue
import threading
import time
from typing import Optional, Callable
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from rich.panel import Panel
from rich.text import Text
from .state import SharedState
from .log_handler import LogBuffer, UILogHandler
from .keyboard import KeyboardHandler
from .panels import (
HeaderPanel,
TabBar,
LogPanel,
HelpBar,
build_summary_panel,
)
from ..db.database import TradingDatabase
from ..db.metrics import MetricsCalculator, PeriodMetrics
logger = logging.getLogger(__name__)
class TradingDashboard:
"""
Main trading dashboard orchestrator.
Runs in a separate thread and provides real-time UI updates
while the trading loop runs in the main thread.
"""
def __init__(
self,
state: SharedState,
db: TradingDatabase,
log_queue: queue.Queue,
on_quit: Optional[Callable] = None,
):
self.state = state
self.db = db
self.log_queue = log_queue
self.on_quit = on_quit
self.console = Console()
self.log_buffer = LogBuffer(max_entries=1000)
self.keyboard = KeyboardHandler()
self.metrics_calculator = MetricsCalculator(db)
self._running = False
self._thread: Optional[threading.Thread] = None
self._active_tab = 0
self._cached_metrics: dict[int, PeriodMetrics] = {}
self._last_metrics_refresh = 0.0
def start(self) -> None:
"""Start the dashboard in a separate thread."""
if self._running:
return
self._running = True
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
logger.debug("Dashboard thread started")
def stop(self) -> None:
"""Stop the dashboard."""
self._running = False
if self._thread:
self._thread.join(timeout=2.0)
logger.debug("Dashboard thread stopped")
def _run(self) -> None:
"""Main dashboard loop."""
try:
with self.keyboard:
with Live(
self._build_layout(),
console=self.console,
refresh_per_second=1,
screen=True,
) as live:
while self._running and self.state.is_running():
# Process keyboard input
action = self.keyboard.get_action(timeout=0.1)
if action:
self._handle_action(action)
# Drain log queue
self.log_buffer.drain_queue(self.log_queue)
# Refresh metrics periodically (every 5 seconds)
now = time.time()
if now - self._last_metrics_refresh > 5.0:
self._refresh_metrics()
self._last_metrics_refresh = now
# Update display
live.update(self._build_layout())
# Small sleep to prevent CPU spinning
time.sleep(0.1)
except Exception as e:
logger.error(f"Dashboard error: {e}", exc_info=True)
finally:
self._running = False
def _handle_action(self, action: str) -> None:
"""Handle keyboard action."""
if action == "quit":
logger.info("Quit requested from UI")
self.state.stop()
if self.on_quit:
self.on_quit()
elif action == "refresh":
self._refresh_metrics()
logger.debug("Manual refresh triggered")
elif action == "filter":
new_filter = self.log_buffer.cycle_filter()
logger.debug(f"Log filter changed to: {new_filter}")
elif action == "filter_trades":
self.log_buffer.set_filter(LogBuffer.FILTER_TRADES)
logger.debug("Log filter set to: trades")
elif action == "filter_all":
self.log_buffer.set_filter(LogBuffer.FILTER_ALL)
logger.debug("Log filter set to: all")
elif action == "filter_errors":
self.log_buffer.set_filter(LogBuffer.FILTER_ERRORS)
logger.debug("Log filter set to: errors")
elif action == "tab_general":
self._active_tab = 0
elif action == "tab_monthly":
if self._has_monthly_data():
self._active_tab = 1
elif action == "tab_weekly":
if self._has_weekly_data():
self._active_tab = 2
elif action == "tab_daily":
self._active_tab = 3
def _refresh_metrics(self) -> None:
"""Refresh metrics from database."""
try:
self._cached_metrics[0] = self.metrics_calculator.get_all_time_metrics()
self._cached_metrics[1] = self.metrics_calculator.get_monthly_metrics()
self._cached_metrics[2] = self.metrics_calculator.get_weekly_metrics()
self._cached_metrics[3] = self.metrics_calculator.get_daily_metrics()
except Exception as e:
logger.warning(f"Failed to refresh metrics: {e}")
def _has_monthly_data(self) -> bool:
"""Check if monthly tab should be shown."""
try:
return self.metrics_calculator.has_monthly_data()
except Exception:
return False
def _has_weekly_data(self) -> bool:
"""Check if weekly tab should be shown."""
try:
return self.metrics_calculator.has_weekly_data()
except Exception:
return False
def _build_layout(self) -> Layout:
"""Build the complete dashboard layout."""
layout = Layout()
# Calculate available height
term_height = self.console.height or 40
# Header takes 3 lines
# Help bar takes 1 line
# Summary panel takes about 12-14 lines
# Rest goes to logs
log_height = max(8, term_height - 20)
layout.split_column(
Layout(name="header", size=3),
Layout(name="summary", size=14),
Layout(name="logs", size=log_height),
Layout(name="help", size=1),
)
# Header
layout["header"].update(HeaderPanel(self.state).render())
# Summary panel with tabs
current_metrics = self._cached_metrics.get(self._active_tab)
tab_bar = TabBar(active_tab=self._active_tab)
layout["summary"].update(
build_summary_panel(
state=self.state,
metrics=current_metrics,
tab_bar=tab_bar,
has_monthly=self._has_monthly_data(),
has_weekly=self._has_weekly_data(),
)
)
# Log panel
layout["logs"].update(LogPanel(self.log_buffer).render(height=log_height))
# Help bar
layout["help"].update(HelpBar().render())
return layout
def setup_ui_logging(log_queue: queue.Queue) -> UILogHandler:
"""
Set up logging to capture messages for UI.
Args:
log_queue: Queue to send log messages to
Returns:
UILogHandler instance
"""
handler = UILogHandler(log_queue)
handler.setLevel(logging.INFO)
# Add handler to root logger
root_logger = logging.getLogger()
root_logger.addHandler(handler)
return handler

128
live_trading/ui/keyboard.py Normal file
View File

@@ -0,0 +1,128 @@
"""Keyboard input handling for terminal UI."""
import sys
import select
import termios
import tty
from typing import Optional, Callable
from dataclasses import dataclass
@dataclass
class KeyAction:
"""Represents a keyboard action."""
key: str
action: str
description: str
class KeyboardHandler:
"""
Non-blocking keyboard input handler.
Uses terminal raw mode to capture single keypresses
without waiting for Enter.
"""
# Key mappings
ACTIONS = {
"q": "quit",
"Q": "quit",
"\x03": "quit", # Ctrl+C
"r": "refresh",
"R": "refresh",
"f": "filter",
"F": "filter",
"t": "filter_trades",
"T": "filter_trades",
"l": "filter_all",
"L": "filter_all",
"e": "filter_errors",
"E": "filter_errors",
"1": "tab_general",
"2": "tab_monthly",
"3": "tab_weekly",
"4": "tab_daily",
}
def __init__(self):
self._old_settings = None
self._enabled = False
def enable(self) -> bool:
"""
Enable raw keyboard input mode.
Returns:
True if enabled successfully
"""
try:
if not sys.stdin.isatty():
return False
self._old_settings = termios.tcgetattr(sys.stdin)
tty.setcbreak(sys.stdin.fileno())
self._enabled = True
return True
except Exception:
return False
def disable(self) -> None:
"""Restore normal terminal mode."""
if self._enabled and self._old_settings:
try:
termios.tcsetattr(
sys.stdin,
termios.TCSADRAIN,
self._old_settings,
)
except Exception:
pass
self._enabled = False
def get_key(self, timeout: float = 0.1) -> Optional[str]:
"""
Get a keypress if available (non-blocking).
Args:
timeout: Seconds to wait for input
Returns:
Key character or None if no input
"""
if not self._enabled:
return None
try:
readable, _, _ = select.select([sys.stdin], [], [], timeout)
if readable:
return sys.stdin.read(1)
except Exception:
pass
return None
def get_action(self, timeout: float = 0.1) -> Optional[str]:
"""
Get action name for pressed key.
Args:
timeout: Seconds to wait for input
Returns:
Action name or None
"""
key = self.get_key(timeout)
if key:
return self.ACTIONS.get(key)
return None
def __enter__(self):
"""Context manager entry."""
self.enable()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.disable()
return False

View File

@@ -0,0 +1,178 @@
"""Custom logging handler for UI integration."""
import logging
import queue
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from collections import deque
@dataclass
class LogEntry:
"""A single log entry."""
timestamp: str
level: str
message: str
logger_name: str
@property
def level_color(self) -> str:
"""Get Rich color for log level."""
colors = {
"DEBUG": "dim",
"INFO": "white",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "bold red",
}
return colors.get(self.level, "white")
class UILogHandler(logging.Handler):
"""
Custom logging handler that sends logs to UI.
Uses a thread-safe queue to pass log entries from the trading
thread to the UI thread.
"""
def __init__(
self,
log_queue: queue.Queue,
max_entries: int = 1000,
):
super().__init__()
self.log_queue = log_queue
self.max_entries = max_entries
self.setFormatter(
logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
)
def emit(self, record: logging.LogRecord) -> None:
"""Emit a log record to the queue."""
try:
entry = LogEntry(
timestamp=datetime.fromtimestamp(record.created).strftime(
"%H:%M:%S"
),
level=record.levelname,
message=self.format_message(record),
logger_name=record.name,
)
# Non-blocking put, drop if queue is full
try:
self.log_queue.put_nowait(entry)
except queue.Full:
pass
except Exception:
self.handleError(record)
def format_message(self, record: logging.LogRecord) -> str:
"""Format the log message."""
return record.getMessage()
class LogBuffer:
"""
Thread-safe buffer for log entries with filtering support.
Maintains a fixed-size buffer of log entries and supports
filtering by log type.
"""
FILTER_ALL = "all"
FILTER_ERRORS = "errors"
FILTER_TRADES = "trades"
FILTER_SIGNALS = "signals"
FILTERS = [FILTER_ALL, FILTER_ERRORS, FILTER_TRADES, FILTER_SIGNALS]
def __init__(self, max_entries: int = 1000):
self.max_entries = max_entries
self._entries: deque[LogEntry] = deque(maxlen=max_entries)
self._current_filter = self.FILTER_ALL
def add(self, entry: LogEntry) -> None:
"""Add a log entry to the buffer."""
self._entries.append(entry)
def get_filtered(self, limit: int = 50) -> list[LogEntry]:
"""
Get filtered log entries.
Args:
limit: Maximum number of entries to return
Returns:
List of filtered LogEntry objects (most recent first)
"""
entries = list(self._entries)
if self._current_filter == self.FILTER_ERRORS:
entries = [e for e in entries if e.level in ("ERROR", "CRITICAL")]
elif self._current_filter == self.FILTER_TRADES:
# Key terms indicating actual trading activity
include_keywords = [
"order", "entry", "exit", "executed", "filled",
"opening", "closing", "position opened", "position closed"
]
# Terms to exclude (noise)
exclude_keywords = [
"sync complete", "0 positions", "portfolio: 0 positions"
]
entries = [
e for e in entries
if any(kw in e.message.lower() for kw in include_keywords)
and not any(ex in e.message.lower() for ex in exclude_keywords)
]
elif self._current_filter == self.FILTER_SIGNALS:
signal_keywords = ["signal", "z_score", "prob", "z="]
entries = [
e for e in entries
if any(kw in e.message.lower() for kw in signal_keywords)
]
# Return most recent entries
return list(reversed(entries[-limit:]))
def set_filter(self, filter_name: str) -> None:
"""Set a specific filter."""
if filter_name in self.FILTERS:
self._current_filter = filter_name
def cycle_filter(self) -> str:
"""Cycle to next filter and return its name."""
current_idx = self.FILTERS.index(self._current_filter)
next_idx = (current_idx + 1) % len(self.FILTERS)
self._current_filter = self.FILTERS[next_idx]
return self._current_filter
def get_current_filter(self) -> str:
"""Get current filter name."""
return self._current_filter
def clear(self) -> None:
"""Clear all log entries."""
self._entries.clear()
def drain_queue(self, log_queue: queue.Queue) -> int:
"""
Drain log entries from queue into buffer.
Args:
log_queue: Queue to drain from
Returns:
Number of entries drained
"""
count = 0
while True:
try:
entry = log_queue.get_nowait()
self.add(entry)
count += 1
except queue.Empty:
break
return count

399
live_trading/ui/panels.py Normal file
View File

@@ -0,0 +1,399 @@
"""UI panel components using Rich."""
from typing import Optional
from rich.console import Console, Group
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from rich.layout import Layout
from .state import SharedState, PositionState, StrategyState, AccountState
from .log_handler import LogBuffer, LogEntry
from ..db.metrics import PeriodMetrics
def format_pnl(value: float, include_sign: bool = True) -> Text:
"""Format PnL value with color."""
if value > 0:
sign = "+" if include_sign else ""
return Text(f"{sign}${value:.2f}", style="green")
elif value < 0:
return Text(f"${value:.2f}", style="red")
else:
return Text(f"${value:.2f}", style="white")
def format_pct(value: float, include_sign: bool = True) -> Text:
"""Format percentage value with color."""
if value > 0:
sign = "+" if include_sign else ""
return Text(f"{sign}{value:.2f}%", style="green")
elif value < 0:
return Text(f"{value:.2f}%", style="red")
else:
return Text(f"{value:.2f}%", style="white")
def format_side(side: str) -> Text:
"""Format position side with color."""
if side.lower() == "long":
return Text("LONG", style="bold green")
else:
return Text("SHORT", style="bold red")
class HeaderPanel:
"""Header panel with title and mode indicator."""
def __init__(self, state: SharedState):
self.state = state
def render(self) -> Panel:
"""Render the header panel."""
mode = self.state.get_mode()
eth_symbol, _ = self.state.get_symbols()
mode_style = "yellow" if mode == "DEMO" else "bold red"
mode_text = Text(f"[{mode}]", style=mode_style)
title = Text()
title.append("REGIME REVERSION STRATEGY - LIVE TRADING", style="bold white")
title.append(" ")
title.append(mode_text)
title.append(" ")
title.append(eth_symbol, style="cyan")
return Panel(title, style="blue", height=3)
class TabBar:
"""Tab bar for period selection."""
TABS = ["1:General", "2:Monthly", "3:Weekly", "4:Daily"]
def __init__(self, active_tab: int = 0):
self.active_tab = active_tab
def render(
self,
has_monthly: bool = True,
has_weekly: bool = True,
) -> Text:
"""Render the tab bar."""
text = Text()
text.append(" ")
for i, tab in enumerate(self.TABS):
# Check if tab should be shown
if i == 1 and not has_monthly:
continue
if i == 2 and not has_weekly:
continue
if i == self.active_tab:
text.append(f"[{tab}]", style="bold white on blue")
else:
text.append(f"[{tab}]", style="dim")
text.append(" ")
return text
class MetricsPanel:
"""Panel showing trading metrics."""
def __init__(self, metrics: Optional[PeriodMetrics] = None):
self.metrics = metrics
def render(self) -> Table:
"""Render metrics as a table."""
table = Table(
show_header=False,
show_edge=False,
box=None,
padding=(0, 1),
)
table.add_column("Label", style="dim")
table.add_column("Value")
if self.metrics is None or self.metrics.total_trades == 0:
table.add_row("Status", Text("No trade data", style="dim"))
return table
m = self.metrics
table.add_row("Total PnL:", format_pnl(m.total_pnl))
table.add_row("Win Rate:", Text(f"{m.win_rate:.1f}%", style="white"))
table.add_row("Total Trades:", Text(str(m.total_trades), style="white"))
table.add_row(
"Win/Loss:",
Text(f"{m.winning_trades}/{m.losing_trades}", style="white"),
)
table.add_row(
"Avg Duration:",
Text(f"{m.avg_trade_duration_hours:.1f}h", style="white"),
)
table.add_row("Max Drawdown:", format_pnl(-m.max_drawdown))
table.add_row("Best Trade:", format_pnl(m.best_trade))
table.add_row("Worst Trade:", format_pnl(m.worst_trade))
return table
class PositionPanel:
"""Panel showing current position."""
def __init__(self, position: Optional[PositionState] = None):
self.position = position
def render(self) -> Table:
"""Render position as a table."""
table = Table(
show_header=False,
show_edge=False,
box=None,
padding=(0, 1),
)
table.add_column("Label", style="dim")
table.add_column("Value")
if self.position is None:
table.add_row("Status", Text("No open position", style="dim"))
return table
p = self.position
table.add_row("Side:", format_side(p.side))
table.add_row("Entry:", Text(f"${p.entry_price:.2f}", style="white"))
table.add_row("Current:", Text(f"${p.current_price:.2f}", style="white"))
# Unrealized PnL
pnl_text = Text()
pnl_text.append_text(format_pnl(p.unrealized_pnl))
pnl_text.append(" (")
pnl_text.append_text(format_pct(p.unrealized_pnl_pct))
pnl_text.append(")")
table.add_row("Unrealized:", pnl_text)
table.add_row("Size:", Text(f"${p.size_usdt:.2f}", style="white"))
# SL/TP
if p.side == "long":
sl_dist = (p.stop_loss_price / p.entry_price - 1) * 100
tp_dist = (p.take_profit_price / p.entry_price - 1) * 100
else:
sl_dist = (1 - p.stop_loss_price / p.entry_price) * 100
tp_dist = (1 - p.take_profit_price / p.entry_price) * 100
sl_text = Text(f"${p.stop_loss_price:.2f} ({sl_dist:+.1f}%)", style="red")
tp_text = Text(f"${p.take_profit_price:.2f} ({tp_dist:+.1f}%)", style="green")
table.add_row("Stop Loss:", sl_text)
table.add_row("Take Profit:", tp_text)
return table
class AccountPanel:
"""Panel showing account information."""
def __init__(self, account: Optional[AccountState] = None):
self.account = account
def render(self) -> Table:
"""Render account info as a table."""
table = Table(
show_header=False,
show_edge=False,
box=None,
padding=(0, 1),
)
table.add_column("Label", style="dim")
table.add_column("Value")
if self.account is None:
table.add_row("Status", Text("Loading...", style="dim"))
return table
a = self.account
table.add_row("Balance:", Text(f"${a.balance:.2f}", style="white"))
table.add_row("Available:", Text(f"${a.available:.2f}", style="white"))
table.add_row("Leverage:", Text(f"{a.leverage}x", style="cyan"))
return table
class StrategyPanel:
"""Panel showing strategy state."""
def __init__(self, strategy: Optional[StrategyState] = None):
self.strategy = strategy
def render(self) -> Table:
"""Render strategy state as a table."""
table = Table(
show_header=False,
show_edge=False,
box=None,
padding=(0, 1),
)
table.add_column("Label", style="dim")
table.add_column("Value")
if self.strategy is None:
table.add_row("Status", Text("Waiting...", style="dim"))
return table
s = self.strategy
# Z-score with color based on threshold
z_style = "white"
if abs(s.z_score) > 1.0:
z_style = "yellow"
if abs(s.z_score) > 1.5:
z_style = "bold yellow"
table.add_row("Z-Score:", Text(f"{s.z_score:.2f}", style=z_style))
# Probability with color
prob_style = "white"
if s.probability > 0.5:
prob_style = "green"
if s.probability > 0.7:
prob_style = "bold green"
table.add_row("Probability:", Text(f"{s.probability:.2f}", style=prob_style))
# Funding rate
funding_style = "green" if s.funding_rate >= 0 else "red"
table.add_row(
"Funding:",
Text(f"{s.funding_rate:.4f}", style=funding_style),
)
# Last action
action_style = "white"
if s.last_action == "entry":
action_style = "bold cyan"
elif s.last_action == "check_exit":
action_style = "yellow"
table.add_row("Last Action:", Text(s.last_action, style=action_style))
return table
class LogPanel:
"""Panel showing log entries."""
def __init__(self, log_buffer: LogBuffer):
self.log_buffer = log_buffer
def render(self, height: int = 10) -> Panel:
"""Render log panel."""
filter_name = self.log_buffer.get_current_filter().title()
entries = self.log_buffer.get_filtered(limit=height - 2)
lines = []
for entry in entries:
line = Text()
line.append(f"{entry.timestamp} ", style="dim")
line.append(f"[{entry.level}] ", style=entry.level_color)
line.append(entry.message)
lines.append(line)
if not lines:
lines.append(Text("No logs to display", style="dim"))
content = Group(*lines)
# Build "tabbed" title
tabs = []
# All Logs tab
if filter_name == "All":
tabs.append("[bold white on blue] [L]ogs [/]")
else:
tabs.append("[dim] [L]ogs [/]")
# Trades tab
if filter_name == "Trades":
tabs.append("[bold white on blue] [T]rades [/]")
else:
tabs.append("[dim] [T]rades [/]")
# Errors tab
if filter_name == "Errors":
tabs.append("[bold white on blue] [E]rrors [/]")
else:
tabs.append("[dim] [E]rrors [/]")
title = " ".join(tabs)
subtitle = "Press 'l', 't', 'e' to switch tabs"
return Panel(
content,
title=title,
subtitle=subtitle,
title_align="left",
subtitle_align="right",
border_style="blue",
)
class HelpBar:
"""Bottom help bar with keyboard shortcuts."""
def render(self) -> Text:
"""Render help bar."""
text = Text()
text.append(" [q]", style="bold")
text.append("Quit ", style="dim")
text.append("[r]", style="bold")
text.append("Refresh ", style="dim")
text.append("[1-4]", style="bold")
text.append("Tabs ", style="dim")
text.append("[l/t/e]", style="bold")
text.append("LogView", style="dim")
return text
def build_summary_panel(
state: SharedState,
metrics: Optional[PeriodMetrics],
tab_bar: TabBar,
has_monthly: bool,
has_weekly: bool,
) -> Panel:
"""Build the complete summary panel with all sections."""
# Create layout for summary content
layout = Layout()
# Tab bar at top
tabs = tab_bar.render(has_monthly, has_weekly)
# Create tables for each section
metrics_table = MetricsPanel(metrics).render()
position_table = PositionPanel(state.get_position()).render()
account_table = AccountPanel(state.get_account()).render()
strategy_table = StrategyPanel(state.get_strategy()).render()
# Build two-column layout
left_col = Table(show_header=True, show_edge=False, box=None, padding=(0, 2))
left_col.add_column("PERFORMANCE", style="bold cyan")
left_col.add_column("ACCOUNT", style="bold cyan")
left_col.add_row(metrics_table, account_table)
right_col = Table(show_header=True, show_edge=False, box=None, padding=(0, 2))
right_col.add_column("CURRENT POSITION", style="bold cyan")
right_col.add_column("STRATEGY STATE", style="bold cyan")
right_col.add_row(position_table, strategy_table)
# Combine into main table
main_table = Table(show_header=False, show_edge=False, box=None, expand=True)
main_table.add_column(ratio=1)
main_table.add_column(ratio=1)
main_table.add_row(left_col, right_col)
content = Group(tabs, Text(""), main_table)
return Panel(content, border_style="blue")

195
live_trading/ui/state.py Normal file
View File

@@ -0,0 +1,195 @@
"""Thread-safe shared state for UI and trading loop."""
import threading
from dataclasses import dataclass, field
from typing import Optional
from datetime import datetime, timezone
@dataclass
class PositionState:
"""Current position information."""
trade_id: str = ""
symbol: str = ""
side: str = ""
entry_price: float = 0.0
current_price: float = 0.0
size: float = 0.0
size_usdt: float = 0.0
unrealized_pnl: float = 0.0
unrealized_pnl_pct: float = 0.0
stop_loss_price: float = 0.0
take_profit_price: float = 0.0
@dataclass
class StrategyState:
"""Current strategy signal state."""
z_score: float = 0.0
probability: float = 0.0
funding_rate: float = 0.0
last_action: str = "hold"
last_reason: str = ""
last_signal_time: str = ""
@dataclass
class AccountState:
"""Account balance information."""
balance: float = 0.0
available: float = 0.0
leverage: int = 1
class SharedState:
"""
Thread-safe shared state between trading loop and UI.
All access to state fields should go through the getter/setter methods
which use a lock for thread safety.
"""
def __init__(self):
self._lock = threading.Lock()
self._position: Optional[PositionState] = None
self._strategy = StrategyState()
self._account = AccountState()
self._is_running = True
self._last_cycle_time: Optional[str] = None
self._mode = "DEMO"
self._eth_symbol = "ETH/USDT:USDT"
self._btc_symbol = "BTC/USDT:USDT"
# Position methods
def get_position(self) -> Optional[PositionState]:
"""Get current position state."""
with self._lock:
return self._position
def set_position(self, position: Optional[PositionState]) -> None:
"""Set current position state."""
with self._lock:
self._position = position
def update_position_price(self, current_price: float) -> None:
"""Update current price and recalculate PnL."""
with self._lock:
if self._position is None:
return
self._position.current_price = current_price
if self._position.side == "long":
pnl = (current_price - self._position.entry_price)
self._position.unrealized_pnl = pnl * self._position.size
pnl_pct = (current_price / self._position.entry_price - 1) * 100
else:
pnl = (self._position.entry_price - current_price)
self._position.unrealized_pnl = pnl * self._position.size
pnl_pct = (1 - current_price / self._position.entry_price) * 100
self._position.unrealized_pnl_pct = pnl_pct
def clear_position(self) -> None:
"""Clear current position."""
with self._lock:
self._position = None
# Strategy methods
def get_strategy(self) -> StrategyState:
"""Get current strategy state."""
with self._lock:
return StrategyState(
z_score=self._strategy.z_score,
probability=self._strategy.probability,
funding_rate=self._strategy.funding_rate,
last_action=self._strategy.last_action,
last_reason=self._strategy.last_reason,
last_signal_time=self._strategy.last_signal_time,
)
def update_strategy(
self,
z_score: float,
probability: float,
funding_rate: float,
action: str,
reason: str,
) -> None:
"""Update strategy state."""
with self._lock:
self._strategy.z_score = z_score
self._strategy.probability = probability
self._strategy.funding_rate = funding_rate
self._strategy.last_action = action
self._strategy.last_reason = reason
self._strategy.last_signal_time = datetime.now(
timezone.utc
).isoformat()
# Account methods
def get_account(self) -> AccountState:
"""Get current account state."""
with self._lock:
return AccountState(
balance=self._account.balance,
available=self._account.available,
leverage=self._account.leverage,
)
def update_account(
self,
balance: float,
available: float,
leverage: int,
) -> None:
"""Update account state."""
with self._lock:
self._account.balance = balance
self._account.available = available
self._account.leverage = leverage
# Control methods
def is_running(self) -> bool:
"""Check if trading loop is running."""
with self._lock:
return self._is_running
def stop(self) -> None:
"""Signal to stop trading loop."""
with self._lock:
self._is_running = False
def get_last_cycle_time(self) -> Optional[str]:
"""Get last trading cycle time."""
with self._lock:
return self._last_cycle_time
def set_last_cycle_time(self, time_str: str) -> None:
"""Set last trading cycle time."""
with self._lock:
self._last_cycle_time = time_str
# Config methods
def get_mode(self) -> str:
"""Get trading mode (DEMO/LIVE)."""
with self._lock:
return self._mode
def set_mode(self, mode: str) -> None:
"""Set trading mode."""
with self._lock:
self._mode = mode
def get_symbols(self) -> tuple[str, str]:
"""Get trading symbols (eth, btc)."""
with self._lock:
return self._eth_symbol, self._btc_symbol
def set_symbols(self, eth_symbol: str, btc_symbol: str) -> None:
"""Set trading symbols."""
with self._lock:
self._eth_symbol = eth_symbol
self._btc_symbol = btc_symbol