""" Position Manager for Live Trading. Tracks open positions, manages risk, and handles SL/TP logic. """ import json import logging from datetime import datetime, timezone from pathlib import Path from typing import Optional from dataclasses import dataclass, field, asdict from .okx_client import OKXClient from .config import TradingConfig, PathConfig logger = logging.getLogger(__name__) @dataclass class Position: """Represents an open trading position.""" trade_id: str symbol: str side: str # "long" or "short" entry_price: float entry_time: str # ISO format size: float # Amount in base currency (e.g., ETH) size_usdt: float # Notional value in USDT stop_loss_price: float take_profit_price: float current_price: float = 0.0 unrealized_pnl: float = 0.0 unrealized_pnl_pct: float = 0.0 order_id: str = "" # Entry order ID from exchange def update_pnl(self, current_price: float) -> None: """Update unrealized PnL based on current price.""" self.current_price = current_price if self.side == "long": self.unrealized_pnl = (current_price - self.entry_price) * self.size self.unrealized_pnl_pct = (current_price / self.entry_price - 1) * 100 else: # short self.unrealized_pnl = (self.entry_price - current_price) * self.size self.unrealized_pnl_pct = (1 - current_price / self.entry_price) * 100 def should_stop_loss(self, current_price: float) -> bool: """Check if stop-loss should trigger.""" if self.side == "long": return current_price <= self.stop_loss_price return current_price >= self.stop_loss_price def should_take_profit(self, current_price: float) -> bool: """Check if take-profit should trigger.""" if self.side == "long": return current_price >= self.take_profit_price return current_price <= self.take_profit_price def to_dict(self) -> dict: """Convert to dictionary for JSON serialization.""" return asdict(self) @classmethod def from_dict(cls, data: dict) -> 'Position': """Create Position from dictionary.""" return cls(**data) class PositionManager: """ Manages trading positions with persistence. Tracks open positions, enforces risk limits, and handles position lifecycle (open, update, close). """ def __init__( self, okx_client: OKXClient, trading_config: TradingConfig, path_config: PathConfig ): self.client = okx_client self.config = trading_config self.paths = path_config self.positions: dict[str, Position] = {} self.trade_log: list[dict] = [] self._load_positions() def _load_positions(self) -> None: """Load positions from file.""" if self.paths.positions_file.exists(): try: with open(self.paths.positions_file, 'r') as f: data = json.load(f) for trade_id, pos_data in data.items(): self.positions[trade_id] = Position.from_dict(pos_data) logger.info(f"Loaded {len(self.positions)} positions from file") except Exception as e: logger.warning(f"Could not load positions: {e}") def save_positions(self) -> None: """Save positions to file.""" try: data = { trade_id: pos.to_dict() for trade_id, pos in self.positions.items() } with open(self.paths.positions_file, 'w') as f: json.dump(data, f, indent=2) logger.debug(f"Saved {len(self.positions)} positions") except Exception as e: logger.error(f"Could not save positions: {e}") def can_open_position(self) -> bool: """Check if we can open a new position.""" return len(self.positions) < self.config.max_concurrent_positions def get_position_for_symbol(self, symbol: str) -> Optional[Position]: """Get position for a specific symbol.""" for pos in self.positions.values(): if pos.symbol == symbol: return pos return None def open_position( self, symbol: str, side: str, entry_price: float, size: float, stop_loss_price: float, take_profit_price: float, order_id: str = "" ) -> Optional[Position]: """ Open a new position. Args: symbol: Trading pair symbol side: "long" or "short" entry_price: Entry price size: Position size in base currency stop_loss_price: Stop-loss price take_profit_price: Take-profit price order_id: Entry order ID from exchange Returns: Position object or None if failed """ if not self.can_open_position(): logger.warning("Cannot open position: max concurrent positions reached") return None # Check if already have position for this symbol existing = self.get_position_for_symbol(symbol) if existing: logger.warning(f"Already have position for {symbol}") return None # Generate trade ID now = datetime.now(timezone.utc) trade_id = f"{symbol}_{now.strftime('%Y%m%d_%H%M%S')}" position = Position( trade_id=trade_id, symbol=symbol, side=side, entry_price=entry_price, entry_time=now.isoformat(), size=size, size_usdt=entry_price * size, stop_loss_price=stop_loss_price, take_profit_price=take_profit_price, current_price=entry_price, order_id=order_id, ) self.positions[trade_id] = position self.save_positions() logger.info( f"Opened {side.upper()} position: {size} {symbol} @ {entry_price}, " f"SL={stop_loss_price}, TP={take_profit_price}" ) return position def close_position( self, trade_id: str, exit_price: float, reason: str = "manual", exit_order_id: str = "" ) -> Optional[dict]: """ Close a position and record the trade. Args: trade_id: Position trade ID exit_price: Exit price reason: Reason for closing (e.g., "stop_loss", "take_profit", "signal") exit_order_id: Exit order ID from exchange Returns: Trade record dictionary """ if trade_id not in self.positions: logger.warning(f"Position {trade_id} not found") return None position = self.positions[trade_id] position.update_pnl(exit_price) # Calculate final PnL entry_time = datetime.fromisoformat(position.entry_time) exit_time = datetime.now(timezone.utc) hold_duration = (exit_time - entry_time).total_seconds() / 3600 # hours trade_record = { 'trade_id': trade_id, 'symbol': position.symbol, 'side': position.side, 'entry_price': position.entry_price, 'exit_price': exit_price, 'size': position.size, 'size_usdt': position.size_usdt, 'pnl_usd': position.unrealized_pnl, 'pnl_pct': position.unrealized_pnl_pct, 'entry_time': position.entry_time, 'exit_time': exit_time.isoformat(), 'hold_duration_hours': hold_duration, 'reason': reason, 'order_id_entry': position.order_id, 'order_id_exit': exit_order_id, } self.trade_log.append(trade_record) del self.positions[trade_id] self.save_positions() self._append_trade_log(trade_record) logger.info( f"Closed {position.side.upper()} position: {position.size} {position.symbol} " f"@ {exit_price}, PnL=${position.unrealized_pnl:.2f} ({position.unrealized_pnl_pct:.2f}%), " f"reason={reason}" ) return trade_record def _append_trade_log(self, trade_record: dict) -> None: """Append trade record to CSV log file.""" import csv file_exists = self.paths.trade_log_file.exists() with open(self.paths.trade_log_file, 'a', newline='') as f: writer = csv.DictWriter(f, fieldnames=trade_record.keys()) if not file_exists: writer.writeheader() writer.writerow(trade_record) def update_positions(self, current_prices: dict[str, float]) -> list[dict]: """ Update all positions with current prices and check SL/TP. Args: current_prices: Dictionary of symbol -> current price Returns: List of closed trade records """ closed_trades = [] for trade_id in list(self.positions.keys()): position = self.positions[trade_id] if position.symbol not in current_prices: continue current_price = current_prices[position.symbol] position.update_pnl(current_price) # Check stop-loss if position.should_stop_loss(current_price): logger.warning( f"Stop-loss triggered for {trade_id} at {current_price}" ) # Close position on exchange exit_order_id = "" try: exit_order = self.client.close_position(position.symbol) exit_order_id = exit_order.get('id', '') if exit_order else '' except Exception as e: logger.error(f"Failed to close position on exchange: {e}") record = self.close_position(trade_id, current_price, "stop_loss", exit_order_id) if record: closed_trades.append(record) continue # Check take-profit if position.should_take_profit(current_price): logger.info( f"Take-profit triggered for {trade_id} at {current_price}" ) # Close position on exchange exit_order_id = "" try: exit_order = self.client.close_position(position.symbol) exit_order_id = exit_order.get('id', '') if exit_order else '' except Exception as e: logger.error(f"Failed to close position on exchange: {e}") record = self.close_position(trade_id, current_price, "take_profit", exit_order_id) if record: closed_trades.append(record) self.save_positions() return closed_trades def sync_with_exchange(self) -> None: """ Sync local positions with exchange positions. Reconciles any discrepancies between local tracking and actual exchange positions. """ try: exchange_positions = self.client.get_positions() exchange_symbols = {p['symbol'] for p in exchange_positions} # Check for positions we have locally but not on exchange for trade_id in list(self.positions.keys()): pos = self.positions[trade_id] if pos.symbol not in exchange_symbols: logger.warning( f"Position {trade_id} not found on exchange, removing" ) # Get last price and close try: ticker = self.client.get_ticker(pos.symbol) exit_price = ticker['last'] except Exception: exit_price = pos.current_price self.close_position(trade_id, exit_price, "sync_removed") logger.info(f"Position sync complete: {len(self.positions)} local positions") except Exception as e: logger.error(f"Position sync failed: {e}") def get_portfolio_summary(self) -> dict: """ Get portfolio summary. Returns: Dictionary with portfolio statistics """ total_exposure = sum(p.size_usdt for p in self.positions.values()) total_unrealized_pnl = sum(p.unrealized_pnl for p in self.positions.values()) return { 'open_positions': len(self.positions), 'total_exposure_usdt': total_exposure, 'total_unrealized_pnl': total_unrealized_pnl, 'positions': [p.to_dict() for p in self.positions.values()], }