Add daily model training scripts and terminal UI for live trading
- Introduced `train_daily.sh` for automating daily model retraining, including data download and model training steps. - Added `install_cron.sh` for setting up a cron job to run the daily training script. - Created `setup_schedule.sh` for configuring Systemd timers for daily training tasks. - Implemented a terminal UI using Rich for real-time monitoring of trading performance, including metrics display and log handling. - Updated `pyproject.toml` to include the `rich` dependency for UI functionality. - Enhanced `.gitignore` to exclude model and log files. - Added database support for trade persistence and metrics calculation. - Updated README with installation and usage instructions for the new features.
This commit is contained in:
@@ -60,7 +60,7 @@ class TradingConfig:
|
||||
|
||||
# Position sizing
|
||||
max_position_usdt: float = -1.0 # Max position size in USDT. If <= 0, use all available funds
|
||||
min_position_usdt: float = 10.0 # Min position size in USDT
|
||||
min_position_usdt: float = 1.0 # Min position size in USDT
|
||||
leverage: int = 1 # Leverage (1x = no leverage)
|
||||
margin_mode: str = "cross" # "cross" or "isolated"
|
||||
|
||||
|
||||
13
live_trading/db/__init__.py
Normal file
13
live_trading/db/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Database module for live trading persistence."""
|
||||
from .database import get_db, init_db
|
||||
from .models import Trade, DailySummary, Session
|
||||
from .metrics import MetricsCalculator
|
||||
|
||||
__all__ = [
|
||||
"get_db",
|
||||
"init_db",
|
||||
"Trade",
|
||||
"DailySummary",
|
||||
"Session",
|
||||
"MetricsCalculator",
|
||||
]
|
||||
325
live_trading/db/database.py
Normal file
325
live_trading/db/database.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""SQLite database connection and operations."""
|
||||
import sqlite3
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from contextlib import contextmanager
|
||||
|
||||
from .models import Trade, DailySummary, Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database schema
|
||||
SCHEMA = """
|
||||
-- Trade history table
|
||||
CREATE TABLE IF NOT EXISTS trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
trade_id TEXT UNIQUE NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
side TEXT NOT NULL,
|
||||
entry_price REAL NOT NULL,
|
||||
exit_price REAL,
|
||||
size REAL NOT NULL,
|
||||
size_usdt REAL NOT NULL,
|
||||
pnl_usd REAL,
|
||||
pnl_pct REAL,
|
||||
entry_time TEXT NOT NULL,
|
||||
exit_time TEXT,
|
||||
hold_duration_hours REAL,
|
||||
reason TEXT,
|
||||
order_id_entry TEXT,
|
||||
order_id_exit TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Daily summary table
|
||||
CREATE TABLE IF NOT EXISTS daily_summary (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
date TEXT UNIQUE NOT NULL,
|
||||
total_trades INTEGER DEFAULT 0,
|
||||
winning_trades INTEGER DEFAULT 0,
|
||||
total_pnl_usd REAL DEFAULT 0,
|
||||
max_drawdown_usd REAL DEFAULT 0,
|
||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Session metadata
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
start_time TEXT NOT NULL,
|
||||
end_time TEXT,
|
||||
starting_balance REAL,
|
||||
ending_balance REAL,
|
||||
total_pnl REAL,
|
||||
total_trades INTEGER DEFAULT 0
|
||||
);
|
||||
|
||||
-- Indexes for common queries
|
||||
CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades(entry_time);
|
||||
CREATE INDEX IF NOT EXISTS idx_trades_exit_time ON trades(exit_time);
|
||||
CREATE INDEX IF NOT EXISTS idx_daily_summary_date ON daily_summary(date);
|
||||
"""
|
||||
|
||||
_db_instance: Optional["TradingDatabase"] = None
|
||||
|
||||
|
||||
class TradingDatabase:
|
||||
"""SQLite database for trade persistence."""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = db_path
|
||||
self._connection: Optional[sqlite3.Connection] = None
|
||||
|
||||
@property
|
||||
def connection(self) -> sqlite3.Connection:
|
||||
"""Get or create database connection."""
|
||||
if self._connection is None:
|
||||
self._connection = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
check_same_thread=False,
|
||||
)
|
||||
self._connection.row_factory = sqlite3.Row
|
||||
return self._connection
|
||||
|
||||
def init_schema(self) -> None:
|
||||
"""Initialize database schema."""
|
||||
with self.connection:
|
||||
self.connection.executescript(SCHEMA)
|
||||
logger.info(f"Database initialized at {self.db_path}")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
if self._connection:
|
||||
self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
@contextmanager
|
||||
def transaction(self):
|
||||
"""Context manager for database transactions."""
|
||||
try:
|
||||
yield self.connection
|
||||
self.connection.commit()
|
||||
except Exception:
|
||||
self.connection.rollback()
|
||||
raise
|
||||
|
||||
def insert_trade(self, trade: Trade) -> int:
|
||||
"""
|
||||
Insert a new trade record.
|
||||
|
||||
Args:
|
||||
trade: Trade object to insert
|
||||
|
||||
Returns:
|
||||
Row ID of inserted trade
|
||||
"""
|
||||
sql = """
|
||||
INSERT INTO trades (
|
||||
trade_id, symbol, side, entry_price, exit_price,
|
||||
size, size_usdt, pnl_usd, pnl_pct, entry_time,
|
||||
exit_time, hold_duration_hours, reason,
|
||||
order_id_entry, order_id_exit
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self.transaction():
|
||||
cursor = self.connection.execute(
|
||||
sql,
|
||||
(
|
||||
trade.trade_id,
|
||||
trade.symbol,
|
||||
trade.side,
|
||||
trade.entry_price,
|
||||
trade.exit_price,
|
||||
trade.size,
|
||||
trade.size_usdt,
|
||||
trade.pnl_usd,
|
||||
trade.pnl_pct,
|
||||
trade.entry_time,
|
||||
trade.exit_time,
|
||||
trade.hold_duration_hours,
|
||||
trade.reason,
|
||||
trade.order_id_entry,
|
||||
trade.order_id_exit,
|
||||
),
|
||||
)
|
||||
return cursor.lastrowid
|
||||
|
||||
def update_trade(self, trade_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
Update an existing trade record.
|
||||
|
||||
Args:
|
||||
trade_id: Trade ID to update
|
||||
**kwargs: Fields to update
|
||||
|
||||
Returns:
|
||||
True if trade was updated
|
||||
"""
|
||||
if not kwargs:
|
||||
return False
|
||||
|
||||
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
|
||||
sql = f"UPDATE trades SET {set_clause} WHERE trade_id = ?"
|
||||
|
||||
with self.transaction():
|
||||
cursor = self.connection.execute(
|
||||
sql, (*kwargs.values(), trade_id)
|
||||
)
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_trade(self, trade_id: str) -> Optional[Trade]:
|
||||
"""Get a trade by ID."""
|
||||
sql = "SELECT * FROM trades WHERE trade_id = ?"
|
||||
row = self.connection.execute(sql, (trade_id,)).fetchone()
|
||||
if row:
|
||||
return Trade(**dict(row))
|
||||
return None
|
||||
|
||||
def get_trades(
|
||||
self,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[Trade]:
|
||||
"""
|
||||
Get trades within a time range.
|
||||
|
||||
Args:
|
||||
start_time: ISO format start time filter
|
||||
end_time: ISO format end time filter
|
||||
limit: Maximum number of trades to return
|
||||
|
||||
Returns:
|
||||
List of Trade objects
|
||||
"""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if start_time:
|
||||
conditions.append("entry_time >= ?")
|
||||
params.append(start_time)
|
||||
if end_time:
|
||||
conditions.append("entry_time <= ?")
|
||||
params.append(end_time)
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
limit_clause = f"LIMIT {limit}" if limit else ""
|
||||
|
||||
sql = f"""
|
||||
SELECT * FROM trades
|
||||
WHERE {where_clause}
|
||||
ORDER BY entry_time DESC
|
||||
{limit_clause}
|
||||
"""
|
||||
|
||||
rows = self.connection.execute(sql, params).fetchall()
|
||||
return [Trade(**dict(row)) for row in rows]
|
||||
|
||||
def get_all_trades(self) -> list[Trade]:
|
||||
"""Get all trades."""
|
||||
sql = "SELECT * FROM trades ORDER BY entry_time DESC"
|
||||
rows = self.connection.execute(sql).fetchall()
|
||||
return [Trade(**dict(row)) for row in rows]
|
||||
|
||||
def count_trades(self) -> int:
|
||||
"""Get total number of trades."""
|
||||
sql = "SELECT COUNT(*) FROM trades WHERE exit_time IS NOT NULL"
|
||||
return self.connection.execute(sql).fetchone()[0]
|
||||
|
||||
def upsert_daily_summary(self, summary: DailySummary) -> None:
|
||||
"""Insert or update daily summary."""
|
||||
sql = """
|
||||
INSERT INTO daily_summary (
|
||||
date, total_trades, winning_trades, total_pnl_usd, max_drawdown_usd
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(date) DO UPDATE SET
|
||||
total_trades = excluded.total_trades,
|
||||
winning_trades = excluded.winning_trades,
|
||||
total_pnl_usd = excluded.total_pnl_usd,
|
||||
max_drawdown_usd = excluded.max_drawdown_usd,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"""
|
||||
with self.transaction():
|
||||
self.connection.execute(
|
||||
sql,
|
||||
(
|
||||
summary.date,
|
||||
summary.total_trades,
|
||||
summary.winning_trades,
|
||||
summary.total_pnl_usd,
|
||||
summary.max_drawdown_usd,
|
||||
),
|
||||
)
|
||||
|
||||
def get_daily_summary(self, date: str) -> Optional[DailySummary]:
|
||||
"""Get daily summary for a specific date."""
|
||||
sql = "SELECT * FROM daily_summary WHERE date = ?"
|
||||
row = self.connection.execute(sql, (date,)).fetchone()
|
||||
if row:
|
||||
return DailySummary(**dict(row))
|
||||
return None
|
||||
|
||||
def insert_session(self, session: Session) -> int:
|
||||
"""Insert a new session record."""
|
||||
sql = """
|
||||
INSERT INTO sessions (
|
||||
start_time, end_time, starting_balance,
|
||||
ending_balance, total_pnl, total_trades
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
with self.transaction():
|
||||
cursor = self.connection.execute(
|
||||
sql,
|
||||
(
|
||||
session.start_time,
|
||||
session.end_time,
|
||||
session.starting_balance,
|
||||
session.ending_balance,
|
||||
session.total_pnl,
|
||||
session.total_trades,
|
||||
),
|
||||
)
|
||||
return cursor.lastrowid
|
||||
|
||||
def update_session(self, session_id: int, **kwargs) -> bool:
|
||||
"""Update an existing session."""
|
||||
if not kwargs:
|
||||
return False
|
||||
|
||||
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
|
||||
sql = f"UPDATE sessions SET {set_clause} WHERE id = ?"
|
||||
|
||||
with self.transaction():
|
||||
cursor = self.connection.execute(
|
||||
sql, (*kwargs.values(), session_id)
|
||||
)
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_latest_session(self) -> Optional[Session]:
|
||||
"""Get the most recent session."""
|
||||
sql = "SELECT * FROM sessions ORDER BY id DESC LIMIT 1"
|
||||
row = self.connection.execute(sql).fetchone()
|
||||
if row:
|
||||
return Session(**dict(row))
|
||||
return None
|
||||
|
||||
|
||||
def init_db(db_path: Path) -> TradingDatabase:
|
||||
"""
|
||||
Initialize the database.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file
|
||||
|
||||
Returns:
|
||||
TradingDatabase instance
|
||||
"""
|
||||
global _db_instance
|
||||
_db_instance = TradingDatabase(db_path)
|
||||
_db_instance.init_schema()
|
||||
return _db_instance
|
||||
|
||||
|
||||
def get_db() -> Optional[TradingDatabase]:
|
||||
"""Get the global database instance."""
|
||||
return _db_instance
|
||||
235
live_trading/db/metrics.py
Normal file
235
live_trading/db/metrics.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Metrics calculation from trade database."""
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from .database import TradingDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeriodMetrics:
|
||||
"""Trading metrics for a time period."""
|
||||
|
||||
period_name: str
|
||||
start_time: Optional[str]
|
||||
end_time: Optional[str]
|
||||
total_pnl: float = 0.0
|
||||
total_trades: int = 0
|
||||
winning_trades: int = 0
|
||||
losing_trades: int = 0
|
||||
win_rate: float = 0.0
|
||||
avg_trade_duration_hours: float = 0.0
|
||||
max_drawdown: float = 0.0
|
||||
max_drawdown_pct: float = 0.0
|
||||
best_trade: float = 0.0
|
||||
worst_trade: float = 0.0
|
||||
avg_win: float = 0.0
|
||||
avg_loss: float = 0.0
|
||||
|
||||
|
||||
class MetricsCalculator:
|
||||
"""Calculate trading metrics from database."""
|
||||
|
||||
def __init__(self, db: TradingDatabase):
|
||||
self.db = db
|
||||
|
||||
def get_all_time_metrics(self) -> PeriodMetrics:
|
||||
"""Get metrics for all trades ever."""
|
||||
return self._calculate_metrics("All Time", None, None)
|
||||
|
||||
def get_monthly_metrics(self) -> PeriodMetrics:
|
||||
"""Get metrics for current month."""
|
||||
now = datetime.now(timezone.utc)
|
||||
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
return self._calculate_metrics(
|
||||
"Monthly",
|
||||
start.isoformat(),
|
||||
now.isoformat(),
|
||||
)
|
||||
|
||||
def get_weekly_metrics(self) -> PeriodMetrics:
|
||||
"""Get metrics for current week (Monday to now)."""
|
||||
now = datetime.now(timezone.utc)
|
||||
days_since_monday = now.weekday()
|
||||
start = now - timedelta(days=days_since_monday)
|
||||
start = start.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
return self._calculate_metrics(
|
||||
"Weekly",
|
||||
start.isoformat(),
|
||||
now.isoformat(),
|
||||
)
|
||||
|
||||
def get_daily_metrics(self) -> PeriodMetrics:
|
||||
"""Get metrics for today (UTC)."""
|
||||
now = datetime.now(timezone.utc)
|
||||
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
return self._calculate_metrics(
|
||||
"Daily",
|
||||
start.isoformat(),
|
||||
now.isoformat(),
|
||||
)
|
||||
|
||||
def _calculate_metrics(
|
||||
self,
|
||||
period_name: str,
|
||||
start_time: Optional[str],
|
||||
end_time: Optional[str],
|
||||
) -> PeriodMetrics:
|
||||
"""
|
||||
Calculate metrics for a time period.
|
||||
|
||||
Args:
|
||||
period_name: Name of the period
|
||||
start_time: ISO format start time (None for all time)
|
||||
end_time: ISO format end time (None for all time)
|
||||
|
||||
Returns:
|
||||
PeriodMetrics object
|
||||
"""
|
||||
metrics = PeriodMetrics(
|
||||
period_name=period_name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# Build query conditions
|
||||
conditions = ["exit_time IS NOT NULL"]
|
||||
params = []
|
||||
|
||||
if start_time:
|
||||
conditions.append("exit_time >= ?")
|
||||
params.append(start_time)
|
||||
if end_time:
|
||||
conditions.append("exit_time <= ?")
|
||||
params.append(end_time)
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
|
||||
# Get aggregate metrics
|
||||
sql = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_trades,
|
||||
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
|
||||
SUM(CASE WHEN pnl_usd < 0 THEN 1 ELSE 0 END) as losing_trades,
|
||||
COALESCE(SUM(pnl_usd), 0) as total_pnl,
|
||||
COALESCE(AVG(hold_duration_hours), 0) as avg_duration,
|
||||
COALESCE(MAX(pnl_usd), 0) as best_trade,
|
||||
COALESCE(MIN(pnl_usd), 0) as worst_trade,
|
||||
COALESCE(AVG(CASE WHEN pnl_usd > 0 THEN pnl_usd END), 0) as avg_win,
|
||||
COALESCE(AVG(CASE WHEN pnl_usd < 0 THEN pnl_usd END), 0) as avg_loss
|
||||
FROM trades
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
|
||||
row = self.db.connection.execute(sql, params).fetchone()
|
||||
|
||||
if row and row["total_trades"] > 0:
|
||||
metrics.total_trades = row["total_trades"]
|
||||
metrics.winning_trades = row["winning_trades"] or 0
|
||||
metrics.losing_trades = row["losing_trades"] or 0
|
||||
metrics.total_pnl = row["total_pnl"]
|
||||
metrics.avg_trade_duration_hours = row["avg_duration"]
|
||||
metrics.best_trade = row["best_trade"]
|
||||
metrics.worst_trade = row["worst_trade"]
|
||||
metrics.avg_win = row["avg_win"]
|
||||
metrics.avg_loss = row["avg_loss"]
|
||||
|
||||
if metrics.total_trades > 0:
|
||||
metrics.win_rate = (
|
||||
metrics.winning_trades / metrics.total_trades * 100
|
||||
)
|
||||
|
||||
# Calculate max drawdown
|
||||
metrics.max_drawdown = self._calculate_max_drawdown(
|
||||
start_time, end_time
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
def _calculate_max_drawdown(
|
||||
self,
|
||||
start_time: Optional[str],
|
||||
end_time: Optional[str],
|
||||
) -> float:
|
||||
"""Calculate maximum drawdown for a period."""
|
||||
conditions = ["exit_time IS NOT NULL"]
|
||||
params = []
|
||||
|
||||
if start_time:
|
||||
conditions.append("exit_time >= ?")
|
||||
params.append(start_time)
|
||||
if end_time:
|
||||
conditions.append("exit_time <= ?")
|
||||
params.append(end_time)
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
|
||||
sql = f"""
|
||||
SELECT pnl_usd
|
||||
FROM trades
|
||||
WHERE {where_clause}
|
||||
ORDER BY exit_time
|
||||
"""
|
||||
|
||||
rows = self.db.connection.execute(sql, params).fetchall()
|
||||
|
||||
if not rows:
|
||||
return 0.0
|
||||
|
||||
cumulative = 0.0
|
||||
peak = 0.0
|
||||
max_drawdown = 0.0
|
||||
|
||||
for row in rows:
|
||||
pnl = row["pnl_usd"] or 0.0
|
||||
cumulative += pnl
|
||||
peak = max(peak, cumulative)
|
||||
drawdown = peak - cumulative
|
||||
max_drawdown = max(max_drawdown, drawdown)
|
||||
|
||||
return max_drawdown
|
||||
|
||||
def has_monthly_data(self) -> bool:
|
||||
"""Check if we have data spanning more than current month."""
|
||||
sql = """
|
||||
SELECT MIN(exit_time) as first_trade
|
||||
FROM trades
|
||||
WHERE exit_time IS NOT NULL
|
||||
"""
|
||||
row = self.db.connection.execute(sql).fetchone()
|
||||
if not row or not row["first_trade"]:
|
||||
return False
|
||||
|
||||
first_trade = datetime.fromisoformat(row["first_trade"])
|
||||
now = datetime.now(timezone.utc)
|
||||
month_start = now.replace(day=1, hour=0, minute=0, second=0)
|
||||
|
||||
return first_trade < month_start
|
||||
|
||||
def has_weekly_data(self) -> bool:
|
||||
"""Check if we have data spanning more than current week."""
|
||||
sql = """
|
||||
SELECT MIN(exit_time) as first_trade
|
||||
FROM trades
|
||||
WHERE exit_time IS NOT NULL
|
||||
"""
|
||||
row = self.db.connection.execute(sql).fetchone()
|
||||
if not row or not row["first_trade"]:
|
||||
return False
|
||||
|
||||
first_trade = datetime.fromisoformat(row["first_trade"])
|
||||
now = datetime.now(timezone.utc)
|
||||
days_since_monday = now.weekday()
|
||||
week_start = now - timedelta(days=days_since_monday)
|
||||
week_start = week_start.replace(hour=0, minute=0, second=0)
|
||||
|
||||
return first_trade < week_start
|
||||
|
||||
def get_session_start_balance(self) -> Optional[float]:
|
||||
"""Get starting balance from latest session."""
|
||||
sql = "SELECT starting_balance FROM sessions ORDER BY id DESC LIMIT 1"
|
||||
row = self.db.connection.execute(sql).fetchone()
|
||||
return row["starting_balance"] if row else None
|
||||
191
live_trading/db/migrations.py
Normal file
191
live_trading/db/migrations.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Database migrations and CSV import."""
|
||||
import csv
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from .database import TradingDatabase
|
||||
from .models import Trade, DailySummary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_csv_to_db(db: TradingDatabase, csv_path: Path) -> int:
|
||||
"""
|
||||
Migrate trades from CSV file to SQLite database.
|
||||
|
||||
Args:
|
||||
db: TradingDatabase instance
|
||||
csv_path: Path to trade_log.csv
|
||||
|
||||
Returns:
|
||||
Number of trades migrated
|
||||
"""
|
||||
if not csv_path.exists():
|
||||
logger.info("No CSV file to migrate")
|
||||
return 0
|
||||
|
||||
# Check if database already has trades
|
||||
existing_count = db.count_trades()
|
||||
if existing_count > 0:
|
||||
logger.info(
|
||||
f"Database already has {existing_count} trades, skipping migration"
|
||||
)
|
||||
return 0
|
||||
|
||||
migrated = 0
|
||||
try:
|
||||
with open(csv_path, "r", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
for row in reader:
|
||||
trade = _csv_row_to_trade(row)
|
||||
if trade:
|
||||
try:
|
||||
db.insert_trade(trade)
|
||||
migrated += 1
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to migrate trade {row.get('trade_id')}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"Migrated {migrated} trades from CSV to database")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CSV migration failed: {e}")
|
||||
|
||||
return migrated
|
||||
|
||||
|
||||
def _csv_row_to_trade(row: dict) -> Trade | None:
|
||||
"""Convert a CSV row to a Trade object."""
|
||||
try:
|
||||
return Trade(
|
||||
trade_id=row["trade_id"],
|
||||
symbol=row["symbol"],
|
||||
side=row["side"],
|
||||
entry_price=float(row["entry_price"]),
|
||||
exit_price=_safe_float(row.get("exit_price")),
|
||||
size=float(row["size"]),
|
||||
size_usdt=float(row["size_usdt"]),
|
||||
pnl_usd=_safe_float(row.get("pnl_usd")),
|
||||
pnl_pct=_safe_float(row.get("pnl_pct")),
|
||||
entry_time=row["entry_time"],
|
||||
exit_time=row.get("exit_time") or None,
|
||||
hold_duration_hours=_safe_float(row.get("hold_duration_hours")),
|
||||
reason=row.get("reason") or None,
|
||||
order_id_entry=row.get("order_id_entry") or None,
|
||||
order_id_exit=row.get("order_id_exit") or None,
|
||||
)
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning(f"Invalid CSV row: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _safe_float(value: str | None) -> float | None:
|
||||
"""Safely convert string to float."""
|
||||
if value is None or value == "":
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def rebuild_daily_summaries(db: TradingDatabase) -> int:
|
||||
"""
|
||||
Rebuild daily summary table from trades.
|
||||
|
||||
Args:
|
||||
db: TradingDatabase instance
|
||||
|
||||
Returns:
|
||||
Number of daily summaries created
|
||||
"""
|
||||
sql = """
|
||||
SELECT
|
||||
DATE(exit_time) as date,
|
||||
COUNT(*) as total_trades,
|
||||
SUM(CASE WHEN pnl_usd > 0 THEN 1 ELSE 0 END) as winning_trades,
|
||||
SUM(pnl_usd) as total_pnl_usd
|
||||
FROM trades
|
||||
WHERE exit_time IS NOT NULL
|
||||
GROUP BY DATE(exit_time)
|
||||
ORDER BY date
|
||||
"""
|
||||
|
||||
rows = db.connection.execute(sql).fetchall()
|
||||
count = 0
|
||||
|
||||
for row in rows:
|
||||
summary = DailySummary(
|
||||
date=row["date"],
|
||||
total_trades=row["total_trades"],
|
||||
winning_trades=row["winning_trades"],
|
||||
total_pnl_usd=row["total_pnl_usd"] or 0.0,
|
||||
max_drawdown_usd=0.0, # Calculated separately
|
||||
)
|
||||
db.upsert_daily_summary(summary)
|
||||
count += 1
|
||||
|
||||
# Calculate max drawdowns
|
||||
_calculate_daily_drawdowns(db)
|
||||
|
||||
logger.info(f"Rebuilt {count} daily summaries")
|
||||
return count
|
||||
|
||||
|
||||
def _calculate_daily_drawdowns(db: TradingDatabase) -> None:
|
||||
"""Calculate and update max drawdown for each day."""
|
||||
sql = """
|
||||
SELECT trade_id, DATE(exit_time) as date, pnl_usd
|
||||
FROM trades
|
||||
WHERE exit_time IS NOT NULL
|
||||
ORDER BY exit_time
|
||||
"""
|
||||
|
||||
rows = db.connection.execute(sql).fetchall()
|
||||
|
||||
# Track cumulative PnL and drawdown per day
|
||||
daily_drawdowns: dict[str, float] = {}
|
||||
cumulative_pnl = 0.0
|
||||
peak_pnl = 0.0
|
||||
|
||||
for row in rows:
|
||||
date = row["date"]
|
||||
pnl = row["pnl_usd"] or 0.0
|
||||
|
||||
cumulative_pnl += pnl
|
||||
peak_pnl = max(peak_pnl, cumulative_pnl)
|
||||
drawdown = peak_pnl - cumulative_pnl
|
||||
|
||||
if date not in daily_drawdowns:
|
||||
daily_drawdowns[date] = 0.0
|
||||
daily_drawdowns[date] = max(daily_drawdowns[date], drawdown)
|
||||
|
||||
# Update daily summaries with drawdown
|
||||
for date, drawdown in daily_drawdowns.items():
|
||||
db.connection.execute(
|
||||
"UPDATE daily_summary SET max_drawdown_usd = ? WHERE date = ?",
|
||||
(drawdown, date),
|
||||
)
|
||||
db.connection.commit()
|
||||
|
||||
|
||||
def run_migrations(db: TradingDatabase, csv_path: Path) -> None:
|
||||
"""
|
||||
Run all migrations.
|
||||
|
||||
Args:
|
||||
db: TradingDatabase instance
|
||||
csv_path: Path to trade_log.csv for migration
|
||||
"""
|
||||
logger.info("Running database migrations...")
|
||||
|
||||
# Migrate CSV data if exists
|
||||
migrate_csv_to_db(db, csv_path)
|
||||
|
||||
# Rebuild daily summaries
|
||||
rebuild_daily_summaries(db)
|
||||
|
||||
logger.info("Migrations complete")
|
||||
69
live_trading/db/models.py
Normal file
69
live_trading/db/models.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Data models for trade persistence."""
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trade:
|
||||
"""Represents a completed trade."""
|
||||
|
||||
trade_id: str
|
||||
symbol: str
|
||||
side: str
|
||||
entry_price: float
|
||||
size: float
|
||||
size_usdt: float
|
||||
entry_time: str
|
||||
exit_price: Optional[float] = None
|
||||
pnl_usd: Optional[float] = None
|
||||
pnl_pct: Optional[float] = None
|
||||
exit_time: Optional[str] = None
|
||||
hold_duration_hours: Optional[float] = None
|
||||
reason: Optional[str] = None
|
||||
order_id_entry: Optional[str] = None
|
||||
order_id_exit: Optional[str] = None
|
||||
id: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: tuple, columns: list[str]) -> "Trade":
|
||||
"""Create Trade from database row."""
|
||||
data = dict(zip(columns, row))
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DailySummary:
|
||||
"""Daily trading summary."""
|
||||
|
||||
date: str
|
||||
total_trades: int = 0
|
||||
winning_trades: int = 0
|
||||
total_pnl_usd: float = 0.0
|
||||
max_drawdown_usd: float = 0.0
|
||||
id: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""Trading session metadata."""
|
||||
|
||||
start_time: str
|
||||
end_time: Optional[str] = None
|
||||
starting_balance: Optional[float] = None
|
||||
ending_balance: Optional[float] = None
|
||||
total_pnl: Optional[float] = None
|
||||
total_trades: int = 0
|
||||
id: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
@@ -39,18 +39,40 @@ class LiveRegimeStrategy:
|
||||
self.paths = path_config
|
||||
self.model: Optional[RandomForestClassifier] = None
|
||||
self.feature_cols: Optional[list] = None
|
||||
self.horizon: int = 102 # Default horizon
|
||||
self._last_model_load_time: float = 0.0
|
||||
self._load_or_train_model()
|
||||
|
||||
def reload_model_if_changed(self) -> None:
|
||||
"""Check if model file has changed and reload if necessary."""
|
||||
if not self.paths.model_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
mtime = self.paths.model_path.stat().st_mtime
|
||||
if mtime > self._last_model_load_time:
|
||||
logger.info(f"Model file changed, reloading... (last: {self._last_model_load_time}, new: {mtime})")
|
||||
self._load_or_train_model()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking model file: {e}")
|
||||
|
||||
def _load_or_train_model(self) -> None:
|
||||
"""Load pre-trained model or train a new one."""
|
||||
if self.paths.model_path.exists():
|
||||
try:
|
||||
self._last_model_load_time = self.paths.model_path.stat().st_mtime
|
||||
with open(self.paths.model_path, 'rb') as f:
|
||||
saved = pickle.load(f)
|
||||
self.model = saved['model']
|
||||
self.feature_cols = saved['feature_cols']
|
||||
logger.info(f"Loaded model from {self.paths.model_path}")
|
||||
return
|
||||
|
||||
# Load horizon from metrics if available
|
||||
if 'metrics' in saved and 'horizon' in saved['metrics']:
|
||||
self.horizon = saved['metrics']['horizon']
|
||||
logger.info(f"Loaded model from {self.paths.model_path} (horizon={self.horizon})")
|
||||
else:
|
||||
logger.info(f"Loaded model from {self.paths.model_path} (default horizon={self.horizon})")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load model: {e}")
|
||||
|
||||
@@ -66,6 +88,7 @@ class LiveRegimeStrategy:
|
||||
pickle.dump({
|
||||
'model': self.model,
|
||||
'feature_cols': self.feature_cols,
|
||||
'metrics': {'horizon': self.horizon} # Save horizon
|
||||
}, f)
|
||||
logger.info(f"Saved model to {self.paths.model_path}")
|
||||
except Exception as e:
|
||||
@@ -81,7 +104,7 @@ class LiveRegimeStrategy:
|
||||
logger.info(f"Training model on {len(features)} samples...")
|
||||
|
||||
z_thresh = self.config.z_entry_threshold
|
||||
horizon = 102 # Optimal horizon from research
|
||||
horizon = self.horizon
|
||||
profit_target = 0.005 # 0.5% profit threshold
|
||||
|
||||
# Define targets
|
||||
|
||||
@@ -11,14 +11,19 @@ Usage:
|
||||
|
||||
# Run with specific settings
|
||||
uv run python -m live_trading.main --max-position 500 --leverage 2
|
||||
|
||||
# Run without UI (headless mode)
|
||||
uv run python -m live_trading.main --no-ui
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
@@ -28,22 +33,47 @@ from live_trading.okx_client import OKXClient
|
||||
from live_trading.data_feed import DataFeed
|
||||
from live_trading.position_manager import PositionManager
|
||||
from live_trading.live_regime_strategy import LiveRegimeStrategy
|
||||
from live_trading.db.database import init_db, TradingDatabase
|
||||
from live_trading.db.migrations import run_migrations
|
||||
from live_trading.ui.state import SharedState, PositionState
|
||||
from live_trading.ui.dashboard import TradingDashboard, setup_ui_logging
|
||||
|
||||
|
||||
def setup_logging(log_dir: Path) -> logging.Logger:
|
||||
"""Configure logging for the trading bot."""
|
||||
def setup_logging(
|
||||
log_dir: Path,
|
||||
log_queue: Optional[queue.Queue] = None,
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Configure logging for the trading bot.
|
||||
|
||||
Args:
|
||||
log_dir: Directory for log files
|
||||
log_queue: Optional queue for UI log handler
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
log_file = log_dir / "live_trading.log"
|
||||
|
||||
handlers = [
|
||||
logging.FileHandler(log_file),
|
||||
]
|
||||
|
||||
# Only add StreamHandler if no UI (log_queue is None)
|
||||
if log_queue is None:
|
||||
handlers.append(logging.StreamHandler(sys.stdout))
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_file),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
force=True
|
||||
handlers=handlers,
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Add UI log handler if queue provided
|
||||
if log_queue is not None:
|
||||
setup_ui_logging(log_queue)
|
||||
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -59,11 +89,15 @@ class LiveTradingBot:
|
||||
self,
|
||||
okx_config: OKXConfig,
|
||||
trading_config: TradingConfig,
|
||||
path_config: PathConfig
|
||||
path_config: PathConfig,
|
||||
database: Optional[TradingDatabase] = None,
|
||||
shared_state: Optional[SharedState] = None,
|
||||
):
|
||||
self.okx_config = okx_config
|
||||
self.trading_config = trading_config
|
||||
self.path_config = path_config
|
||||
self.db = database
|
||||
self.state = shared_state
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.running = True
|
||||
@@ -74,7 +108,7 @@ class LiveTradingBot:
|
||||
self.okx_client = OKXClient(okx_config, trading_config)
|
||||
self.data_feed = DataFeed(self.okx_client, trading_config, path_config)
|
||||
self.position_manager = PositionManager(
|
||||
self.okx_client, trading_config, path_config
|
||||
self.okx_client, trading_config, path_config, database
|
||||
)
|
||||
self.strategy = LiveRegimeStrategy(trading_config, path_config)
|
||||
|
||||
@@ -82,6 +116,16 @@ class LiveTradingBot:
|
||||
signal.signal(signal.SIGINT, self._handle_shutdown)
|
||||
signal.signal(signal.SIGTERM, self._handle_shutdown)
|
||||
|
||||
# Initialize shared state if provided
|
||||
if self.state:
|
||||
mode = "DEMO" if okx_config.demo_mode else "LIVE"
|
||||
self.state.set_mode(mode)
|
||||
self.state.set_symbols(
|
||||
trading_config.eth_symbol,
|
||||
trading_config.btc_symbol,
|
||||
)
|
||||
self.state.update_account(0.0, 0.0, trading_config.leverage)
|
||||
|
||||
self._print_startup_banner()
|
||||
|
||||
def _print_startup_banner(self) -> None:
|
||||
@@ -109,6 +153,8 @@ class LiveTradingBot:
|
||||
"""Handle shutdown signals gracefully."""
|
||||
self.logger.info("Shutdown signal received, stopping...")
|
||||
self.running = False
|
||||
if self.state:
|
||||
self.state.stop()
|
||||
|
||||
def run_trading_cycle(self) -> None:
|
||||
"""
|
||||
@@ -118,10 +164,20 @@ class LiveTradingBot:
|
||||
2. Update open positions
|
||||
3. Generate trading signal
|
||||
4. Execute trades if signal triggers
|
||||
5. Update shared state for UI
|
||||
"""
|
||||
# Reload model if it has changed (e.g. daily training)
|
||||
try:
|
||||
self.strategy.reload_model_if_changed()
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to reload model: {e}")
|
||||
|
||||
cycle_start = datetime.now(timezone.utc)
|
||||
self.logger.info(f"--- Trading Cycle Start: {cycle_start.isoformat()} ---")
|
||||
|
||||
if self.state:
|
||||
self.state.set_last_cycle_time(cycle_start.isoformat())
|
||||
|
||||
try:
|
||||
# 1. Fetch market data
|
||||
features = self.data_feed.get_latest_data()
|
||||
@@ -154,15 +210,22 @@ class LiveTradingBot:
|
||||
funding = self.data_feed.get_current_funding_rates()
|
||||
|
||||
# 5. Generate trading signal
|
||||
signal = self.strategy.generate_signal(features, funding)
|
||||
sig = self.strategy.generate_signal(features, funding)
|
||||
|
||||
# 6. Execute trades based on signal
|
||||
if signal['action'] == 'entry':
|
||||
self._execute_entry(signal, eth_price)
|
||||
elif signal['action'] == 'check_exit':
|
||||
self._execute_exit(signal)
|
||||
# 6. Update shared state with strategy info
|
||||
self._update_strategy_state(sig, funding)
|
||||
|
||||
# 7. Log portfolio summary
|
||||
# 7. Execute trades based on signal
|
||||
if sig['action'] == 'entry':
|
||||
self._execute_entry(sig, eth_price)
|
||||
elif sig['action'] == 'check_exit':
|
||||
self._execute_exit(sig)
|
||||
|
||||
# 8. Update shared state with position and account
|
||||
self._update_position_state(eth_price)
|
||||
self._update_account_state()
|
||||
|
||||
# 9. Log portfolio summary
|
||||
summary = self.position_manager.get_portfolio_summary()
|
||||
self.logger.info(
|
||||
f"Portfolio: {summary['open_positions']} positions, "
|
||||
@@ -178,6 +241,61 @@ class LiveTradingBot:
|
||||
cycle_duration = (datetime.now(timezone.utc) - cycle_start).total_seconds()
|
||||
self.logger.info(f"--- Cycle completed in {cycle_duration:.1f}s ---")
|
||||
|
||||
def _update_strategy_state(self, sig: dict, funding: dict) -> None:
|
||||
"""Update shared state with strategy information."""
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
self.state.update_strategy(
|
||||
z_score=sig.get('z_score', 0.0),
|
||||
probability=sig.get('probability', 0.0),
|
||||
funding_rate=funding.get('btc_funding', 0.0),
|
||||
action=sig.get('action', 'hold'),
|
||||
reason=sig.get('reason', ''),
|
||||
)
|
||||
|
||||
def _update_position_state(self, current_price: float) -> None:
|
||||
"""Update shared state with current position."""
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
symbol = self.trading_config.eth_symbol
|
||||
position = self.position_manager.get_position_for_symbol(symbol)
|
||||
|
||||
if position is None:
|
||||
self.state.clear_position()
|
||||
return
|
||||
|
||||
pos_state = PositionState(
|
||||
trade_id=position.trade_id,
|
||||
symbol=position.symbol,
|
||||
side=position.side,
|
||||
entry_price=position.entry_price,
|
||||
current_price=position.current_price,
|
||||
size=position.size,
|
||||
size_usdt=position.size_usdt,
|
||||
unrealized_pnl=position.unrealized_pnl,
|
||||
unrealized_pnl_pct=position.unrealized_pnl_pct,
|
||||
stop_loss_price=position.stop_loss_price,
|
||||
take_profit_price=position.take_profit_price,
|
||||
)
|
||||
self.state.set_position(pos_state)
|
||||
|
||||
def _update_account_state(self) -> None:
|
||||
"""Update shared state with account information."""
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
try:
|
||||
balance = self.okx_client.get_balance()
|
||||
self.state.update_account(
|
||||
balance=balance.get('total', 0.0),
|
||||
available=balance.get('free', 0.0),
|
||||
leverage=self.trading_config.leverage,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to update account state: {e}")
|
||||
|
||||
def _execute_entry(self, signal: dict, current_price: float) -> None:
|
||||
"""Execute entry trade."""
|
||||
symbol = self.trading_config.eth_symbol
|
||||
@@ -191,11 +309,15 @@ class LiveTradingBot:
|
||||
# Get account balance
|
||||
balance = self.okx_client.get_balance()
|
||||
available_usdt = balance['free']
|
||||
|
||||
self.logger.info(f"Account balance: ${available_usdt:.2f} USDT available")
|
||||
|
||||
# Calculate position size
|
||||
size_usdt = self.strategy.calculate_position_size(signal, available_usdt)
|
||||
if size_usdt <= 0:
|
||||
self.logger.info("Position size too small, skipping entry")
|
||||
self.logger.info(
|
||||
f"Position size too small (${size_usdt:.2f}), skipping entry. "
|
||||
f"Min required: ${self.strategy.config.min_position_usdt:.2f}"
|
||||
)
|
||||
return
|
||||
|
||||
size_eth = size_usdt / current_price
|
||||
@@ -290,22 +412,30 @@ class LiveTradingBot:
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exit execution failed: {e}", exc_info=True)
|
||||
|
||||
def _is_running(self) -> bool:
|
||||
"""Check if bot should continue running."""
|
||||
if not self.running:
|
||||
return False
|
||||
if self.state and not self.state.is_running():
|
||||
return False
|
||||
return True
|
||||
|
||||
def run(self) -> None:
|
||||
"""Main trading loop."""
|
||||
self.logger.info("Starting trading loop...")
|
||||
|
||||
while self.running:
|
||||
while self._is_running():
|
||||
try:
|
||||
self.run_trading_cycle()
|
||||
|
||||
if self.running:
|
||||
if self._is_running():
|
||||
sleep_seconds = self.trading_config.sleep_seconds
|
||||
minutes = sleep_seconds // 60
|
||||
self.logger.info(f"Sleeping for {minutes} minutes...")
|
||||
|
||||
# Sleep in smaller chunks to allow faster shutdown
|
||||
for _ in range(sleep_seconds):
|
||||
if not self.running:
|
||||
if not self._is_running():
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
@@ -319,6 +449,8 @@ class LiveTradingBot:
|
||||
# Cleanup
|
||||
self.logger.info("Shutting down...")
|
||||
self.position_manager.save_positions()
|
||||
if self.db:
|
||||
self.db.close()
|
||||
self.logger.info("Shutdown complete")
|
||||
|
||||
|
||||
@@ -350,6 +482,11 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Use live trading mode (requires OKX_DEMO_MODE=false)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-ui",
|
||||
action="store_true",
|
||||
help="Run in headless mode without terminal UI"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -370,19 +507,64 @@ def main():
|
||||
if args.live:
|
||||
okx_config.demo_mode = False
|
||||
|
||||
# Determine if UI should be enabled
|
||||
use_ui = not args.no_ui and sys.stdin.isatty()
|
||||
|
||||
# Initialize database
|
||||
db_path = path_config.base_dir / "live_trading" / "trading.db"
|
||||
db = init_db(db_path)
|
||||
|
||||
# Run migrations (imports CSV if exists)
|
||||
run_migrations(db, path_config.trade_log_file)
|
||||
|
||||
# Initialize UI components if enabled
|
||||
log_queue: Optional[queue.Queue] = None
|
||||
shared_state: Optional[SharedState] = None
|
||||
dashboard: Optional[TradingDashboard] = None
|
||||
|
||||
if use_ui:
|
||||
log_queue = queue.Queue(maxsize=1000)
|
||||
shared_state = SharedState()
|
||||
|
||||
# Setup logging
|
||||
logger = setup_logging(path_config.logs_dir)
|
||||
logger = setup_logging(path_config.logs_dir, log_queue)
|
||||
|
||||
try:
|
||||
# Create and run bot
|
||||
bot = LiveTradingBot(okx_config, trading_config, path_config)
|
||||
# Create bot
|
||||
bot = LiveTradingBot(
|
||||
okx_config,
|
||||
trading_config,
|
||||
path_config,
|
||||
database=db,
|
||||
shared_state=shared_state,
|
||||
)
|
||||
|
||||
# Start dashboard if UI enabled
|
||||
if use_ui and shared_state and log_queue:
|
||||
dashboard = TradingDashboard(
|
||||
state=shared_state,
|
||||
db=db,
|
||||
log_queue=log_queue,
|
||||
on_quit=lambda: setattr(bot, 'running', False),
|
||||
)
|
||||
dashboard.start()
|
||||
logger.info("Dashboard started")
|
||||
|
||||
# Run bot
|
||||
bot.run()
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Configuration error: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
finally:
|
||||
# Cleanup
|
||||
if dashboard:
|
||||
dashboard.stop()
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -3,16 +3,21 @@ Position Manager for Live Trading.
|
||||
|
||||
Tracks open positions, manages risk, and handles SL/TP logic.
|
||||
"""
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
from .okx_client import OKXClient
|
||||
from .config import TradingConfig, PathConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db.database import TradingDatabase
|
||||
from .db.models import Trade
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -78,11 +83,13 @@ class PositionManager:
|
||||
self,
|
||||
okx_client: OKXClient,
|
||||
trading_config: TradingConfig,
|
||||
path_config: PathConfig
|
||||
path_config: PathConfig,
|
||||
database: Optional["TradingDatabase"] = None,
|
||||
):
|
||||
self.client = okx_client
|
||||
self.config = trading_config
|
||||
self.paths = path_config
|
||||
self.db = database
|
||||
self.positions: dict[str, Position] = {}
|
||||
self.trade_log: list[dict] = []
|
||||
self._load_positions()
|
||||
@@ -249,16 +256,55 @@ class PositionManager:
|
||||
return trade_record
|
||||
|
||||
def _append_trade_log(self, trade_record: dict) -> None:
|
||||
"""Append trade record to CSV log file."""
|
||||
import csv
|
||||
"""Append trade record to CSV and SQLite database."""
|
||||
# Write to CSV (backup/compatibility)
|
||||
self._append_trade_csv(trade_record)
|
||||
|
||||
# Write to SQLite (primary)
|
||||
self._append_trade_db(trade_record)
|
||||
|
||||
def _append_trade_csv(self, trade_record: dict) -> None:
|
||||
"""Append trade record to CSV log file."""
|
||||
file_exists = self.paths.trade_log_file.exists()
|
||||
|
||||
with open(self.paths.trade_log_file, 'a', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
writer.writerow(trade_record)
|
||||
try:
|
||||
with open(self.paths.trade_log_file, 'a', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
writer.writerow(trade_record)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write trade to CSV: {e}")
|
||||
|
||||
def _append_trade_db(self, trade_record: dict) -> None:
|
||||
"""Append trade record to SQLite database."""
|
||||
if self.db is None:
|
||||
return
|
||||
|
||||
try:
|
||||
from .db.models import Trade
|
||||
|
||||
trade = Trade(
|
||||
trade_id=trade_record['trade_id'],
|
||||
symbol=trade_record['symbol'],
|
||||
side=trade_record['side'],
|
||||
entry_price=trade_record['entry_price'],
|
||||
exit_price=trade_record.get('exit_price'),
|
||||
size=trade_record['size'],
|
||||
size_usdt=trade_record['size_usdt'],
|
||||
pnl_usd=trade_record.get('pnl_usd'),
|
||||
pnl_pct=trade_record.get('pnl_pct'),
|
||||
entry_time=trade_record['entry_time'],
|
||||
exit_time=trade_record.get('exit_time'),
|
||||
hold_duration_hours=trade_record.get('hold_duration_hours'),
|
||||
reason=trade_record.get('reason'),
|
||||
order_id_entry=trade_record.get('order_id_entry'),
|
||||
order_id_exit=trade_record.get('order_id_exit'),
|
||||
)
|
||||
self.db.insert_trade(trade)
|
||||
logger.debug(f"Trade {trade.trade_id} saved to database")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write trade to database: {e}")
|
||||
|
||||
def update_positions(self, current_prices: dict[str, float]) -> list[dict]:
|
||||
"""
|
||||
|
||||
10
live_trading/ui/__init__.py
Normal file
10
live_trading/ui/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Terminal UI module for live trading dashboard."""
|
||||
from .dashboard import TradingDashboard
|
||||
from .state import SharedState
|
||||
from .log_handler import UILogHandler
|
||||
|
||||
__all__ = [
|
||||
"TradingDashboard",
|
||||
"SharedState",
|
||||
"UILogHandler",
|
||||
]
|
||||
240
live_trading/ui/dashboard.py
Normal file
240
live_trading/ui/dashboard.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Main trading dashboard UI orchestration."""
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, Callable
|
||||
|
||||
from rich.console import Console
|
||||
from rich.layout import Layout
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from .state import SharedState
|
||||
from .log_handler import LogBuffer, UILogHandler
|
||||
from .keyboard import KeyboardHandler
|
||||
from .panels import (
|
||||
HeaderPanel,
|
||||
TabBar,
|
||||
LogPanel,
|
||||
HelpBar,
|
||||
build_summary_panel,
|
||||
)
|
||||
from ..db.database import TradingDatabase
|
||||
from ..db.metrics import MetricsCalculator, PeriodMetrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TradingDashboard:
|
||||
"""
|
||||
Main trading dashboard orchestrator.
|
||||
|
||||
Runs in a separate thread and provides real-time UI updates
|
||||
while the trading loop runs in the main thread.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state: SharedState,
|
||||
db: TradingDatabase,
|
||||
log_queue: queue.Queue,
|
||||
on_quit: Optional[Callable] = None,
|
||||
):
|
||||
self.state = state
|
||||
self.db = db
|
||||
self.log_queue = log_queue
|
||||
self.on_quit = on_quit
|
||||
|
||||
self.console = Console()
|
||||
self.log_buffer = LogBuffer(max_entries=1000)
|
||||
self.keyboard = KeyboardHandler()
|
||||
self.metrics_calculator = MetricsCalculator(db)
|
||||
|
||||
self._running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._active_tab = 0
|
||||
self._cached_metrics: dict[int, PeriodMetrics] = {}
|
||||
self._last_metrics_refresh = 0.0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the dashboard in a separate thread."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
logger.debug("Dashboard thread started")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dashboard."""
|
||||
self._running = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=2.0)
|
||||
logger.debug("Dashboard thread stopped")
|
||||
|
||||
def _run(self) -> None:
|
||||
"""Main dashboard loop."""
|
||||
try:
|
||||
with self.keyboard:
|
||||
with Live(
|
||||
self._build_layout(),
|
||||
console=self.console,
|
||||
refresh_per_second=1,
|
||||
screen=True,
|
||||
) as live:
|
||||
while self._running and self.state.is_running():
|
||||
# Process keyboard input
|
||||
action = self.keyboard.get_action(timeout=0.1)
|
||||
if action:
|
||||
self._handle_action(action)
|
||||
|
||||
# Drain log queue
|
||||
self.log_buffer.drain_queue(self.log_queue)
|
||||
|
||||
# Refresh metrics periodically (every 5 seconds)
|
||||
now = time.time()
|
||||
if now - self._last_metrics_refresh > 5.0:
|
||||
self._refresh_metrics()
|
||||
self._last_metrics_refresh = now
|
||||
|
||||
# Update display
|
||||
live.update(self._build_layout())
|
||||
|
||||
# Small sleep to prevent CPU spinning
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dashboard error: {e}", exc_info=True)
|
||||
finally:
|
||||
self._running = False
|
||||
|
||||
def _handle_action(self, action: str) -> None:
|
||||
"""Handle keyboard action."""
|
||||
if action == "quit":
|
||||
logger.info("Quit requested from UI")
|
||||
self.state.stop()
|
||||
if self.on_quit:
|
||||
self.on_quit()
|
||||
|
||||
elif action == "refresh":
|
||||
self._refresh_metrics()
|
||||
logger.debug("Manual refresh triggered")
|
||||
|
||||
elif action == "filter":
|
||||
new_filter = self.log_buffer.cycle_filter()
|
||||
logger.debug(f"Log filter changed to: {new_filter}")
|
||||
|
||||
elif action == "filter_trades":
|
||||
self.log_buffer.set_filter(LogBuffer.FILTER_TRADES)
|
||||
logger.debug("Log filter set to: trades")
|
||||
|
||||
elif action == "filter_all":
|
||||
self.log_buffer.set_filter(LogBuffer.FILTER_ALL)
|
||||
logger.debug("Log filter set to: all")
|
||||
|
||||
elif action == "filter_errors":
|
||||
self.log_buffer.set_filter(LogBuffer.FILTER_ERRORS)
|
||||
logger.debug("Log filter set to: errors")
|
||||
|
||||
elif action == "tab_general":
|
||||
self._active_tab = 0
|
||||
elif action == "tab_monthly":
|
||||
if self._has_monthly_data():
|
||||
self._active_tab = 1
|
||||
elif action == "tab_weekly":
|
||||
if self._has_weekly_data():
|
||||
self._active_tab = 2
|
||||
elif action == "tab_daily":
|
||||
self._active_tab = 3
|
||||
|
||||
def _refresh_metrics(self) -> None:
|
||||
"""Refresh metrics from database."""
|
||||
try:
|
||||
self._cached_metrics[0] = self.metrics_calculator.get_all_time_metrics()
|
||||
self._cached_metrics[1] = self.metrics_calculator.get_monthly_metrics()
|
||||
self._cached_metrics[2] = self.metrics_calculator.get_weekly_metrics()
|
||||
self._cached_metrics[3] = self.metrics_calculator.get_daily_metrics()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to refresh metrics: {e}")
|
||||
|
||||
def _has_monthly_data(self) -> bool:
|
||||
"""Check if monthly tab should be shown."""
|
||||
try:
|
||||
return self.metrics_calculator.has_monthly_data()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _has_weekly_data(self) -> bool:
|
||||
"""Check if weekly tab should be shown."""
|
||||
try:
|
||||
return self.metrics_calculator.has_weekly_data()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _build_layout(self) -> Layout:
|
||||
"""Build the complete dashboard layout."""
|
||||
layout = Layout()
|
||||
|
||||
# Calculate available height
|
||||
term_height = self.console.height or 40
|
||||
|
||||
# Header takes 3 lines
|
||||
# Help bar takes 1 line
|
||||
# Summary panel takes about 12-14 lines
|
||||
# Rest goes to logs
|
||||
log_height = max(8, term_height - 20)
|
||||
|
||||
layout.split_column(
|
||||
Layout(name="header", size=3),
|
||||
Layout(name="summary", size=14),
|
||||
Layout(name="logs", size=log_height),
|
||||
Layout(name="help", size=1),
|
||||
)
|
||||
|
||||
# Header
|
||||
layout["header"].update(HeaderPanel(self.state).render())
|
||||
|
||||
# Summary panel with tabs
|
||||
current_metrics = self._cached_metrics.get(self._active_tab)
|
||||
tab_bar = TabBar(active_tab=self._active_tab)
|
||||
|
||||
layout["summary"].update(
|
||||
build_summary_panel(
|
||||
state=self.state,
|
||||
metrics=current_metrics,
|
||||
tab_bar=tab_bar,
|
||||
has_monthly=self._has_monthly_data(),
|
||||
has_weekly=self._has_weekly_data(),
|
||||
)
|
||||
)
|
||||
|
||||
# Log panel
|
||||
layout["logs"].update(LogPanel(self.log_buffer).render(height=log_height))
|
||||
|
||||
# Help bar
|
||||
layout["help"].update(HelpBar().render())
|
||||
|
||||
return layout
|
||||
|
||||
|
||||
def setup_ui_logging(log_queue: queue.Queue) -> UILogHandler:
|
||||
"""
|
||||
Set up logging to capture messages for UI.
|
||||
|
||||
Args:
|
||||
log_queue: Queue to send log messages to
|
||||
|
||||
Returns:
|
||||
UILogHandler instance
|
||||
"""
|
||||
handler = UILogHandler(log_queue)
|
||||
handler.setLevel(logging.INFO)
|
||||
|
||||
# Add handler to root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
return handler
|
||||
128
live_trading/ui/keyboard.py
Normal file
128
live_trading/ui/keyboard.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Keyboard input handling for terminal UI."""
|
||||
import sys
|
||||
import select
|
||||
import termios
|
||||
import tty
|
||||
from typing import Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeyAction:
|
||||
"""Represents a keyboard action."""
|
||||
|
||||
key: str
|
||||
action: str
|
||||
description: str
|
||||
|
||||
|
||||
class KeyboardHandler:
|
||||
"""
|
||||
Non-blocking keyboard input handler.
|
||||
|
||||
Uses terminal raw mode to capture single keypresses
|
||||
without waiting for Enter.
|
||||
"""
|
||||
|
||||
# Key mappings
|
||||
ACTIONS = {
|
||||
"q": "quit",
|
||||
"Q": "quit",
|
||||
"\x03": "quit", # Ctrl+C
|
||||
"r": "refresh",
|
||||
"R": "refresh",
|
||||
"f": "filter",
|
||||
"F": "filter",
|
||||
"t": "filter_trades",
|
||||
"T": "filter_trades",
|
||||
"l": "filter_all",
|
||||
"L": "filter_all",
|
||||
"e": "filter_errors",
|
||||
"E": "filter_errors",
|
||||
"1": "tab_general",
|
||||
"2": "tab_monthly",
|
||||
"3": "tab_weekly",
|
||||
"4": "tab_daily",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._old_settings = None
|
||||
self._enabled = False
|
||||
|
||||
def enable(self) -> bool:
|
||||
"""
|
||||
Enable raw keyboard input mode.
|
||||
|
||||
Returns:
|
||||
True if enabled successfully
|
||||
"""
|
||||
try:
|
||||
if not sys.stdin.isatty():
|
||||
return False
|
||||
|
||||
self._old_settings = termios.tcgetattr(sys.stdin)
|
||||
tty.setcbreak(sys.stdin.fileno())
|
||||
self._enabled = True
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Restore normal terminal mode."""
|
||||
if self._enabled and self._old_settings:
|
||||
try:
|
||||
termios.tcsetattr(
|
||||
sys.stdin,
|
||||
termios.TCSADRAIN,
|
||||
self._old_settings,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._enabled = False
|
||||
|
||||
def get_key(self, timeout: float = 0.1) -> Optional[str]:
|
||||
"""
|
||||
Get a keypress if available (non-blocking).
|
||||
|
||||
Args:
|
||||
timeout: Seconds to wait for input
|
||||
|
||||
Returns:
|
||||
Key character or None if no input
|
||||
"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
readable, _, _ = select.select([sys.stdin], [], [], timeout)
|
||||
if readable:
|
||||
return sys.stdin.read(1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def get_action(self, timeout: float = 0.1) -> Optional[str]:
|
||||
"""
|
||||
Get action name for pressed key.
|
||||
|
||||
Args:
|
||||
timeout: Seconds to wait for input
|
||||
|
||||
Returns:
|
||||
Action name or None
|
||||
"""
|
||||
key = self.get_key(timeout)
|
||||
if key:
|
||||
return self.ACTIONS.get(key)
|
||||
return None
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
self.enable()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.disable()
|
||||
return False
|
||||
178
live_trading/ui/log_handler.py
Normal file
178
live_trading/ui/log_handler.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Custom logging handler for UI integration."""
|
||||
import logging
|
||||
import queue
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from collections import deque
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogEntry:
|
||||
"""A single log entry."""
|
||||
|
||||
timestamp: str
|
||||
level: str
|
||||
message: str
|
||||
logger_name: str
|
||||
|
||||
@property
|
||||
def level_color(self) -> str:
|
||||
"""Get Rich color for log level."""
|
||||
colors = {
|
||||
"DEBUG": "dim",
|
||||
"INFO": "white",
|
||||
"WARNING": "yellow",
|
||||
"ERROR": "red",
|
||||
"CRITICAL": "bold red",
|
||||
}
|
||||
return colors.get(self.level, "white")
|
||||
|
||||
|
||||
class UILogHandler(logging.Handler):
|
||||
"""
|
||||
Custom logging handler that sends logs to UI.
|
||||
|
||||
Uses a thread-safe queue to pass log entries from the trading
|
||||
thread to the UI thread.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_queue: queue.Queue,
|
||||
max_entries: int = 1000,
|
||||
):
|
||||
super().__init__()
|
||||
self.log_queue = log_queue
|
||||
self.max_entries = max_entries
|
||||
self.setFormatter(
|
||||
logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
||||
)
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
"""Emit a log record to the queue."""
|
||||
try:
|
||||
entry = LogEntry(
|
||||
timestamp=datetime.fromtimestamp(record.created).strftime(
|
||||
"%H:%M:%S"
|
||||
),
|
||||
level=record.levelname,
|
||||
message=self.format_message(record),
|
||||
logger_name=record.name,
|
||||
)
|
||||
# Non-blocking put, drop if queue is full
|
||||
try:
|
||||
self.log_queue.put_nowait(entry)
|
||||
except queue.Full:
|
||||
pass
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
def format_message(self, record: logging.LogRecord) -> str:
|
||||
"""Format the log message."""
|
||||
return record.getMessage()
|
||||
|
||||
|
||||
class LogBuffer:
|
||||
"""
|
||||
Thread-safe buffer for log entries with filtering support.
|
||||
|
||||
Maintains a fixed-size buffer of log entries and supports
|
||||
filtering by log type.
|
||||
"""
|
||||
|
||||
FILTER_ALL = "all"
|
||||
FILTER_ERRORS = "errors"
|
||||
FILTER_TRADES = "trades"
|
||||
FILTER_SIGNALS = "signals"
|
||||
|
||||
FILTERS = [FILTER_ALL, FILTER_ERRORS, FILTER_TRADES, FILTER_SIGNALS]
|
||||
|
||||
def __init__(self, max_entries: int = 1000):
|
||||
self.max_entries = max_entries
|
||||
self._entries: deque[LogEntry] = deque(maxlen=max_entries)
|
||||
self._current_filter = self.FILTER_ALL
|
||||
|
||||
def add(self, entry: LogEntry) -> None:
|
||||
"""Add a log entry to the buffer."""
|
||||
self._entries.append(entry)
|
||||
|
||||
def get_filtered(self, limit: int = 50) -> list[LogEntry]:
|
||||
"""
|
||||
Get filtered log entries.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of entries to return
|
||||
|
||||
Returns:
|
||||
List of filtered LogEntry objects (most recent first)
|
||||
"""
|
||||
entries = list(self._entries)
|
||||
|
||||
if self._current_filter == self.FILTER_ERRORS:
|
||||
entries = [e for e in entries if e.level in ("ERROR", "CRITICAL")]
|
||||
elif self._current_filter == self.FILTER_TRADES:
|
||||
# Key terms indicating actual trading activity
|
||||
include_keywords = [
|
||||
"order", "entry", "exit", "executed", "filled",
|
||||
"opening", "closing", "position opened", "position closed"
|
||||
]
|
||||
# Terms to exclude (noise)
|
||||
exclude_keywords = [
|
||||
"sync complete", "0 positions", "portfolio: 0 positions"
|
||||
]
|
||||
|
||||
entries = [
|
||||
e for e in entries
|
||||
if any(kw in e.message.lower() for kw in include_keywords)
|
||||
and not any(ex in e.message.lower() for ex in exclude_keywords)
|
||||
]
|
||||
elif self._current_filter == self.FILTER_SIGNALS:
|
||||
signal_keywords = ["signal", "z_score", "prob", "z="]
|
||||
entries = [
|
||||
e for e in entries
|
||||
if any(kw in e.message.lower() for kw in signal_keywords)
|
||||
]
|
||||
|
||||
# Return most recent entries
|
||||
return list(reversed(entries[-limit:]))
|
||||
|
||||
def set_filter(self, filter_name: str) -> None:
|
||||
"""Set a specific filter."""
|
||||
if filter_name in self.FILTERS:
|
||||
self._current_filter = filter_name
|
||||
|
||||
def cycle_filter(self) -> str:
|
||||
"""Cycle to next filter and return its name."""
|
||||
current_idx = self.FILTERS.index(self._current_filter)
|
||||
next_idx = (current_idx + 1) % len(self.FILTERS)
|
||||
self._current_filter = self.FILTERS[next_idx]
|
||||
return self._current_filter
|
||||
|
||||
def get_current_filter(self) -> str:
|
||||
"""Get current filter name."""
|
||||
return self._current_filter
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all log entries."""
|
||||
self._entries.clear()
|
||||
|
||||
def drain_queue(self, log_queue: queue.Queue) -> int:
|
||||
"""
|
||||
Drain log entries from queue into buffer.
|
||||
|
||||
Args:
|
||||
log_queue: Queue to drain from
|
||||
|
||||
Returns:
|
||||
Number of entries drained
|
||||
"""
|
||||
count = 0
|
||||
while True:
|
||||
try:
|
||||
entry = log_queue.get_nowait()
|
||||
self.add(entry)
|
||||
count += 1
|
||||
except queue.Empty:
|
||||
break
|
||||
return count
|
||||
399
live_trading/ui/panels.py
Normal file
399
live_trading/ui/panels.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""UI panel components using Rich."""
|
||||
from typing import Optional
|
||||
|
||||
from rich.console import Console, Group
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
from rich.layout import Layout
|
||||
|
||||
from .state import SharedState, PositionState, StrategyState, AccountState
|
||||
from .log_handler import LogBuffer, LogEntry
|
||||
from ..db.metrics import PeriodMetrics
|
||||
|
||||
|
||||
def format_pnl(value: float, include_sign: bool = True) -> Text:
|
||||
"""Format PnL value with color."""
|
||||
if value > 0:
|
||||
sign = "+" if include_sign else ""
|
||||
return Text(f"{sign}${value:.2f}", style="green")
|
||||
elif value < 0:
|
||||
return Text(f"${value:.2f}", style="red")
|
||||
else:
|
||||
return Text(f"${value:.2f}", style="white")
|
||||
|
||||
|
||||
def format_pct(value: float, include_sign: bool = True) -> Text:
|
||||
"""Format percentage value with color."""
|
||||
if value > 0:
|
||||
sign = "+" if include_sign else ""
|
||||
return Text(f"{sign}{value:.2f}%", style="green")
|
||||
elif value < 0:
|
||||
return Text(f"{value:.2f}%", style="red")
|
||||
else:
|
||||
return Text(f"{value:.2f}%", style="white")
|
||||
|
||||
|
||||
def format_side(side: str) -> Text:
|
||||
"""Format position side with color."""
|
||||
if side.lower() == "long":
|
||||
return Text("LONG", style="bold green")
|
||||
else:
|
||||
return Text("SHORT", style="bold red")
|
||||
|
||||
|
||||
class HeaderPanel:
|
||||
"""Header panel with title and mode indicator."""
|
||||
|
||||
def __init__(self, state: SharedState):
|
||||
self.state = state
|
||||
|
||||
def render(self) -> Panel:
|
||||
"""Render the header panel."""
|
||||
mode = self.state.get_mode()
|
||||
eth_symbol, _ = self.state.get_symbols()
|
||||
|
||||
mode_style = "yellow" if mode == "DEMO" else "bold red"
|
||||
mode_text = Text(f"[{mode}]", style=mode_style)
|
||||
|
||||
title = Text()
|
||||
title.append("REGIME REVERSION STRATEGY - LIVE TRADING", style="bold white")
|
||||
title.append(" ")
|
||||
title.append(mode_text)
|
||||
title.append(" ")
|
||||
title.append(eth_symbol, style="cyan")
|
||||
|
||||
return Panel(title, style="blue", height=3)
|
||||
|
||||
|
||||
class TabBar:
|
||||
"""Tab bar for period selection."""
|
||||
|
||||
TABS = ["1:General", "2:Monthly", "3:Weekly", "4:Daily"]
|
||||
|
||||
def __init__(self, active_tab: int = 0):
|
||||
self.active_tab = active_tab
|
||||
|
||||
def render(
|
||||
self,
|
||||
has_monthly: bool = True,
|
||||
has_weekly: bool = True,
|
||||
) -> Text:
|
||||
"""Render the tab bar."""
|
||||
text = Text()
|
||||
text.append(" ")
|
||||
|
||||
for i, tab in enumerate(self.TABS):
|
||||
# Check if tab should be shown
|
||||
if i == 1 and not has_monthly:
|
||||
continue
|
||||
if i == 2 and not has_weekly:
|
||||
continue
|
||||
|
||||
if i == self.active_tab:
|
||||
text.append(f"[{tab}]", style="bold white on blue")
|
||||
else:
|
||||
text.append(f"[{tab}]", style="dim")
|
||||
text.append(" ")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class MetricsPanel:
|
||||
"""Panel showing trading metrics."""
|
||||
|
||||
def __init__(self, metrics: Optional[PeriodMetrics] = None):
|
||||
self.metrics = metrics
|
||||
|
||||
def render(self) -> Table:
|
||||
"""Render metrics as a table."""
|
||||
table = Table(
|
||||
show_header=False,
|
||||
show_edge=False,
|
||||
box=None,
|
||||
padding=(0, 1),
|
||||
)
|
||||
table.add_column("Label", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
if self.metrics is None or self.metrics.total_trades == 0:
|
||||
table.add_row("Status", Text("No trade data", style="dim"))
|
||||
return table
|
||||
|
||||
m = self.metrics
|
||||
|
||||
table.add_row("Total PnL:", format_pnl(m.total_pnl))
|
||||
table.add_row("Win Rate:", Text(f"{m.win_rate:.1f}%", style="white"))
|
||||
table.add_row("Total Trades:", Text(str(m.total_trades), style="white"))
|
||||
table.add_row(
|
||||
"Win/Loss:",
|
||||
Text(f"{m.winning_trades}/{m.losing_trades}", style="white"),
|
||||
)
|
||||
table.add_row(
|
||||
"Avg Duration:",
|
||||
Text(f"{m.avg_trade_duration_hours:.1f}h", style="white"),
|
||||
)
|
||||
table.add_row("Max Drawdown:", format_pnl(-m.max_drawdown))
|
||||
table.add_row("Best Trade:", format_pnl(m.best_trade))
|
||||
table.add_row("Worst Trade:", format_pnl(m.worst_trade))
|
||||
|
||||
return table
|
||||
|
||||
|
||||
class PositionPanel:
|
||||
"""Panel showing current position."""
|
||||
|
||||
def __init__(self, position: Optional[PositionState] = None):
|
||||
self.position = position
|
||||
|
||||
def render(self) -> Table:
|
||||
"""Render position as a table."""
|
||||
table = Table(
|
||||
show_header=False,
|
||||
show_edge=False,
|
||||
box=None,
|
||||
padding=(0, 1),
|
||||
)
|
||||
table.add_column("Label", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
if self.position is None:
|
||||
table.add_row("Status", Text("No open position", style="dim"))
|
||||
return table
|
||||
|
||||
p = self.position
|
||||
|
||||
table.add_row("Side:", format_side(p.side))
|
||||
table.add_row("Entry:", Text(f"${p.entry_price:.2f}", style="white"))
|
||||
table.add_row("Current:", Text(f"${p.current_price:.2f}", style="white"))
|
||||
|
||||
# Unrealized PnL
|
||||
pnl_text = Text()
|
||||
pnl_text.append_text(format_pnl(p.unrealized_pnl))
|
||||
pnl_text.append(" (")
|
||||
pnl_text.append_text(format_pct(p.unrealized_pnl_pct))
|
||||
pnl_text.append(")")
|
||||
table.add_row("Unrealized:", pnl_text)
|
||||
|
||||
table.add_row("Size:", Text(f"${p.size_usdt:.2f}", style="white"))
|
||||
|
||||
# SL/TP
|
||||
if p.side == "long":
|
||||
sl_dist = (p.stop_loss_price / p.entry_price - 1) * 100
|
||||
tp_dist = (p.take_profit_price / p.entry_price - 1) * 100
|
||||
else:
|
||||
sl_dist = (1 - p.stop_loss_price / p.entry_price) * 100
|
||||
tp_dist = (1 - p.take_profit_price / p.entry_price) * 100
|
||||
|
||||
sl_text = Text(f"${p.stop_loss_price:.2f} ({sl_dist:+.1f}%)", style="red")
|
||||
tp_text = Text(f"${p.take_profit_price:.2f} ({tp_dist:+.1f}%)", style="green")
|
||||
|
||||
table.add_row("Stop Loss:", sl_text)
|
||||
table.add_row("Take Profit:", tp_text)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
class AccountPanel:
|
||||
"""Panel showing account information."""
|
||||
|
||||
def __init__(self, account: Optional[AccountState] = None):
|
||||
self.account = account
|
||||
|
||||
def render(self) -> Table:
|
||||
"""Render account info as a table."""
|
||||
table = Table(
|
||||
show_header=False,
|
||||
show_edge=False,
|
||||
box=None,
|
||||
padding=(0, 1),
|
||||
)
|
||||
table.add_column("Label", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
if self.account is None:
|
||||
table.add_row("Status", Text("Loading...", style="dim"))
|
||||
return table
|
||||
|
||||
a = self.account
|
||||
|
||||
table.add_row("Balance:", Text(f"${a.balance:.2f}", style="white"))
|
||||
table.add_row("Available:", Text(f"${a.available:.2f}", style="white"))
|
||||
table.add_row("Leverage:", Text(f"{a.leverage}x", style="cyan"))
|
||||
|
||||
return table
|
||||
|
||||
|
||||
class StrategyPanel:
|
||||
"""Panel showing strategy state."""
|
||||
|
||||
def __init__(self, strategy: Optional[StrategyState] = None):
|
||||
self.strategy = strategy
|
||||
|
||||
def render(self) -> Table:
|
||||
"""Render strategy state as a table."""
|
||||
table = Table(
|
||||
show_header=False,
|
||||
show_edge=False,
|
||||
box=None,
|
||||
padding=(0, 1),
|
||||
)
|
||||
table.add_column("Label", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
if self.strategy is None:
|
||||
table.add_row("Status", Text("Waiting...", style="dim"))
|
||||
return table
|
||||
|
||||
s = self.strategy
|
||||
|
||||
# Z-score with color based on threshold
|
||||
z_style = "white"
|
||||
if abs(s.z_score) > 1.0:
|
||||
z_style = "yellow"
|
||||
if abs(s.z_score) > 1.5:
|
||||
z_style = "bold yellow"
|
||||
table.add_row("Z-Score:", Text(f"{s.z_score:.2f}", style=z_style))
|
||||
|
||||
# Probability with color
|
||||
prob_style = "white"
|
||||
if s.probability > 0.5:
|
||||
prob_style = "green"
|
||||
if s.probability > 0.7:
|
||||
prob_style = "bold green"
|
||||
table.add_row("Probability:", Text(f"{s.probability:.2f}", style=prob_style))
|
||||
|
||||
# Funding rate
|
||||
funding_style = "green" if s.funding_rate >= 0 else "red"
|
||||
table.add_row(
|
||||
"Funding:",
|
||||
Text(f"{s.funding_rate:.4f}", style=funding_style),
|
||||
)
|
||||
|
||||
# Last action
|
||||
action_style = "white"
|
||||
if s.last_action == "entry":
|
||||
action_style = "bold cyan"
|
||||
elif s.last_action == "check_exit":
|
||||
action_style = "yellow"
|
||||
table.add_row("Last Action:", Text(s.last_action, style=action_style))
|
||||
|
||||
return table
|
||||
|
||||
|
||||
class LogPanel:
|
||||
"""Panel showing log entries."""
|
||||
|
||||
def __init__(self, log_buffer: LogBuffer):
|
||||
self.log_buffer = log_buffer
|
||||
|
||||
def render(self, height: int = 10) -> Panel:
|
||||
"""Render log panel."""
|
||||
filter_name = self.log_buffer.get_current_filter().title()
|
||||
entries = self.log_buffer.get_filtered(limit=height - 2)
|
||||
|
||||
lines = []
|
||||
for entry in entries:
|
||||
line = Text()
|
||||
line.append(f"{entry.timestamp} ", style="dim")
|
||||
line.append(f"[{entry.level}] ", style=entry.level_color)
|
||||
line.append(entry.message)
|
||||
lines.append(line)
|
||||
|
||||
if not lines:
|
||||
lines.append(Text("No logs to display", style="dim"))
|
||||
|
||||
content = Group(*lines)
|
||||
|
||||
# Build "tabbed" title
|
||||
tabs = []
|
||||
|
||||
# All Logs tab
|
||||
if filter_name == "All":
|
||||
tabs.append("[bold white on blue] [L]ogs [/]")
|
||||
else:
|
||||
tabs.append("[dim] [L]ogs [/]")
|
||||
|
||||
# Trades tab
|
||||
if filter_name == "Trades":
|
||||
tabs.append("[bold white on blue] [T]rades [/]")
|
||||
else:
|
||||
tabs.append("[dim] [T]rades [/]")
|
||||
|
||||
# Errors tab
|
||||
if filter_name == "Errors":
|
||||
tabs.append("[bold white on blue] [E]rrors [/]")
|
||||
else:
|
||||
tabs.append("[dim] [E]rrors [/]")
|
||||
|
||||
title = " ".join(tabs)
|
||||
subtitle = "Press 'l', 't', 'e' to switch tabs"
|
||||
|
||||
return Panel(
|
||||
content,
|
||||
title=title,
|
||||
subtitle=subtitle,
|
||||
title_align="left",
|
||||
subtitle_align="right",
|
||||
border_style="blue",
|
||||
)
|
||||
|
||||
|
||||
class HelpBar:
|
||||
"""Bottom help bar with keyboard shortcuts."""
|
||||
|
||||
def render(self) -> Text:
|
||||
"""Render help bar."""
|
||||
text = Text()
|
||||
text.append(" [q]", style="bold")
|
||||
text.append("Quit ", style="dim")
|
||||
text.append("[r]", style="bold")
|
||||
text.append("Refresh ", style="dim")
|
||||
text.append("[1-4]", style="bold")
|
||||
text.append("Tabs ", style="dim")
|
||||
text.append("[l/t/e]", style="bold")
|
||||
text.append("LogView", style="dim")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def build_summary_panel(
|
||||
state: SharedState,
|
||||
metrics: Optional[PeriodMetrics],
|
||||
tab_bar: TabBar,
|
||||
has_monthly: bool,
|
||||
has_weekly: bool,
|
||||
) -> Panel:
|
||||
"""Build the complete summary panel with all sections."""
|
||||
# Create layout for summary content
|
||||
layout = Layout()
|
||||
|
||||
# Tab bar at top
|
||||
tabs = tab_bar.render(has_monthly, has_weekly)
|
||||
|
||||
# Create tables for each section
|
||||
metrics_table = MetricsPanel(metrics).render()
|
||||
position_table = PositionPanel(state.get_position()).render()
|
||||
account_table = AccountPanel(state.get_account()).render()
|
||||
strategy_table = StrategyPanel(state.get_strategy()).render()
|
||||
|
||||
# Build two-column layout
|
||||
left_col = Table(show_header=True, show_edge=False, box=None, padding=(0, 2))
|
||||
left_col.add_column("PERFORMANCE", style="bold cyan")
|
||||
left_col.add_column("ACCOUNT", style="bold cyan")
|
||||
left_col.add_row(metrics_table, account_table)
|
||||
|
||||
right_col = Table(show_header=True, show_edge=False, box=None, padding=(0, 2))
|
||||
right_col.add_column("CURRENT POSITION", style="bold cyan")
|
||||
right_col.add_column("STRATEGY STATE", style="bold cyan")
|
||||
right_col.add_row(position_table, strategy_table)
|
||||
|
||||
# Combine into main table
|
||||
main_table = Table(show_header=False, show_edge=False, box=None, expand=True)
|
||||
main_table.add_column(ratio=1)
|
||||
main_table.add_column(ratio=1)
|
||||
main_table.add_row(left_col, right_col)
|
||||
|
||||
content = Group(tabs, Text(""), main_table)
|
||||
|
||||
return Panel(content, border_style="blue")
|
||||
195
live_trading/ui/state.py
Normal file
195
live_trading/ui/state.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Thread-safe shared state for UI and trading loop."""
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
@dataclass
|
||||
class PositionState:
|
||||
"""Current position information."""
|
||||
|
||||
trade_id: str = ""
|
||||
symbol: str = ""
|
||||
side: str = ""
|
||||
entry_price: float = 0.0
|
||||
current_price: float = 0.0
|
||||
size: float = 0.0
|
||||
size_usdt: float = 0.0
|
||||
unrealized_pnl: float = 0.0
|
||||
unrealized_pnl_pct: float = 0.0
|
||||
stop_loss_price: float = 0.0
|
||||
take_profit_price: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyState:
|
||||
"""Current strategy signal state."""
|
||||
|
||||
z_score: float = 0.0
|
||||
probability: float = 0.0
|
||||
funding_rate: float = 0.0
|
||||
last_action: str = "hold"
|
||||
last_reason: str = ""
|
||||
last_signal_time: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AccountState:
|
||||
"""Account balance information."""
|
||||
|
||||
balance: float = 0.0
|
||||
available: float = 0.0
|
||||
leverage: int = 1
|
||||
|
||||
|
||||
class SharedState:
|
||||
"""
|
||||
Thread-safe shared state between trading loop and UI.
|
||||
|
||||
All access to state fields should go through the getter/setter methods
|
||||
which use a lock for thread safety.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._position: Optional[PositionState] = None
|
||||
self._strategy = StrategyState()
|
||||
self._account = AccountState()
|
||||
self._is_running = True
|
||||
self._last_cycle_time: Optional[str] = None
|
||||
self._mode = "DEMO"
|
||||
self._eth_symbol = "ETH/USDT:USDT"
|
||||
self._btc_symbol = "BTC/USDT:USDT"
|
||||
|
||||
# Position methods
|
||||
def get_position(self) -> Optional[PositionState]:
|
||||
"""Get current position state."""
|
||||
with self._lock:
|
||||
return self._position
|
||||
|
||||
def set_position(self, position: Optional[PositionState]) -> None:
|
||||
"""Set current position state."""
|
||||
with self._lock:
|
||||
self._position = position
|
||||
|
||||
def update_position_price(self, current_price: float) -> None:
|
||||
"""Update current price and recalculate PnL."""
|
||||
with self._lock:
|
||||
if self._position is None:
|
||||
return
|
||||
|
||||
self._position.current_price = current_price
|
||||
|
||||
if self._position.side == "long":
|
||||
pnl = (current_price - self._position.entry_price)
|
||||
self._position.unrealized_pnl = pnl * self._position.size
|
||||
pnl_pct = (current_price / self._position.entry_price - 1) * 100
|
||||
else:
|
||||
pnl = (self._position.entry_price - current_price)
|
||||
self._position.unrealized_pnl = pnl * self._position.size
|
||||
pnl_pct = (1 - current_price / self._position.entry_price) * 100
|
||||
|
||||
self._position.unrealized_pnl_pct = pnl_pct
|
||||
|
||||
def clear_position(self) -> None:
|
||||
"""Clear current position."""
|
||||
with self._lock:
|
||||
self._position = None
|
||||
|
||||
# Strategy methods
|
||||
def get_strategy(self) -> StrategyState:
|
||||
"""Get current strategy state."""
|
||||
with self._lock:
|
||||
return StrategyState(
|
||||
z_score=self._strategy.z_score,
|
||||
probability=self._strategy.probability,
|
||||
funding_rate=self._strategy.funding_rate,
|
||||
last_action=self._strategy.last_action,
|
||||
last_reason=self._strategy.last_reason,
|
||||
last_signal_time=self._strategy.last_signal_time,
|
||||
)
|
||||
|
||||
def update_strategy(
|
||||
self,
|
||||
z_score: float,
|
||||
probability: float,
|
||||
funding_rate: float,
|
||||
action: str,
|
||||
reason: str,
|
||||
) -> None:
|
||||
"""Update strategy state."""
|
||||
with self._lock:
|
||||
self._strategy.z_score = z_score
|
||||
self._strategy.probability = probability
|
||||
self._strategy.funding_rate = funding_rate
|
||||
self._strategy.last_action = action
|
||||
self._strategy.last_reason = reason
|
||||
self._strategy.last_signal_time = datetime.now(
|
||||
timezone.utc
|
||||
).isoformat()
|
||||
|
||||
# Account methods
|
||||
def get_account(self) -> AccountState:
|
||||
"""Get current account state."""
|
||||
with self._lock:
|
||||
return AccountState(
|
||||
balance=self._account.balance,
|
||||
available=self._account.available,
|
||||
leverage=self._account.leverage,
|
||||
)
|
||||
|
||||
def update_account(
|
||||
self,
|
||||
balance: float,
|
||||
available: float,
|
||||
leverage: int,
|
||||
) -> None:
|
||||
"""Update account state."""
|
||||
with self._lock:
|
||||
self._account.balance = balance
|
||||
self._account.available = available
|
||||
self._account.leverage = leverage
|
||||
|
||||
# Control methods
|
||||
def is_running(self) -> bool:
|
||||
"""Check if trading loop is running."""
|
||||
with self._lock:
|
||||
return self._is_running
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal to stop trading loop."""
|
||||
with self._lock:
|
||||
self._is_running = False
|
||||
|
||||
def get_last_cycle_time(self) -> Optional[str]:
|
||||
"""Get last trading cycle time."""
|
||||
with self._lock:
|
||||
return self._last_cycle_time
|
||||
|
||||
def set_last_cycle_time(self, time_str: str) -> None:
|
||||
"""Set last trading cycle time."""
|
||||
with self._lock:
|
||||
self._last_cycle_time = time_str
|
||||
|
||||
# Config methods
|
||||
def get_mode(self) -> str:
|
||||
"""Get trading mode (DEMO/LIVE)."""
|
||||
with self._lock:
|
||||
return self._mode
|
||||
|
||||
def set_mode(self, mode: str) -> None:
|
||||
"""Set trading mode."""
|
||||
with self._lock:
|
||||
self._mode = mode
|
||||
|
||||
def get_symbols(self) -> tuple[str, str]:
|
||||
"""Get trading symbols (eth, btc)."""
|
||||
with self._lock:
|
||||
return self._eth_symbol, self._btc_symbol
|
||||
|
||||
def set_symbols(self, eth_symbol: str, btc_symbol: str) -> None:
|
||||
"""Set trading symbols."""
|
||||
with self._lock:
|
||||
self._eth_symbol = eth_symbol
|
||||
self._btc_symbol = btc_symbol
|
||||
Reference in New Issue
Block a user