344 lines
13 KiB
Python
344 lines
13 KiB
Python
|
|
"""
|
||
|
|
Incremental Trader for backtesting incremental strategies.
|
||
|
|
|
||
|
|
This module provides the IncTrader class that manages a single incremental strategy
|
||
|
|
during backtesting, handling position state, trade execution, and performance tracking.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pandas as pd
|
||
|
|
import numpy as np
|
||
|
|
from typing import Dict, Optional, List, Any
|
||
|
|
import logging
|
||
|
|
from dataclasses import dataclass
|
||
|
|
|
||
|
|
from .base import IncStrategyBase, IncStrategySignal
|
||
|
|
from ..market_fees import MarketFees
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class TradeRecord:
|
||
|
|
"""Record of a completed trade."""
|
||
|
|
entry_time: pd.Timestamp
|
||
|
|
exit_time: pd.Timestamp
|
||
|
|
entry_price: float
|
||
|
|
exit_price: float
|
||
|
|
entry_fee: float
|
||
|
|
exit_fee: float
|
||
|
|
profit_pct: float
|
||
|
|
exit_reason: str
|
||
|
|
strategy_name: str
|
||
|
|
|
||
|
|
|
||
|
|
class IncTrader:
|
||
|
|
"""
|
||
|
|
Incremental trader that manages a single strategy during backtesting.
|
||
|
|
|
||
|
|
This class handles:
|
||
|
|
- Strategy initialization and data feeding
|
||
|
|
- Position management (USD/coin balance)
|
||
|
|
- Trade execution based on strategy signals
|
||
|
|
- Performance tracking and metrics collection
|
||
|
|
- Fee calculation and trade logging
|
||
|
|
|
||
|
|
The trader processes data points sequentially, feeding them to the strategy
|
||
|
|
and executing trades based on the generated signals.
|
||
|
|
|
||
|
|
Example:
|
||
|
|
strategy = IncRandomStrategy(params={"timeframe": "15min"})
|
||
|
|
trader = IncTrader(
|
||
|
|
strategy=strategy,
|
||
|
|
initial_usd=10000,
|
||
|
|
params={"stop_loss_pct": 0.02}
|
||
|
|
)
|
||
|
|
|
||
|
|
# Process data sequentially
|
||
|
|
for timestamp, ohlcv_data in data_stream:
|
||
|
|
trader.process_data_point(timestamp, ohlcv_data)
|
||
|
|
|
||
|
|
# Get results
|
||
|
|
results = trader.get_results()
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, strategy: IncStrategyBase, initial_usd: float = 10000,
|
||
|
|
params: Optional[Dict] = None):
|
||
|
|
"""
|
||
|
|
Initialize the incremental trader.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
strategy: Incremental strategy instance
|
||
|
|
initial_usd: Initial USD balance
|
||
|
|
params: Trader parameters (stop_loss_pct, take_profit_pct, etc.)
|
||
|
|
"""
|
||
|
|
self.strategy = strategy
|
||
|
|
self.initial_usd = initial_usd
|
||
|
|
self.params = params or {}
|
||
|
|
|
||
|
|
# Position state
|
||
|
|
self.usd = initial_usd
|
||
|
|
self.coin = 0.0
|
||
|
|
self.position = 0 # 0 = no position, 1 = long position
|
||
|
|
self.entry_price = 0.0
|
||
|
|
self.entry_time = None
|
||
|
|
|
||
|
|
# Performance tracking
|
||
|
|
self.max_balance = initial_usd
|
||
|
|
self.drawdowns = []
|
||
|
|
self.trade_records = []
|
||
|
|
self.current_timestamp = None
|
||
|
|
self.current_price = None
|
||
|
|
|
||
|
|
# Strategy state
|
||
|
|
self.data_points_processed = 0
|
||
|
|
self.warmup_complete = False
|
||
|
|
|
||
|
|
# Parameters
|
||
|
|
self.stop_loss_pct = self.params.get("stop_loss_pct", 0.0)
|
||
|
|
self.take_profit_pct = self.params.get("take_profit_pct", 0.0)
|
||
|
|
|
||
|
|
logger.info(f"IncTrader initialized: strategy={strategy.name}, "
|
||
|
|
f"initial_usd=${initial_usd}, stop_loss={self.stop_loss_pct*100:.1f}%")
|
||
|
|
|
||
|
|
def process_data_point(self, timestamp: pd.Timestamp, ohlcv_data: Dict[str, float]) -> None:
|
||
|
|
"""
|
||
|
|
Process a single data point through the strategy and handle trading logic.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
timestamp: Data point timestamp
|
||
|
|
ohlcv_data: OHLCV data dictionary with keys: open, high, low, close, volume
|
||
|
|
"""
|
||
|
|
self.current_timestamp = timestamp
|
||
|
|
self.current_price = ohlcv_data['close']
|
||
|
|
self.data_points_processed += 1
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Feed data to strategy (handles timeframe aggregation internally)
|
||
|
|
result = self.strategy.update_minute_data(timestamp, ohlcv_data)
|
||
|
|
|
||
|
|
# Check if strategy is warmed up
|
||
|
|
if not self.warmup_complete and self.strategy.is_warmed_up:
|
||
|
|
self.warmup_complete = True
|
||
|
|
logger.info(f"Strategy {self.strategy.name} warmed up after "
|
||
|
|
f"{self.data_points_processed} data points")
|
||
|
|
|
||
|
|
# Only process signals if strategy is warmed up and we have a complete timeframe bar
|
||
|
|
if self.warmup_complete and result is not None:
|
||
|
|
self._process_trading_logic()
|
||
|
|
|
||
|
|
# Update performance tracking
|
||
|
|
self._update_performance_metrics()
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error processing data point at {timestamp}: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
def _process_trading_logic(self) -> None:
|
||
|
|
"""Process trading logic based on current position and strategy signals."""
|
||
|
|
if self.position == 0:
|
||
|
|
# No position - check for entry signals
|
||
|
|
self._check_entry_signals()
|
||
|
|
else:
|
||
|
|
# In position - check for exit signals
|
||
|
|
self._check_exit_signals()
|
||
|
|
|
||
|
|
def _check_entry_signals(self) -> None:
|
||
|
|
"""Check for entry signals when not in position."""
|
||
|
|
try:
|
||
|
|
entry_signal = self.strategy.get_entry_signal()
|
||
|
|
|
||
|
|
if entry_signal.signal_type == "ENTRY" and entry_signal.confidence > 0:
|
||
|
|
self._execute_entry(entry_signal)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error checking entry signals: {e}")
|
||
|
|
|
||
|
|
def _check_exit_signals(self) -> None:
|
||
|
|
"""Check for exit signals when in position."""
|
||
|
|
try:
|
||
|
|
# Check strategy exit signals
|
||
|
|
exit_signal = self.strategy.get_exit_signal()
|
||
|
|
|
||
|
|
if exit_signal.signal_type == "EXIT" and exit_signal.confidence > 0:
|
||
|
|
exit_reason = exit_signal.metadata.get("type", "STRATEGY_EXIT")
|
||
|
|
self._execute_exit(exit_reason, exit_signal.price)
|
||
|
|
return
|
||
|
|
|
||
|
|
# Check stop loss
|
||
|
|
if self.stop_loss_pct > 0:
|
||
|
|
stop_loss_price = self.entry_price * (1 - self.stop_loss_pct)
|
||
|
|
if self.current_price <= stop_loss_price:
|
||
|
|
self._execute_exit("STOP_LOSS", self.current_price)
|
||
|
|
return
|
||
|
|
|
||
|
|
# Check take profit
|
||
|
|
if self.take_profit_pct > 0:
|
||
|
|
take_profit_price = self.entry_price * (1 + self.take_profit_pct)
|
||
|
|
if self.current_price >= take_profit_price:
|
||
|
|
self._execute_exit("TAKE_PROFIT", self.current_price)
|
||
|
|
return
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error checking exit signals: {e}")
|
||
|
|
|
||
|
|
def _execute_entry(self, signal: IncStrategySignal) -> None:
|
||
|
|
"""Execute entry trade."""
|
||
|
|
entry_price = signal.price if signal.price else self.current_price
|
||
|
|
entry_fee = MarketFees.calculate_okx_taker_maker_fee(self.usd, is_maker=False)
|
||
|
|
usd_after_fee = self.usd - entry_fee
|
||
|
|
|
||
|
|
self.coin = usd_after_fee / entry_price
|
||
|
|
self.entry_price = entry_price
|
||
|
|
self.entry_time = self.current_timestamp
|
||
|
|
self.usd = 0.0
|
||
|
|
self.position = 1
|
||
|
|
|
||
|
|
logger.info(f"ENTRY: {self.strategy.name} at ${entry_price:.2f}, "
|
||
|
|
f"confidence={signal.confidence:.2f}, fee=${entry_fee:.2f}")
|
||
|
|
|
||
|
|
def _execute_exit(self, exit_reason: str, exit_price: Optional[float] = None) -> None:
|
||
|
|
"""Execute exit trade."""
|
||
|
|
exit_price = exit_price if exit_price else self.current_price
|
||
|
|
usd_gross = self.coin * exit_price
|
||
|
|
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
|
||
|
|
|
||
|
|
self.usd = usd_gross - exit_fee
|
||
|
|
|
||
|
|
# Calculate profit
|
||
|
|
profit_pct = (exit_price - self.entry_price) / self.entry_price
|
||
|
|
|
||
|
|
# Record trade
|
||
|
|
trade_record = TradeRecord(
|
||
|
|
entry_time=self.entry_time,
|
||
|
|
exit_time=self.current_timestamp,
|
||
|
|
entry_price=self.entry_price,
|
||
|
|
exit_price=exit_price,
|
||
|
|
entry_fee=MarketFees.calculate_okx_taker_maker_fee(
|
||
|
|
self.coin * self.entry_price, is_maker=False
|
||
|
|
),
|
||
|
|
exit_fee=exit_fee,
|
||
|
|
profit_pct=profit_pct,
|
||
|
|
exit_reason=exit_reason,
|
||
|
|
strategy_name=self.strategy.name
|
||
|
|
)
|
||
|
|
self.trade_records.append(trade_record)
|
||
|
|
|
||
|
|
# Reset position
|
||
|
|
self.coin = 0.0
|
||
|
|
self.position = 0
|
||
|
|
self.entry_price = 0.0
|
||
|
|
self.entry_time = None
|
||
|
|
|
||
|
|
logger.info(f"EXIT: {self.strategy.name} at ${exit_price:.2f}, "
|
||
|
|
f"reason={exit_reason}, profit={profit_pct*100:.2f}%, fee=${exit_fee:.2f}")
|
||
|
|
|
||
|
|
def _update_performance_metrics(self) -> None:
|
||
|
|
"""Update performance tracking metrics."""
|
||
|
|
# Calculate current balance
|
||
|
|
if self.position == 0:
|
||
|
|
current_balance = self.usd
|
||
|
|
else:
|
||
|
|
current_balance = self.coin * self.current_price
|
||
|
|
|
||
|
|
# Update max balance and drawdown
|
||
|
|
if current_balance > self.max_balance:
|
||
|
|
self.max_balance = current_balance
|
||
|
|
|
||
|
|
drawdown = (self.max_balance - current_balance) / self.max_balance
|
||
|
|
self.drawdowns.append(drawdown)
|
||
|
|
|
||
|
|
def finalize(self) -> None:
|
||
|
|
"""Finalize trading session (close any open positions)."""
|
||
|
|
if self.position == 1:
|
||
|
|
self._execute_exit("EOD", self.current_price)
|
||
|
|
logger.info(f"Closed final position for {self.strategy.name} at EOD")
|
||
|
|
|
||
|
|
def get_results(self) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Get comprehensive trading results.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dict containing performance metrics, trade records, and statistics
|
||
|
|
"""
|
||
|
|
final_balance = self.usd
|
||
|
|
n_trades = len(self.trade_records)
|
||
|
|
|
||
|
|
# Calculate statistics
|
||
|
|
if n_trades > 0:
|
||
|
|
profits = [trade.profit_pct for trade in self.trade_records]
|
||
|
|
wins = [p for p in profits if p > 0]
|
||
|
|
win_rate = len(wins) / n_trades
|
||
|
|
avg_trade = np.mean(profits)
|
||
|
|
total_fees = sum(trade.entry_fee + trade.exit_fee for trade in self.trade_records)
|
||
|
|
else:
|
||
|
|
win_rate = 0.0
|
||
|
|
avg_trade = 0.0
|
||
|
|
total_fees = 0.0
|
||
|
|
|
||
|
|
max_drawdown = max(self.drawdowns) if self.drawdowns else 0.0
|
||
|
|
profit_ratio = (final_balance - self.initial_usd) / self.initial_usd
|
||
|
|
|
||
|
|
# Convert trade records to dictionaries
|
||
|
|
trades = []
|
||
|
|
for trade in self.trade_records:
|
||
|
|
trades.append({
|
||
|
|
'entry_time': trade.entry_time,
|
||
|
|
'exit_time': trade.exit_time,
|
||
|
|
'entry': trade.entry_price,
|
||
|
|
'exit': trade.exit_price,
|
||
|
|
'profit_pct': trade.profit_pct,
|
||
|
|
'type': trade.exit_reason,
|
||
|
|
'fee_usd': trade.entry_fee + trade.exit_fee,
|
||
|
|
'strategy': trade.strategy_name
|
||
|
|
})
|
||
|
|
|
||
|
|
results = {
|
||
|
|
"strategy_name": self.strategy.name,
|
||
|
|
"strategy_params": self.strategy.params,
|
||
|
|
"trader_params": self.params,
|
||
|
|
"initial_usd": self.initial_usd,
|
||
|
|
"final_usd": final_balance,
|
||
|
|
"profit_ratio": profit_ratio,
|
||
|
|
"n_trades": n_trades,
|
||
|
|
"win_rate": win_rate,
|
||
|
|
"max_drawdown": max_drawdown,
|
||
|
|
"avg_trade": avg_trade,
|
||
|
|
"total_fees_usd": total_fees,
|
||
|
|
"data_points_processed": self.data_points_processed,
|
||
|
|
"warmup_complete": self.warmup_complete,
|
||
|
|
"trades": trades
|
||
|
|
}
|
||
|
|
|
||
|
|
# Add first and last trade info if available
|
||
|
|
if n_trades > 0:
|
||
|
|
results["first_trade"] = {
|
||
|
|
"entry_time": self.trade_records[0].entry_time,
|
||
|
|
"entry": self.trade_records[0].entry_price
|
||
|
|
}
|
||
|
|
results["last_trade"] = {
|
||
|
|
"exit_time": self.trade_records[-1].exit_time,
|
||
|
|
"exit": self.trade_records[-1].exit_price
|
||
|
|
}
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def get_current_state(self) -> Dict[str, Any]:
|
||
|
|
"""Get current trader state for debugging."""
|
||
|
|
return {
|
||
|
|
"strategy": self.strategy.name,
|
||
|
|
"position": self.position,
|
||
|
|
"usd": self.usd,
|
||
|
|
"coin": self.coin,
|
||
|
|
"current_price": self.current_price,
|
||
|
|
"entry_price": self.entry_price,
|
||
|
|
"data_points_processed": self.data_points_processed,
|
||
|
|
"warmup_complete": self.warmup_complete,
|
||
|
|
"n_trades": len(self.trade_records),
|
||
|
|
"strategy_state": self.strategy.get_current_state_summary()
|
||
|
|
}
|
||
|
|
|
||
|
|
def __repr__(self) -> str:
|
||
|
|
"""String representation of the trader."""
|
||
|
|
return (f"IncTrader(strategy={self.strategy.name}, "
|
||
|
|
f"position={self.position}, usd=${self.usd:.2f}, "
|
||
|
|
f"trades={len(self.trade_records)})")
|