Files
lowkey_backtest/live_trading/position_manager.py
Simon Moisy 0c82c4f366 Implement FastAPI backend and Vue 3 frontend for Lowkey Backtest UI
- Added FastAPI backend with core API endpoints for strategies, backtests, and data management.
- Introduced Vue 3 frontend with a dark theme, enabling users to run backtests, adjust parameters, and compare results.
- Implemented Pydantic schemas for request/response validation and SQLAlchemy models for database interactions.
- Enhanced project structure with dedicated modules for services, routers, and components.
- Updated dependencies in `pyproject.toml` and `frontend/package.json` to include FastAPI, SQLAlchemy, and Vue-related packages.
- Improved `.gitignore` to exclude unnecessary files and directories.
2026-01-14 21:44:04 +08:00

370 lines
13 KiB
Python

"""
Position Manager for Live Trading.
Tracks open positions, manages risk, and handles SL/TP logic.
"""
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
from dataclasses import dataclass, field, asdict
from .okx_client import OKXClient
from .config import TradingConfig, PathConfig
logger = logging.getLogger(__name__)
@dataclass
class Position:
"""Represents an open trading position."""
trade_id: str
symbol: str
side: str # "long" or "short"
entry_price: float
entry_time: str # ISO format
size: float # Amount in base currency (e.g., ETH)
size_usdt: float # Notional value in USDT
stop_loss_price: float
take_profit_price: float
current_price: float = 0.0
unrealized_pnl: float = 0.0
unrealized_pnl_pct: float = 0.0
order_id: str = "" # Entry order ID from exchange
def update_pnl(self, current_price: float) -> None:
"""Update unrealized PnL based on current price."""
self.current_price = current_price
if self.side == "long":
self.unrealized_pnl = (current_price - self.entry_price) * self.size
self.unrealized_pnl_pct = (current_price / self.entry_price - 1) * 100
else: # short
self.unrealized_pnl = (self.entry_price - current_price) * self.size
self.unrealized_pnl_pct = (1 - current_price / self.entry_price) * 100
def should_stop_loss(self, current_price: float) -> bool:
"""Check if stop-loss should trigger."""
if self.side == "long":
return current_price <= self.stop_loss_price
return current_price >= self.stop_loss_price
def should_take_profit(self, current_price: float) -> bool:
"""Check if take-profit should trigger."""
if self.side == "long":
return current_price >= self.take_profit_price
return current_price <= self.take_profit_price
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return asdict(self)
@classmethod
def from_dict(cls, data: dict) -> 'Position':
"""Create Position from dictionary."""
return cls(**data)
class PositionManager:
"""
Manages trading positions with persistence.
Tracks open positions, enforces risk limits, and handles
position lifecycle (open, update, close).
"""
def __init__(
self,
okx_client: OKXClient,
trading_config: TradingConfig,
path_config: PathConfig
):
self.client = okx_client
self.config = trading_config
self.paths = path_config
self.positions: dict[str, Position] = {}
self.trade_log: list[dict] = []
self._load_positions()
def _load_positions(self) -> None:
"""Load positions from file."""
if self.paths.positions_file.exists():
try:
with open(self.paths.positions_file, 'r') as f:
data = json.load(f)
for trade_id, pos_data in data.items():
self.positions[trade_id] = Position.from_dict(pos_data)
logger.info(f"Loaded {len(self.positions)} positions from file")
except Exception as e:
logger.warning(f"Could not load positions: {e}")
def save_positions(self) -> None:
"""Save positions to file."""
try:
data = {
trade_id: pos.to_dict()
for trade_id, pos in self.positions.items()
}
with open(self.paths.positions_file, 'w') as f:
json.dump(data, f, indent=2)
logger.debug(f"Saved {len(self.positions)} positions")
except Exception as e:
logger.error(f"Could not save positions: {e}")
def can_open_position(self) -> bool:
"""Check if we can open a new position."""
return len(self.positions) < self.config.max_concurrent_positions
def get_position_for_symbol(self, symbol: str) -> Optional[Position]:
"""Get position for a specific symbol."""
for pos in self.positions.values():
if pos.symbol == symbol:
return pos
return None
def open_position(
self,
symbol: str,
side: str,
entry_price: float,
size: float,
stop_loss_price: float,
take_profit_price: float,
order_id: str = ""
) -> Optional[Position]:
"""
Open a new position.
Args:
symbol: Trading pair symbol
side: "long" or "short"
entry_price: Entry price
size: Position size in base currency
stop_loss_price: Stop-loss price
take_profit_price: Take-profit price
order_id: Entry order ID from exchange
Returns:
Position object or None if failed
"""
if not self.can_open_position():
logger.warning("Cannot open position: max concurrent positions reached")
return None
# Check if already have position for this symbol
existing = self.get_position_for_symbol(symbol)
if existing:
logger.warning(f"Already have position for {symbol}")
return None
# Generate trade ID
now = datetime.now(timezone.utc)
trade_id = f"{symbol}_{now.strftime('%Y%m%d_%H%M%S')}"
position = Position(
trade_id=trade_id,
symbol=symbol,
side=side,
entry_price=entry_price,
entry_time=now.isoformat(),
size=size,
size_usdt=entry_price * size,
stop_loss_price=stop_loss_price,
take_profit_price=take_profit_price,
current_price=entry_price,
order_id=order_id,
)
self.positions[trade_id] = position
self.save_positions()
logger.info(
f"Opened {side.upper()} position: {size} {symbol} @ {entry_price}, "
f"SL={stop_loss_price}, TP={take_profit_price}"
)
return position
def close_position(
self,
trade_id: str,
exit_price: float,
reason: str = "manual",
exit_order_id: str = ""
) -> Optional[dict]:
"""
Close a position and record the trade.
Args:
trade_id: Position trade ID
exit_price: Exit price
reason: Reason for closing (e.g., "stop_loss", "take_profit", "signal")
exit_order_id: Exit order ID from exchange
Returns:
Trade record dictionary
"""
if trade_id not in self.positions:
logger.warning(f"Position {trade_id} not found")
return None
position = self.positions[trade_id]
position.update_pnl(exit_price)
# Calculate final PnL
entry_time = datetime.fromisoformat(position.entry_time)
exit_time = datetime.now(timezone.utc)
hold_duration = (exit_time - entry_time).total_seconds() / 3600 # hours
trade_record = {
'trade_id': trade_id,
'symbol': position.symbol,
'side': position.side,
'entry_price': position.entry_price,
'exit_price': exit_price,
'size': position.size,
'size_usdt': position.size_usdt,
'pnl_usd': position.unrealized_pnl,
'pnl_pct': position.unrealized_pnl_pct,
'entry_time': position.entry_time,
'exit_time': exit_time.isoformat(),
'hold_duration_hours': hold_duration,
'reason': reason,
'order_id_entry': position.order_id,
'order_id_exit': exit_order_id,
}
self.trade_log.append(trade_record)
del self.positions[trade_id]
self.save_positions()
self._append_trade_log(trade_record)
logger.info(
f"Closed {position.side.upper()} position: {position.size} {position.symbol} "
f"@ {exit_price}, PnL=${position.unrealized_pnl:.2f} ({position.unrealized_pnl_pct:.2f}%), "
f"reason={reason}"
)
return trade_record
def _append_trade_log(self, trade_record: dict) -> None:
"""Append trade record to CSV log file."""
import csv
file_exists = self.paths.trade_log_file.exists()
with open(self.paths.trade_log_file, 'a', newline='') as f:
writer = csv.DictWriter(f, fieldnames=trade_record.keys())
if not file_exists:
writer.writeheader()
writer.writerow(trade_record)
def update_positions(self, current_prices: dict[str, float]) -> list[dict]:
"""
Update all positions with current prices and check SL/TP.
Args:
current_prices: Dictionary of symbol -> current price
Returns:
List of closed trade records
"""
closed_trades = []
for trade_id in list(self.positions.keys()):
position = self.positions[trade_id]
if position.symbol not in current_prices:
continue
current_price = current_prices[position.symbol]
position.update_pnl(current_price)
# Check stop-loss
if position.should_stop_loss(current_price):
logger.warning(
f"Stop-loss triggered for {trade_id} at {current_price}"
)
# Close position on exchange
exit_order_id = ""
try:
exit_order = self.client.close_position(position.symbol)
exit_order_id = exit_order.get('id', '') if exit_order else ''
except Exception as e:
logger.error(f"Failed to close position on exchange: {e}")
record = self.close_position(trade_id, current_price, "stop_loss", exit_order_id)
if record:
closed_trades.append(record)
continue
# Check take-profit
if position.should_take_profit(current_price):
logger.info(
f"Take-profit triggered for {trade_id} at {current_price}"
)
# Close position on exchange
exit_order_id = ""
try:
exit_order = self.client.close_position(position.symbol)
exit_order_id = exit_order.get('id', '') if exit_order else ''
except Exception as e:
logger.error(f"Failed to close position on exchange: {e}")
record = self.close_position(trade_id, current_price, "take_profit", exit_order_id)
if record:
closed_trades.append(record)
self.save_positions()
return closed_trades
def sync_with_exchange(self) -> None:
"""
Sync local positions with exchange positions.
Reconciles any discrepancies between local tracking
and actual exchange positions.
"""
try:
exchange_positions = self.client.get_positions()
exchange_symbols = {p['symbol'] for p in exchange_positions}
# Check for positions we have locally but not on exchange
for trade_id in list(self.positions.keys()):
pos = self.positions[trade_id]
if pos.symbol not in exchange_symbols:
logger.warning(
f"Position {trade_id} not found on exchange, removing"
)
# Get last price and close
try:
ticker = self.client.get_ticker(pos.symbol)
exit_price = ticker['last']
except Exception:
exit_price = pos.current_price
self.close_position(trade_id, exit_price, "sync_removed")
logger.info(f"Position sync complete: {len(self.positions)} local positions")
except Exception as e:
logger.error(f"Position sync failed: {e}")
def get_portfolio_summary(self) -> dict:
"""
Get portfolio summary.
Returns:
Dictionary with portfolio statistics
"""
total_exposure = sum(p.size_usdt for p in self.positions.values())
total_unrealized_pnl = sum(p.unrealized_pnl for p in self.positions.values())
return {
'open_positions': len(self.positions),
'total_exposure_usdt': total_exposure,
'total_unrealized_pnl': total_unrealized_pnl,
'positions': [p.to_dict() for p in self.positions.values()],
}