- Introduced `train_daily.sh` for automating daily model retraining, including data download and model training steps. - Added `install_cron.sh` for setting up a cron job to run the daily training script. - Created `setup_schedule.sh` for configuring Systemd timers for daily training tasks. - Implemented a terminal UI using Rich for real-time monitoring of trading performance, including metrics display and log handling. - Updated `pyproject.toml` to include the `rich` dependency for UI functionality. - Enhanced `.gitignore` to exclude model and log files. - Added database support for trade persistence and metrics calculation. - Updated README with installation and usage instructions for the new features.
416 lines
15 KiB
Python
416 lines
15 KiB
Python
"""
|
|
Position Manager for Live Trading.
|
|
|
|
Tracks open positions, manages risk, and handles SL/TP logic.
|
|
"""
|
|
import csv
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Optional, TYPE_CHECKING
|
|
from dataclasses import dataclass, asdict
|
|
|
|
from .okx_client import OKXClient
|
|
from .config import TradingConfig, PathConfig
|
|
|
|
if TYPE_CHECKING:
|
|
from .db.database import TradingDatabase
|
|
from .db.models import Trade
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Position:
|
|
"""Represents an open trading position."""
|
|
trade_id: str
|
|
symbol: str
|
|
side: str # "long" or "short"
|
|
entry_price: float
|
|
entry_time: str # ISO format
|
|
size: float # Amount in base currency (e.g., ETH)
|
|
size_usdt: float # Notional value in USDT
|
|
stop_loss_price: float
|
|
take_profit_price: float
|
|
current_price: float = 0.0
|
|
unrealized_pnl: float = 0.0
|
|
unrealized_pnl_pct: float = 0.0
|
|
order_id: str = "" # Entry order ID from exchange
|
|
|
|
def update_pnl(self, current_price: float) -> None:
|
|
"""Update unrealized PnL based on current price."""
|
|
self.current_price = current_price
|
|
|
|
if self.side == "long":
|
|
self.unrealized_pnl = (current_price - self.entry_price) * self.size
|
|
self.unrealized_pnl_pct = (current_price / self.entry_price - 1) * 100
|
|
else: # short
|
|
self.unrealized_pnl = (self.entry_price - current_price) * self.size
|
|
self.unrealized_pnl_pct = (1 - current_price / self.entry_price) * 100
|
|
|
|
def should_stop_loss(self, current_price: float) -> bool:
|
|
"""Check if stop-loss should trigger."""
|
|
if self.side == "long":
|
|
return current_price <= self.stop_loss_price
|
|
return current_price >= self.stop_loss_price
|
|
|
|
def should_take_profit(self, current_price: float) -> bool:
|
|
"""Check if take-profit should trigger."""
|
|
if self.side == "long":
|
|
return current_price >= self.take_profit_price
|
|
return current_price <= self.take_profit_price
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to dictionary for JSON serialization."""
|
|
return asdict(self)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict) -> 'Position':
|
|
"""Create Position from dictionary."""
|
|
return cls(**data)
|
|
|
|
|
|
class PositionManager:
|
|
"""
|
|
Manages trading positions with persistence.
|
|
|
|
Tracks open positions, enforces risk limits, and handles
|
|
position lifecycle (open, update, close).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
okx_client: OKXClient,
|
|
trading_config: TradingConfig,
|
|
path_config: PathConfig,
|
|
database: Optional["TradingDatabase"] = None,
|
|
):
|
|
self.client = okx_client
|
|
self.config = trading_config
|
|
self.paths = path_config
|
|
self.db = database
|
|
self.positions: dict[str, Position] = {}
|
|
self.trade_log: list[dict] = []
|
|
self._load_positions()
|
|
|
|
def _load_positions(self) -> None:
|
|
"""Load positions from file."""
|
|
if self.paths.positions_file.exists():
|
|
try:
|
|
with open(self.paths.positions_file, 'r') as f:
|
|
data = json.load(f)
|
|
for trade_id, pos_data in data.items():
|
|
self.positions[trade_id] = Position.from_dict(pos_data)
|
|
logger.info(f"Loaded {len(self.positions)} positions from file")
|
|
except Exception as e:
|
|
logger.warning(f"Could not load positions: {e}")
|
|
|
|
def save_positions(self) -> None:
|
|
"""Save positions to file."""
|
|
try:
|
|
data = {
|
|
trade_id: pos.to_dict()
|
|
for trade_id, pos in self.positions.items()
|
|
}
|
|
with open(self.paths.positions_file, 'w') as f:
|
|
json.dump(data, f, indent=2)
|
|
logger.debug(f"Saved {len(self.positions)} positions")
|
|
except Exception as e:
|
|
logger.error(f"Could not save positions: {e}")
|
|
|
|
def can_open_position(self) -> bool:
|
|
"""Check if we can open a new position."""
|
|
return len(self.positions) < self.config.max_concurrent_positions
|
|
|
|
def get_position_for_symbol(self, symbol: str) -> Optional[Position]:
|
|
"""Get position for a specific symbol."""
|
|
for pos in self.positions.values():
|
|
if pos.symbol == symbol:
|
|
return pos
|
|
return None
|
|
|
|
def open_position(
|
|
self,
|
|
symbol: str,
|
|
side: str,
|
|
entry_price: float,
|
|
size: float,
|
|
stop_loss_price: float,
|
|
take_profit_price: float,
|
|
order_id: str = ""
|
|
) -> Optional[Position]:
|
|
"""
|
|
Open a new position.
|
|
|
|
Args:
|
|
symbol: Trading pair symbol
|
|
side: "long" or "short"
|
|
entry_price: Entry price
|
|
size: Position size in base currency
|
|
stop_loss_price: Stop-loss price
|
|
take_profit_price: Take-profit price
|
|
order_id: Entry order ID from exchange
|
|
|
|
Returns:
|
|
Position object or None if failed
|
|
"""
|
|
if not self.can_open_position():
|
|
logger.warning("Cannot open position: max concurrent positions reached")
|
|
return None
|
|
|
|
# Check if already have position for this symbol
|
|
existing = self.get_position_for_symbol(symbol)
|
|
if existing:
|
|
logger.warning(f"Already have position for {symbol}")
|
|
return None
|
|
|
|
# Generate trade ID
|
|
now = datetime.now(timezone.utc)
|
|
trade_id = f"{symbol}_{now.strftime('%Y%m%d_%H%M%S')}"
|
|
|
|
position = Position(
|
|
trade_id=trade_id,
|
|
symbol=symbol,
|
|
side=side,
|
|
entry_price=entry_price,
|
|
entry_time=now.isoformat(),
|
|
size=size,
|
|
size_usdt=entry_price * size,
|
|
stop_loss_price=stop_loss_price,
|
|
take_profit_price=take_profit_price,
|
|
current_price=entry_price,
|
|
order_id=order_id,
|
|
)
|
|
|
|
self.positions[trade_id] = position
|
|
self.save_positions()
|
|
|
|
logger.info(
|
|
f"Opened {side.upper()} position: {size} {symbol} @ {entry_price}, "
|
|
f"SL={stop_loss_price}, TP={take_profit_price}"
|
|
)
|
|
|
|
return position
|
|
|
|
def close_position(
|
|
self,
|
|
trade_id: str,
|
|
exit_price: float,
|
|
reason: str = "manual",
|
|
exit_order_id: str = ""
|
|
) -> Optional[dict]:
|
|
"""
|
|
Close a position and record the trade.
|
|
|
|
Args:
|
|
trade_id: Position trade ID
|
|
exit_price: Exit price
|
|
reason: Reason for closing (e.g., "stop_loss", "take_profit", "signal")
|
|
exit_order_id: Exit order ID from exchange
|
|
|
|
Returns:
|
|
Trade record dictionary
|
|
"""
|
|
if trade_id not in self.positions:
|
|
logger.warning(f"Position {trade_id} not found")
|
|
return None
|
|
|
|
position = self.positions[trade_id]
|
|
position.update_pnl(exit_price)
|
|
|
|
# Calculate final PnL
|
|
entry_time = datetime.fromisoformat(position.entry_time)
|
|
exit_time = datetime.now(timezone.utc)
|
|
hold_duration = (exit_time - entry_time).total_seconds() / 3600 # hours
|
|
|
|
trade_record = {
|
|
'trade_id': trade_id,
|
|
'symbol': position.symbol,
|
|
'side': position.side,
|
|
'entry_price': position.entry_price,
|
|
'exit_price': exit_price,
|
|
'size': position.size,
|
|
'size_usdt': position.size_usdt,
|
|
'pnl_usd': position.unrealized_pnl,
|
|
'pnl_pct': position.unrealized_pnl_pct,
|
|
'entry_time': position.entry_time,
|
|
'exit_time': exit_time.isoformat(),
|
|
'hold_duration_hours': hold_duration,
|
|
'reason': reason,
|
|
'order_id_entry': position.order_id,
|
|
'order_id_exit': exit_order_id,
|
|
}
|
|
|
|
self.trade_log.append(trade_record)
|
|
del self.positions[trade_id]
|
|
self.save_positions()
|
|
self._append_trade_log(trade_record)
|
|
|
|
logger.info(
|
|
f"Closed {position.side.upper()} position: {position.size} {position.symbol} "
|
|
f"@ {exit_price}, PnL=${position.unrealized_pnl:.2f} ({position.unrealized_pnl_pct:.2f}%), "
|
|
f"reason={reason}"
|
|
)
|
|
|
|
return trade_record
|
|
|
|
def _append_trade_log(self, trade_record: dict) -> None:
|
|
"""Append trade record to CSV and SQLite database."""
|
|
# Write to CSV (backup/compatibility)
|
|
self._append_trade_csv(trade_record)
|
|
|
|
# Write to SQLite (primary)
|
|
self._append_trade_db(trade_record)
|
|
|
|
def _append_trade_csv(self, trade_record: dict) -> None:
|
|
"""Append trade record to CSV log file."""
|
|
file_exists = self.paths.trade_log_file.exists()
|
|
|
|
try:
|
|
with open(self.paths.trade_log_file, 'a', newline='') as f:
|
|
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
|
|
if not file_exists:
|
|
writer.writeheader()
|
|
writer.writerow(trade_record)
|
|
except Exception as e:
|
|
logger.error(f"Failed to write trade to CSV: {e}")
|
|
|
|
def _append_trade_db(self, trade_record: dict) -> None:
|
|
"""Append trade record to SQLite database."""
|
|
if self.db is None:
|
|
return
|
|
|
|
try:
|
|
from .db.models import Trade
|
|
|
|
trade = Trade(
|
|
trade_id=trade_record['trade_id'],
|
|
symbol=trade_record['symbol'],
|
|
side=trade_record['side'],
|
|
entry_price=trade_record['entry_price'],
|
|
exit_price=trade_record.get('exit_price'),
|
|
size=trade_record['size'],
|
|
size_usdt=trade_record['size_usdt'],
|
|
pnl_usd=trade_record.get('pnl_usd'),
|
|
pnl_pct=trade_record.get('pnl_pct'),
|
|
entry_time=trade_record['entry_time'],
|
|
exit_time=trade_record.get('exit_time'),
|
|
hold_duration_hours=trade_record.get('hold_duration_hours'),
|
|
reason=trade_record.get('reason'),
|
|
order_id_entry=trade_record.get('order_id_entry'),
|
|
order_id_exit=trade_record.get('order_id_exit'),
|
|
)
|
|
self.db.insert_trade(trade)
|
|
logger.debug(f"Trade {trade.trade_id} saved to database")
|
|
except Exception as e:
|
|
logger.error(f"Failed to write trade to database: {e}")
|
|
|
|
def update_positions(self, current_prices: dict[str, float]) -> list[dict]:
|
|
"""
|
|
Update all positions with current prices and check SL/TP.
|
|
|
|
Args:
|
|
current_prices: Dictionary of symbol -> current price
|
|
|
|
Returns:
|
|
List of closed trade records
|
|
"""
|
|
closed_trades = []
|
|
|
|
for trade_id in list(self.positions.keys()):
|
|
position = self.positions[trade_id]
|
|
|
|
if position.symbol not in current_prices:
|
|
continue
|
|
|
|
current_price = current_prices[position.symbol]
|
|
position.update_pnl(current_price)
|
|
|
|
# Check stop-loss
|
|
if position.should_stop_loss(current_price):
|
|
logger.warning(
|
|
f"Stop-loss triggered for {trade_id} at {current_price}"
|
|
)
|
|
# Close position on exchange
|
|
exit_order_id = ""
|
|
try:
|
|
exit_order = self.client.close_position(position.symbol)
|
|
exit_order_id = exit_order.get('id', '') if exit_order else ''
|
|
except Exception as e:
|
|
logger.error(f"Failed to close position on exchange: {e}")
|
|
|
|
record = self.close_position(trade_id, current_price, "stop_loss", exit_order_id)
|
|
if record:
|
|
closed_trades.append(record)
|
|
continue
|
|
|
|
# Check take-profit
|
|
if position.should_take_profit(current_price):
|
|
logger.info(
|
|
f"Take-profit triggered for {trade_id} at {current_price}"
|
|
)
|
|
# Close position on exchange
|
|
exit_order_id = ""
|
|
try:
|
|
exit_order = self.client.close_position(position.symbol)
|
|
exit_order_id = exit_order.get('id', '') if exit_order else ''
|
|
except Exception as e:
|
|
logger.error(f"Failed to close position on exchange: {e}")
|
|
|
|
record = self.close_position(trade_id, current_price, "take_profit", exit_order_id)
|
|
if record:
|
|
closed_trades.append(record)
|
|
|
|
self.save_positions()
|
|
return closed_trades
|
|
|
|
def sync_with_exchange(self) -> None:
|
|
"""
|
|
Sync local positions with exchange positions.
|
|
|
|
Reconciles any discrepancies between local tracking
|
|
and actual exchange positions.
|
|
"""
|
|
try:
|
|
exchange_positions = self.client.get_positions()
|
|
exchange_symbols = {p['symbol'] for p in exchange_positions}
|
|
|
|
# Check for positions we have locally but not on exchange
|
|
for trade_id in list(self.positions.keys()):
|
|
pos = self.positions[trade_id]
|
|
if pos.symbol not in exchange_symbols:
|
|
logger.warning(
|
|
f"Position {trade_id} not found on exchange, removing"
|
|
)
|
|
# Get last price and close
|
|
try:
|
|
ticker = self.client.get_ticker(pos.symbol)
|
|
exit_price = ticker['last']
|
|
except Exception:
|
|
exit_price = pos.current_price
|
|
|
|
self.close_position(trade_id, exit_price, "sync_removed")
|
|
|
|
logger.info(f"Position sync complete: {len(self.positions)} local positions")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Position sync failed: {e}")
|
|
|
|
def get_portfolio_summary(self) -> dict:
|
|
"""
|
|
Get portfolio summary.
|
|
|
|
Returns:
|
|
Dictionary with portfolio statistics
|
|
"""
|
|
total_exposure = sum(p.size_usdt for p in self.positions.values())
|
|
total_unrealized_pnl = sum(p.unrealized_pnl for p in self.positions.values())
|
|
|
|
return {
|
|
'open_positions': len(self.positions),
|
|
'total_exposure_usdt': total_exposure,
|
|
'total_unrealized_pnl': total_unrealized_pnl,
|
|
'positions': [p.to_dict() for p in self.positions.values()],
|
|
}
|