Enhance backtesting framework with static task processing and progress management. Introduced static task processing for parallel execution, improved error handling, and added a progress manager for better task tracking. Updated BacktestRunner to support progress callbacks and optimized worker allocation based on system resources. Added new configuration files for flexible backtesting setups.

This commit is contained in:
Simon Moisy
2025-07-10 10:23:41 +08:00
parent be331ed631
commit 65f30a4020
11 changed files with 830 additions and 156 deletions

View File

@@ -7,21 +7,32 @@ from cycles.market_fees import MarketFees
class Backtest:
@staticmethod
def run(min1_df, df, initial_usd, stop_loss_pct, debug=False):
def run(min1_df, df, initial_usd, stop_loss_pct, progress_callback=None, verbose=False):
"""
Backtest a simple strategy using the meta supertrend (all three supertrends agree).
Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss.
Parameters:
- min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional)
- df: pandas DataFrame, main timeframe data for signals
- initial_usd: float, starting USD amount
- stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%)
- debug: bool, whether to print debug info
- progress_callback: callable, optional callback function to report progress (current_step)
- verbose: bool, enable debug logging for stop loss checks
"""
_df = df.copy().reset_index(drop=True)
_df = df.copy().reset_index()
# Ensure we have a timestamp column regardless of original index name
if 'timestamp' not in _df.columns:
# If reset_index() created a column with the original index name, rename it
if len(_df.columns) > 0 and _df.columns[0] not in ['open', 'high', 'low', 'close', 'volume', 'predicted_close_price']:
_df = _df.rename(columns={_df.columns[0]: 'timestamp'})
else:
raise ValueError("Unable to identify timestamp column in DataFrame")
_df['timestamp'] = pd.to_datetime(_df['timestamp'])
supertrends = Supertrends(_df, verbose=False)
supertrends = Supertrends(_df, verbose=False, close_column='predicted_close_price')
supertrend_results_list = supertrends.calculate_supertrend_indicators()
trends = [st['results']['trend'] for st in supertrend_results_list]
@@ -41,18 +52,21 @@ class Backtest:
drawdowns = []
trades = []
entry_time = None
current_trade_min1_start_idx = None
stop_loss_count = 0 # Track number of stop losses
min1_df.index = pd.to_datetime(min1_df.index)
min1_timestamps = min1_df.index.values
# Ensure min1_df has proper DatetimeIndex
if min1_df is not None and not min1_df.empty:
min1_df.index = pd.to_datetime(min1_df.index)
last_print_time = time.time()
for i in range(1, len(_df)):
current_time = time.time()
if current_time - last_print_time >= 5:
progress = (i / len(_df)) * 100
print(f"\rProgress: {progress:.1f}%", end="", flush=True)
last_print_time = current_time
# Report progress if callback is provided
if progress_callback:
# Update more frequently for better responsiveness
update_frequency = max(1, len(_df) // 50) # Update every 2% of dataset (50 updates total)
if i % update_frequency == 0 or i == len(_df) - 1: # Always update on last iteration
if verbose: # Only print in verbose mode to avoid spam
print(f"DEBUG: Progress callback called with i={i}, total={len(_df)-1}")
progress_callback(i)
price_open = _df['open'].iloc[i]
price_close = _df['close'].iloc[i]
@@ -69,16 +83,13 @@ class Backtest:
entry_price,
stop_loss_pct,
coin,
usd,
debug,
current_trade_min1_start_idx
verbose=verbose
)
if stop_loss_result is not None:
trade_log_entry, current_trade_min1_start_idx, position, coin, entry_price = stop_loss_result
trade_log_entry, position, coin, entry_price, usd = stop_loss_result
trade_log.append(trade_log_entry)
stop_loss_count += 1
continue
# Update the start index for next check
current_trade_min1_start_idx = min1_df.index[min1_df.index <= date][-1]
# Entry: only if not in position and signal changes to 1
if position == 0 and prev_mt != 1 and curr_mt == 1:
@@ -99,7 +110,9 @@ class Backtest:
drawdown = (max_balance - balance) / max_balance
drawdowns.append(drawdown)
print("\rProgress: 100%\r\n", end="", flush=True)
# Report completion if callback is provided
if progress_callback:
progress_callback(len(_df) - 1)
# If still in position at end, sell at last close
if position == 1:
@@ -122,22 +135,37 @@ class Backtest:
profit_pct = (trade['exit'] - trade['entry']) / trade['entry']
else:
profit_pct = 0.0
# Validate fee_usd field
if 'fee_usd' not in trade:
raise ValueError(f"Trade missing required field 'fee_usd': {trade}")
fee_usd = trade['fee_usd']
if fee_usd is None:
raise ValueError(f"Trade fee_usd is None: {trade}")
# Validate trade type field
if 'type' not in trade:
raise ValueError(f"Trade missing required field 'type': {trade}")
trade_type = trade['type']
if trade_type is None:
raise ValueError(f"Trade type is None: {trade}")
trades.append({
'entry_time': trade['entry_time'],
'exit_time': trade['exit_time'],
'entry': trade['entry'],
'exit': trade['exit'],
'profit_pct': profit_pct,
'type': trade.get('type', 'SELL'),
'fee_usd': trade.get('fee_usd')
'type': trade_type,
'fee_usd': fee_usd
})
fee_usd = trade.get('fee_usd')
total_fees_usd += fee_usd
results = {
"initial_usd": initial_usd,
"final_usd": final_balance,
"n_trades": n_trades,
"n_stop_loss": stop_loss_count, # Add stop loss count
"win_rate": win_rate,
"max_drawdown": max_drawdown,
"avg_trade": avg_trade,
@@ -157,38 +185,112 @@ class Backtest:
return results
@staticmethod
def check_stop_loss(min1_df, entry_time, date, entry_price, stop_loss_pct, coin, usd, debug, current_trade_min1_start_idx):
def check_stop_loss(min1_df, entry_time, current_time, entry_price, stop_loss_pct, coin, verbose=False):
"""
Check if stop loss should be triggered based on 1-minute data
Args:
min1_df: 1-minute DataFrame with DatetimeIndex
entry_time: Entry timestamp
current_time: Current timestamp
entry_price: Entry price
stop_loss_pct: Stop loss percentage (e.g. 0.05 for 5%)
coin: Current coin position
verbose: Enable debug logging
Returns:
Tuple of (trade_log_entry, position, coin, entry_price, usd) if stop loss triggered, None otherwise
"""
if min1_df is None or min1_df.empty:
if verbose:
print("Warning: No 1-minute data available for stop loss checking")
return None
stop_price = entry_price * (1 - stop_loss_pct)
if current_trade_min1_start_idx is None:
current_trade_min1_start_idx = min1_df.index[min1_df.index >= entry_time][0]
current_min1_end_idx = min1_df.index[min1_df.index <= date][-1]
# Check all 1-minute candles in between for stop loss
min1_slice = min1_df.loc[current_trade_min1_start_idx:current_min1_end_idx]
if (min1_slice['low'] <= stop_price).any():
# Stop loss triggered, find the exact candle
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0]
# More realistic fill: if open < stop, fill at open, else at stop
if stop_candle['open'] < stop_price:
sell_price = stop_candle['open']
else:
sell_price = stop_price
if debug:
print(f"STOP LOSS triggered: entry={entry_price}, stop={stop_price}, sell_price={sell_price}, entry_time={entry_time}, stop_time={stop_candle.name}")
btc_to_sell = coin
usd_gross = btc_to_sell * sell_price
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
trade_log_entry = {
'type': 'STOP',
'entry': entry_price,
'exit': sell_price,
'entry_time': entry_time,
'exit_time': stop_candle.name,
'fee_usd': exit_fee
}
# After stop loss, reset position and entry
return trade_log_entry, None, 0, 0, 0
try:
# Ensure min1_df has a DatetimeIndex
if not isinstance(min1_df.index, pd.DatetimeIndex):
if verbose:
print("Warning: min1_df does not have DatetimeIndex")
return None
# Convert entry_time and current_time to pandas Timestamps for comparison
entry_ts = pd.to_datetime(entry_time)
current_ts = pd.to_datetime(current_time)
if verbose:
print(f"Checking stop loss from {entry_ts} to {current_ts}, stop_price: {stop_price:.2f}")
# Handle edge case where entry and current time are the same (1-minute timeframe)
if entry_ts == current_ts:
if verbose:
print("Entry and current time are the same, no range to check")
return None
# Find the range of 1-minute data to check (exclusive of entry time, inclusive of current time)
# We start from the candle AFTER entry to avoid checking the entry candle itself
start_check_time = entry_ts + pd.Timedelta(minutes=1)
# Get the slice of data to check for stop loss
mask = (min1_df.index > entry_ts) & (min1_df.index <= current_ts)
min1_slice = min1_df.loc[mask]
if len(min1_slice) == 0:
if verbose:
print(f"No 1-minute data found between {start_check_time} and {current_ts}")
return None
if verbose:
print(f"Checking {len(min1_slice)} candles for stop loss")
# Check if any low price in the slice hits the stop loss
stop_triggered = (min1_slice['low'] <= stop_price).any()
if stop_triggered:
# Find the exact candle where stop loss was triggered
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0]
if verbose:
print(f"Stop loss triggered at {stop_candle.name}, low: {stop_candle['low']:.2f}")
# More realistic fill: if open < stop, fill at open, else at stop
if stop_candle['open'] < stop_price:
sell_price = stop_candle['open']
if verbose:
print(f"Filled at open price: {sell_price:.2f}")
else:
sell_price = stop_price
if verbose:
print(f"Filled at stop price: {sell_price:.2f}")
btc_to_sell = coin
usd_gross = btc_to_sell * sell_price
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
usd_after_stop = usd_gross - exit_fee
trade_log_entry = {
'type': 'STOP',
'entry': entry_price,
'exit': sell_price,
'entry_time': entry_time,
'exit_time': stop_candle.name,
'fee_usd': exit_fee
}
# After stop loss, reset position and entry, return USD balance
return trade_log_entry, 0, 0, 0, usd_after_stop
elif verbose:
print(f"No stop loss triggered, min low in range: {min1_slice['low'].min():.2f}")
except Exception as e:
# In case of any error, don't trigger stop loss but log the issue
error_msg = f"Warning: Stop loss check failed: {e}"
print(error_msg)
if verbose:
import traceback
print(traceback.format_exc())
return None
return None
@staticmethod

View File

@@ -65,30 +65,57 @@ def cached_supertrend_calculation(period, multiplier, data_tuple):
'lower_band': final_lower
}
def calculate_supertrend_external(data, period, multiplier):
def calculate_supertrend_external(data, period, multiplier, close_column='close'):
"""
External function to calculate SuperTrend with configurable close column
Parameters:
- data: DataFrame with OHLC data
- period: int, period for ATR calculation
- multiplier: float, multiplier for ATR
- close_column: str, name of the column to use as close price (default: 'close')
"""
high_tuple = tuple(data['high'])
low_tuple = tuple(data['low'])
close_tuple = tuple(data['close'])
close_tuple = tuple(data[close_column])
return cached_supertrend_calculation(period, multiplier, (high_tuple, low_tuple, close_tuple))
class Supertrends:
def __init__(self, data, verbose=False, display=False):
def __init__(self, data, close_column='close', verbose=False, display=False):
"""
Initialize Supertrends calculator
Parameters:
- data: pandas DataFrame with OHLC data or list of prices
- close_column: str, name of the column to use as close price (default: 'close')
- verbose: bool, enable verbose logging
- display: bool, display mode (currently unused)
"""
self.close_column = close_column
self.data = data
self.verbose = verbose
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger('TrendDetectorSimple')
if not isinstance(self.data, pd.DataFrame):
if isinstance(self.data, list):
self.data = pd.DataFrame({'close': self.data})
self.data = pd.DataFrame({self.close_column: self.data})
else:
raise ValueError("Data must be a pandas DataFrame or a list")
# Validate that required columns exist
required_columns = ['high', 'low', self.close_column]
missing_columns = [col for col in required_columns if col not in self.data.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
def calculate_tr(self):
"""Calculate True Range using the configured close column"""
df = self.data.copy()
high = df['high'].values
low = df['low'].values
close = df['close'].values
close = df[self.close_column].values
tr = np.zeros_like(close)
tr[0] = high[0] - low[0]
for i in range(1, len(close)):
@@ -99,6 +126,7 @@ class Supertrends:
return tr
def calculate_atr(self, period=14):
"""Calculate Average True Range"""
tr = self.calculate_tr()
atr = np.zeros_like(tr)
atr[0] = tr[0]
@@ -109,18 +137,20 @@ class Supertrends:
def calculate_supertrend(self, period=10, multiplier=3.0):
"""
Calculate SuperTrend indicator for the price data.
Calculate SuperTrend indicator for the price data using the configured close column.
SuperTrend is a trend-following indicator that uses ATR to determine the trend direction.
Parameters:
- period: int, the period for the ATR calculation (default: 10)
- multiplier: float, the multiplier for the ATR (default: 3.0)
Returns:
- Dictionary containing SuperTrend values, trend direction, and upper/lower bands
"""
df = self.data.copy()
high = df['high'].values
low = df['low'].values
close = df['close'].values
close = df[self.close_column].values
atr = self.calculate_atr(period)
upper_band = np.zeros_like(close)
lower_band = np.zeros_like(close)

View File

@@ -0,0 +1,233 @@
#!/usr/bin/env python3
"""
Progress Manager for tracking multiple parallel backtest tasks
"""
import threading
import time
import sys
from typing import Dict, Optional, Callable
from dataclasses import dataclass
@dataclass
class TaskProgress:
"""Represents progress information for a single task"""
task_id: str
name: str
current: int
total: int
start_time: float
last_update: float
@property
def percentage(self) -> float:
"""Calculate completion percentage"""
if self.total == 0:
return 0.0
return (self.current / self.total) * 100
@property
def elapsed_time(self) -> float:
"""Calculate elapsed time in seconds"""
return time.time() - self.start_time
@property
def eta(self) -> Optional[float]:
"""Estimate time to completion in seconds"""
if self.current == 0 or self.percentage >= 100:
return None
elapsed = self.elapsed_time
rate = self.current / elapsed
remaining = self.total - self.current
return remaining / rate if rate > 0 else None
class ProgressManager:
"""Manages progress tracking for multiple parallel tasks"""
def __init__(self, update_interval: float = 1.0, display_width: int = 50):
"""
Initialize progress manager
Args:
update_interval: How often to update display (seconds)
display_width: Width of progress bar in characters
"""
self.tasks: Dict[str, TaskProgress] = {}
self.update_interval = update_interval
self.display_width = display_width
self.lock = threading.Lock()
self.display_thread: Optional[threading.Thread] = None
self.running = False
self.last_display_height = 0
def start_task(self, task_id: str, name: str, total: int) -> None:
"""
Start tracking a new task
Args:
task_id: Unique identifier for the task
name: Human-readable name for the task
total: Total number of steps in the task
"""
with self.lock:
self.tasks[task_id] = TaskProgress(
task_id=task_id,
name=name,
current=0,
total=total,
start_time=time.time(),
last_update=time.time()
)
def update_progress(self, task_id: str, current: int) -> None:
"""
Update progress for a specific task
Args:
task_id: Task identifier
current: Current progress value
"""
with self.lock:
if task_id in self.tasks:
self.tasks[task_id].current = current
self.tasks[task_id].last_update = time.time()
def complete_task(self, task_id: str) -> None:
"""
Mark a task as completed
Args:
task_id: Task identifier
"""
with self.lock:
if task_id in self.tasks:
task = self.tasks[task_id]
task.current = task.total
task.last_update = time.time()
def start_display(self) -> None:
"""Start the progress display thread"""
if not self.running:
self.running = True
self.display_thread = threading.Thread(target=self._display_loop, daemon=True)
self.display_thread.start()
def stop_display(self) -> None:
"""Stop the progress display thread"""
self.running = False
if self.display_thread:
self.display_thread.join(timeout=1.0)
self._clear_display()
def _display_loop(self) -> None:
"""Main loop for updating the progress display"""
while self.running:
self._update_display()
time.sleep(self.update_interval)
def _update_display(self) -> None:
"""Update the console display with current progress"""
with self.lock:
if not self.tasks:
return
# Clear previous display
self._clear_display()
# Build display lines
lines = []
for task in sorted(self.tasks.values(), key=lambda t: t.task_id):
line = self._format_progress_line(task)
lines.append(line)
# Print all lines
for line in lines:
print(line, flush=True)
self.last_display_height = len(lines)
def _clear_display(self) -> None:
"""Clear the previous progress display"""
if self.last_display_height > 0:
# Move cursor up and clear lines
for _ in range(self.last_display_height):
sys.stdout.write('\033[F') # Move cursor up one line
sys.stdout.write('\033[K') # Clear line
sys.stdout.flush()
def _format_progress_line(self, task: TaskProgress) -> str:
"""
Format a single progress line for display
Args:
task: TaskProgress instance
Returns:
Formatted progress string
"""
# Progress bar
filled_width = int(task.percentage / 100 * self.display_width)
bar = '' * filled_width + '' * (self.display_width - filled_width)
# Time information
elapsed_str = self._format_time(task.elapsed_time)
eta_str = self._format_time(task.eta) if task.eta else "N/A"
# Format line
line = (f"{task.name:<25}{bar}"
f"{task.percentage:5.1f}% "
f"({task.current:,}/{task.total:,}) "
f"{elapsed_str} ETA: {eta_str}")
return line
def _format_time(self, seconds: float) -> str:
"""
Format time duration for display
Args:
seconds: Time in seconds
Returns:
Formatted time string
"""
if seconds < 60:
return f"{seconds:.0f}s"
elif seconds < 3600:
minutes = seconds / 60
return f"{minutes:.1f}m"
else:
hours = seconds / 3600
return f"{hours:.1f}h"
def get_task_progress_callback(self, task_id: str) -> Callable[[int], None]:
"""
Get a progress callback function for a specific task
Args:
task_id: Task identifier
Returns:
Callback function that updates progress for this task
"""
def callback(current: int) -> None:
self.update_progress(task_id, current)
return callback
def all_tasks_completed(self) -> bool:
"""Check if all tasks are completed"""
with self.lock:
return all(task.current >= task.total for task in self.tasks.values())
def get_summary(self) -> str:
"""Get a summary of all tasks"""
with self.lock:
total_tasks = len(self.tasks)
completed_tasks = sum(1 for task in self.tasks.values()
if task.current >= task.total)
return f"Tasks: {completed_tasks}/{total_tasks} completed"

View File

@@ -10,10 +10,12 @@ class SystemUtils:
"""Determine optimal number of worker processes based on system resources"""
cpu_count = os.cpu_count() or 4
memory_gb = psutil.virtual_memory().total / (1024**3)
# Heuristic: Use 75% of cores, but cap based on available memory
# Assume each worker needs ~2GB for large datasets
workers_by_memory = max(1, int(memory_gb / 2))
workers_by_cpu = max(1, int(cpu_count * 0.75))
# OPTIMIZATION: More aggressive worker allocation for better performance
workers_by_memory = max(1, int(memory_gb / 2)) # 2GB per worker
workers_by_cpu = max(1, int(cpu_count * 0.8)) # Use 80% of CPU cores
optimal_workers = min(workers_by_cpu, workers_by_memory, 8) # Cap at 8 workers
if self.logging is not None:
self.logging.info(f"Using {min(workers_by_cpu, workers_by_memory)} workers for processing")
return min(workers_by_cpu, workers_by_memory)
self.logging.info(f"Using {optimal_workers} workers for processing (CPU-based: {workers_by_cpu}, Memory-based: {workers_by_memory})")
return optimal_workers