Enhance backtesting framework with static task processing and progress management. Introduced static task processing for parallel execution, improved error handling, and added a progress manager for better task tracking. Updated BacktestRunner to support progress callbacks and optimized worker allocation based on system resources. Added new configuration files for flexible backtesting setups.

This commit is contained in:
Simon Moisy 2025-07-10 10:23:41 +08:00
parent be331ed631
commit 65f30a4020
11 changed files with 830 additions and 156 deletions

View File

@ -5,9 +5,85 @@ from typing import List, Tuple, Dict, Any, Optional
from cycles.utils.storage import Storage from cycles.utils.storage import Storage
from cycles.utils.system import SystemUtils from cycles.utils.system import SystemUtils
from cycles.utils.progress_manager import ProgressManager
from result_processor import ResultProcessor from result_processor import ResultProcessor
def _process_single_task_static(task: Tuple[str, str, pd.DataFrame, float, float], progress_callback=None) -> Tuple[List[Dict], List[Dict]]:
"""
Static version of _process_single_task for use with ProcessPoolExecutor
Args:
task: Tuple of (task_id, timeframe, data_1min, stop_loss_pct, initial_usd)
progress_callback: Optional progress callback function
Returns:
Tuple of (results, trades)
"""
task_id, timeframe, data_1min, stop_loss_pct, initial_usd = task
try:
if timeframe == "1T" or timeframe == "1min":
df = data_1min.copy()
else:
df = _resample_data_static(data_1min, timeframe)
# Create required components for processing
from cycles.utils.storage import Storage
from result_processor import ResultProcessor
# Create storage with default paths (for subprocess)
storage = Storage()
result_processor = ResultProcessor(storage)
results, trades = result_processor.process_timeframe_results(
data_1min,
df,
[stop_loss_pct],
timeframe,
initial_usd,
progress_callback=progress_callback
)
return results, trades
except Exception as e:
error_msg = f"Failed to process {timeframe} with stop loss {stop_loss_pct}: {e}"
raise RuntimeError(error_msg) from e
def _resample_data_static(data_1min: pd.DataFrame, timeframe: str) -> pd.DataFrame:
"""
Static function to resample 1-minute data to specified timeframe
Args:
data_1min: 1-minute data DataFrame
timeframe: Target timeframe string
Returns:
Resampled DataFrame
"""
try:
agg_dict = {
'open': 'first',
'high': 'max',
'low': 'min',
'close': 'last',
'volume': 'sum'
}
if 'predicted_close_price' in data_1min.columns:
agg_dict['predicted_close_price'] = 'last'
resampled = data_1min.resample(timeframe).agg(agg_dict).dropna()
return resampled.reset_index()
except Exception as e:
error_msg = f"Failed to resample data to {timeframe}: {e}"
raise ValueError(error_msg) from e
class BacktestRunner: class BacktestRunner:
"""Handles the execution of backtests across multiple timeframes and parameters""" """Handles the execution of backtests across multiple timeframes and parameters"""
@ -16,7 +92,8 @@ class BacktestRunner:
storage: Storage, storage: Storage,
system_utils: SystemUtils, system_utils: SystemUtils,
result_processor: ResultProcessor, result_processor: ResultProcessor,
logging_instance: Optional[logging.Logger] = None logging_instance: Optional[logging.Logger] = None,
show_progress: bool = True
): ):
""" """
Initialize backtest runner Initialize backtest runner
@ -26,11 +103,14 @@ class BacktestRunner:
system_utils: System utilities for resource management system_utils: System utilities for resource management
result_processor: Result processor for handling outputs result_processor: Result processor for handling outputs
logging_instance: Optional logging instance logging_instance: Optional logging instance
show_progress: Whether to show visual progress bars
""" """
self.storage = storage self.storage = storage
self.system_utils = system_utils self.system_utils = system_utils
self.result_processor = result_processor self.result_processor = result_processor
self.logging = logging_instance self.logging = logging_instance
self.show_progress = show_progress
self.progress_manager = ProgressManager() if show_progress else None
def run_backtests( def run_backtests(
self, self,
@ -56,10 +136,13 @@ class BacktestRunner:
# Create tasks for all combinations # Create tasks for all combinations
tasks = self._create_tasks(timeframes, stop_loss_pcts, data_1min, initial_usd) tasks = self._create_tasks(timeframes, stop_loss_pcts, data_1min, initial_usd)
if self.logging:
self.logging.info(f"Starting {len(tasks)} backtest tasks")
if debug: if debug:
return self._run_sequential(tasks, debug) return self._run_sequential(tasks)
else: else:
return self._run_parallel(tasks, debug) return self._run_parallel(tasks)
def _create_tasks( def _create_tasks(
self, self,
@ -72,50 +155,92 @@ class BacktestRunner:
tasks = [] tasks = []
for timeframe in timeframes: for timeframe in timeframes:
for stop_loss_pct in stop_loss_pcts: for stop_loss_pct in stop_loss_pcts:
task = (timeframe, data_1min, stop_loss_pct, initial_usd) task_id = f"{timeframe}_{stop_loss_pct}"
task = (task_id, timeframe, data_1min, stop_loss_pct, initial_usd)
tasks.append(task) tasks.append(task)
return tasks return tasks
def _run_sequential(self, tasks: List[Tuple], debug: bool) -> Tuple[List[Dict], List[Dict]]:
def _run_sequential(self, tasks: List[Tuple]) -> Tuple[List[Dict], List[Dict]]:
"""Run tasks sequentially (for debug mode)""" """Run tasks sequentially (for debug mode)"""
# Initialize progress tracking if enabled
if self.progress_manager:
for task in tasks:
task_id, timeframe, data_1min, stop_loss_pct, initial_usd = task
# Calculate actual DataFrame size that will be processed
if timeframe == "1T" or timeframe == "1min":
actual_df_size = len(data_1min)
else:
# Get the actual resampled DataFrame size
temp_df = self._resample_data(data_1min, timeframe)
actual_df_size = len(temp_df)
task_name = f"{timeframe} SL:{stop_loss_pct:.0%}"
self.progress_manager.start_task(task_id, task_name, actual_df_size)
self.progress_manager.start_display()
all_results = [] all_results = []
all_trades = [] all_trades = []
for task in tasks: try:
try: for task in tasks:
results, trades = self._process_single_task(task, debug) try:
if results: # Get progress callback for this task if available
all_results.extend(results) progress_callback = None
if trades: if self.progress_manager:
all_trades.extend(trades) progress_callback = self.progress_manager.get_task_progress_callback(task[0])
except Exception as e: results, trades = self._process_single_task(task, progress_callback)
error_msg = f"Error processing task {task[0]} with stop loss {task[2]}: {e}"
if self.logging: if results:
self.logging.error(error_msg) all_results.extend(results)
raise RuntimeError(error_msg) from e if trades:
all_trades.extend(trades)
# Mark task as completed
if self.progress_manager:
self.progress_manager.complete_task(task[0])
except Exception as e:
error_msg = f"Error processing task {task[1]} with stop loss {task[3]}: {e}"
if self.logging:
self.logging.error(error_msg)
raise RuntimeError(error_msg) from e
finally:
# Stop progress display
if self.progress_manager:
self.progress_manager.stop_display()
return all_results, all_trades return all_results, all_trades
def _run_parallel(self, tasks: List[Tuple], debug: bool) -> Tuple[List[Dict], List[Dict]]: def _run_parallel(self, tasks: List[Tuple]) -> Tuple[List[Dict], List[Dict]]:
"""Run tasks in parallel using ProcessPoolExecutor""" """Run tasks in parallel using ProcessPoolExecutor"""
workers = self.system_utils.get_optimal_workers() workers = self.system_utils.get_optimal_workers()
if self.logging: if self.logging:
self.logging.info(f"Running {len(tasks)} tasks with {workers} workers") self.logging.info(f"Running {len(tasks)} tasks with {workers} workers")
# OPTIMIZATION: Disable progress manager for parallel execution to reduce overhead
# Progress tracking adds significant overhead in multiprocessing
if self.progress_manager and self.logging:
self.logging.info("Progress tracking disabled for parallel execution (performance optimization)")
all_results = [] all_results = []
all_trades = [] all_trades = []
completed_tasks = 0
try: try:
with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor: with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
# Submit all tasks
future_to_task = { future_to_task = {
executor.submit(self._process_single_task, task, debug): task executor.submit(_process_single_task_static, task): task
for task in tasks for task in tasks
} }
# Collect results as they complete
for future in concurrent.futures.as_completed(future_to_task): for future in concurrent.futures.as_completed(future_to_task):
task = future_to_task[future] task = future_to_task[future]
try: try:
@ -124,9 +249,14 @@ class BacktestRunner:
all_results.extend(results) all_results.extend(results)
if trades: if trades:
all_trades.extend(trades) all_trades.extend(trades)
completed_tasks += 1
if self.logging:
self.logging.info(f"Completed task {task[0]} ({completed_tasks}/{len(tasks)})")
except Exception as e: except Exception as e:
error_msg = f"Task {task[0]} with stop loss {task[2]} failed: {e}" error_msg = f"Task {task[1]} with stop loss {task[3]} failed: {e}"
if self.logging: if self.logging:
self.logging.error(error_msg) self.logging.error(error_msg)
raise RuntimeError(error_msg) from e raise RuntimeError(error_msg) from e
@ -136,46 +266,56 @@ class BacktestRunner:
if self.logging: if self.logging:
self.logging.error(error_msg) self.logging.error(error_msg)
raise RuntimeError(error_msg) from e raise RuntimeError(error_msg) from e
finally:
# Stop progress display
if self.progress_manager:
self.progress_manager.stop_display()
if self.logging:
self.logging.info(f"All {len(tasks)} tasks completed successfully")
return all_results, all_trades return all_results, all_trades
def _process_single_task( def _process_single_task(
self, self,
task: Tuple[str, pd.DataFrame, float, float], task: Tuple[str, str, pd.DataFrame, float, float],
debug: bool = False progress_callback=None
) -> Tuple[List[Dict], List[Dict]]: ) -> Tuple[List[Dict], List[Dict]]:
""" """
Process a single backtest task Process a single backtest task
Args: Args:
task: Tuple of (timeframe, data_1min, stop_loss_pct, initial_usd) task: Tuple of (task_id, timeframe, data_1min, stop_loss_pct, initial_usd)
debug: Whether to enable debug output progress_callback: Optional progress callback function
Returns: Returns:
Tuple of (results, trades) Tuple of (results, trades)
""" """
timeframe, data_1min, stop_loss_pct, initial_usd = task task_id, timeframe, data_1min, stop_loss_pct, initial_usd = task
try: try:
# Resample data if needed
if timeframe == "1T" or timeframe == "1min": if timeframe == "1T" or timeframe == "1min":
df = data_1min.copy() df = data_1min.copy()
else: else:
df = self._resample_data(data_1min, timeframe) df = self._resample_data(data_1min, timeframe)
# Process timeframe results
results, trades = self.result_processor.process_timeframe_results( results, trades = self.result_processor.process_timeframe_results(
data_1min, data_1min,
df, df,
[stop_loss_pct], [stop_loss_pct],
timeframe, timeframe,
initial_usd, initial_usd,
debug progress_callback=progress_callback
) )
# Save individual trade files if trades exist # OPTIMIZATION: Skip individual trade file saving during parallel execution
if trades: # Trade files will be saved in batch at the end
self.result_processor.save_trade_file(trades, timeframe, stop_loss_pct) # if trades:
# self.result_processor.save_trade_file(trades, timeframe, stop_loss_pct)
if self.logging:
self.logging.info(f"Completed task {task_id}: {len(results)} results, {len(trades)} trades")
return results, trades return results, trades
@ -197,13 +337,18 @@ class BacktestRunner:
Resampled DataFrame Resampled DataFrame
""" """
try: try:
resampled = data_1min.resample(timeframe).agg({ agg_dict = {
'open': 'first', 'open': 'first',
'high': 'max', 'high': 'max',
'low': 'min', 'low': 'min',
'close': 'last', 'close': 'last',
'volume': 'sum' 'volume': 'sum'
}).dropna() }
if 'predicted_close_price' in data_1min.columns:
agg_dict['predicted_close_price'] = 'last'
resampled = data_1min.resample(timeframe).agg(agg_dict).dropna()
return resampled.reset_index() return resampled.reset_index()
@ -213,6 +358,34 @@ class BacktestRunner:
self.logging.error(error_msg) self.logging.error(error_msg)
raise ValueError(error_msg) from e raise ValueError(error_msg) from e
def _get_timeframe_factor(self, timeframe: str) -> int:
"""
Get the factor by which data is reduced when resampling to timeframe
Args:
timeframe: Target timeframe string (e.g., '1h', '4h', '1D')
Returns:
Factor for estimating data size after resampling
"""
timeframe_factors = {
'1T': 1, '1min': 1,
'5T': 5, '5min': 5,
'15T': 15, '15min': 15,
'30T': 30, '30min': 30,
'1h': 60, '1H': 60,
'2h': 120, '2H': 120,
'4h': 240, '4H': 240,
'6h': 360, '6H': 360,
'8h': 480, '8H': 480,
'12h': 720, '12H': 720,
'1D': 1440, '1d': 1440,
'2D': 2880, '2d': 2880,
'3D': 4320, '3d': 4320,
'1W': 10080, '1w': 10080
}
return timeframe_factors.get(timeframe, 60) # Default to 1 hour if unknown
def load_data(self, filename: str, start_date: str, stop_date: str) -> pd.DataFrame: def load_data(self, filename: str, start_date: str, stop_date: str) -> pd.DataFrame:
""" """
Load and validate data for backtesting Load and validate data for backtesting
@ -234,8 +407,11 @@ class BacktestRunner:
if data.empty: if data.empty:
raise ValueError(f"No data loaded for period {start_date} to {stop_date}") raise ValueError(f"No data loaded for period {start_date} to {stop_date}")
# Validate required columns
required_columns = ['open', 'high', 'low', 'close', 'volume'] required_columns = ['open', 'high', 'low', 'close', 'volume']
if 'predicted_close_price' in data.columns:
required_columns.append('predicted_close_price')
missing_columns = [col for col in required_columns if col not in data.columns] missing_columns = [col for col in required_columns if col not in data.columns]
if missing_columns: if missing_columns:
@ -269,11 +445,9 @@ class BacktestRunner:
Raises: Raises:
ValueError: If any input is invalid ValueError: If any input is invalid
""" """
# Validate timeframes
if not timeframes: if not timeframes:
raise ValueError("At least one timeframe must be specified") raise ValueError("At least one timeframe must be specified")
# Validate stop loss percentages
if not stop_loss_pcts: if not stop_loss_pcts:
raise ValueError("At least one stop loss percentage must be specified") raise ValueError("At least one stop loss percentage must be specified")
@ -281,7 +455,6 @@ class BacktestRunner:
if not 0 < pct < 1: if not 0 < pct < 1:
raise ValueError(f"Stop loss percentage must be between 0 and 1, got: {pct}") raise ValueError(f"Stop loss percentage must be between 0 and 1, got: {pct}")
# Validate initial USD
if initial_usd <= 0: if initial_usd <= 0:
raise ValueError("Initial USD must be positive") raise ValueError("Initial USD must be positive")

View File

@ -14,7 +14,7 @@ class ConfigManager:
"initial_usd": 10000, "initial_usd": 10000,
"timeframes": ["1D", "6h", "3h", "1h", "30m", "15m", "5m", "1m"], "timeframes": ["1D", "6h", "3h", "1h", "30m", "15m", "5m", "1m"],
"stop_loss_pcts": [0.01, 0.02, 0.03, 0.05], "stop_loss_pcts": [0.01, 0.02, 0.03, 0.05],
"data_dir": "data", "data_dir": "../data",
"results_dir": "results" "results_dir": "results"
} }

View File

@ -0,0 +1,10 @@
{
"start_date": "2021-11-01",
"stop_date": "2024-04-01",
"initial_usd": 10000,
"timeframes": ["1min", "2min", "3min", "4min", "5min", "10min", "15min", "30min", "1h", "2h", "4h", "6h", "8h", "12h", "1d"],
"stop_loss_pcts": [0.01, 0.02, 0.03, 0.04, 0.05, 0.1],
"data_dir": "../data",
"results_dir": "../results",
"debug": 0
}

10
configs/full_config.json Normal file
View File

@ -0,0 +1,10 @@
{
"start_date": "2020-01-01",
"stop_date": "2025-07-08",
"initial_usd": 10000,
"timeframes": ["1h", "4h", "15ME", "5ME", "1ME"],
"stop_loss_pcts": [0.01, 0.02, 0.03, 0.05],
"data_dir": "../data",
"results_dir": "../results",
"debug": 1
}

View File

@ -2,8 +2,9 @@
"start_date": "2023-01-01", "start_date": "2023-01-01",
"stop_date": "2025-01-15", "stop_date": "2025-01-15",
"initial_usd": 10000, "initial_usd": 10000,
"timeframes": ["1h", "4h"], "timeframes": ["4h"],
"stop_loss_pcts": [0.02, 0.05], "stop_loss_pcts": [0.05],
"data_dir": "../data", "data_dir": "../data",
"results_dir": "../results" "results_dir": "../results",
"debug": 0
} }

View File

@ -7,21 +7,32 @@ from cycles.market_fees import MarketFees
class Backtest: class Backtest:
@staticmethod @staticmethod
def run(min1_df, df, initial_usd, stop_loss_pct, debug=False): def run(min1_df, df, initial_usd, stop_loss_pct, progress_callback=None, verbose=False):
""" """
Backtest a simple strategy using the meta supertrend (all three supertrends agree). Backtest a simple strategy using the meta supertrend (all three supertrends agree).
Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss. Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss.
Parameters: Parameters:
- min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional) - min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional)
- df: pandas DataFrame, main timeframe data for signals
- initial_usd: float, starting USD amount - initial_usd: float, starting USD amount
- stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%) - stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%)
- debug: bool, whether to print debug info - progress_callback: callable, optional callback function to report progress (current_step)
- verbose: bool, enable debug logging for stop loss checks
""" """
_df = df.copy().reset_index(drop=True) _df = df.copy().reset_index()
# Ensure we have a timestamp column regardless of original index name
if 'timestamp' not in _df.columns:
# If reset_index() created a column with the original index name, rename it
if len(_df.columns) > 0 and _df.columns[0] not in ['open', 'high', 'low', 'close', 'volume', 'predicted_close_price']:
_df = _df.rename(columns={_df.columns[0]: 'timestamp'})
else:
raise ValueError("Unable to identify timestamp column in DataFrame")
_df['timestamp'] = pd.to_datetime(_df['timestamp']) _df['timestamp'] = pd.to_datetime(_df['timestamp'])
supertrends = Supertrends(_df, verbose=False) supertrends = Supertrends(_df, verbose=False, close_column='predicted_close_price')
supertrend_results_list = supertrends.calculate_supertrend_indicators() supertrend_results_list = supertrends.calculate_supertrend_indicators()
trends = [st['results']['trend'] for st in supertrend_results_list] trends = [st['results']['trend'] for st in supertrend_results_list]
@ -41,18 +52,21 @@ class Backtest:
drawdowns = [] drawdowns = []
trades = [] trades = []
entry_time = None entry_time = None
current_trade_min1_start_idx = None stop_loss_count = 0 # Track number of stop losses
min1_df.index = pd.to_datetime(min1_df.index) # Ensure min1_df has proper DatetimeIndex
min1_timestamps = min1_df.index.values if min1_df is not None and not min1_df.empty:
min1_df.index = pd.to_datetime(min1_df.index)
last_print_time = time.time()
for i in range(1, len(_df)): for i in range(1, len(_df)):
current_time = time.time() # Report progress if callback is provided
if current_time - last_print_time >= 5: if progress_callback:
progress = (i / len(_df)) * 100 # Update more frequently for better responsiveness
print(f"\rProgress: {progress:.1f}%", end="", flush=True) update_frequency = max(1, len(_df) // 50) # Update every 2% of dataset (50 updates total)
last_print_time = current_time if i % update_frequency == 0 or i == len(_df) - 1: # Always update on last iteration
if verbose: # Only print in verbose mode to avoid spam
print(f"DEBUG: Progress callback called with i={i}, total={len(_df)-1}")
progress_callback(i)
price_open = _df['open'].iloc[i] price_open = _df['open'].iloc[i]
price_close = _df['close'].iloc[i] price_close = _df['close'].iloc[i]
@ -69,16 +83,13 @@ class Backtest:
entry_price, entry_price,
stop_loss_pct, stop_loss_pct,
coin, coin,
usd, verbose=verbose
debug,
current_trade_min1_start_idx
) )
if stop_loss_result is not None: if stop_loss_result is not None:
trade_log_entry, current_trade_min1_start_idx, position, coin, entry_price = stop_loss_result trade_log_entry, position, coin, entry_price, usd = stop_loss_result
trade_log.append(trade_log_entry) trade_log.append(trade_log_entry)
stop_loss_count += 1
continue continue
# Update the start index for next check
current_trade_min1_start_idx = min1_df.index[min1_df.index <= date][-1]
# Entry: only if not in position and signal changes to 1 # Entry: only if not in position and signal changes to 1
if position == 0 and prev_mt != 1 and curr_mt == 1: if position == 0 and prev_mt != 1 and curr_mt == 1:
@ -99,7 +110,9 @@ class Backtest:
drawdown = (max_balance - balance) / max_balance drawdown = (max_balance - balance) / max_balance
drawdowns.append(drawdown) drawdowns.append(drawdown)
print("\rProgress: 100%\r\n", end="", flush=True) # Report completion if callback is provided
if progress_callback:
progress_callback(len(_df) - 1)
# If still in position at end, sell at last close # If still in position at end, sell at last close
if position == 1: if position == 1:
@ -122,22 +135,37 @@ class Backtest:
profit_pct = (trade['exit'] - trade['entry']) / trade['entry'] profit_pct = (trade['exit'] - trade['entry']) / trade['entry']
else: else:
profit_pct = 0.0 profit_pct = 0.0
# Validate fee_usd field
if 'fee_usd' not in trade:
raise ValueError(f"Trade missing required field 'fee_usd': {trade}")
fee_usd = trade['fee_usd']
if fee_usd is None:
raise ValueError(f"Trade fee_usd is None: {trade}")
# Validate trade type field
if 'type' not in trade:
raise ValueError(f"Trade missing required field 'type': {trade}")
trade_type = trade['type']
if trade_type is None:
raise ValueError(f"Trade type is None: {trade}")
trades.append({ trades.append({
'entry_time': trade['entry_time'], 'entry_time': trade['entry_time'],
'exit_time': trade['exit_time'], 'exit_time': trade['exit_time'],
'entry': trade['entry'], 'entry': trade['entry'],
'exit': trade['exit'], 'exit': trade['exit'],
'profit_pct': profit_pct, 'profit_pct': profit_pct,
'type': trade.get('type', 'SELL'), 'type': trade_type,
'fee_usd': trade.get('fee_usd') 'fee_usd': fee_usd
}) })
fee_usd = trade.get('fee_usd')
total_fees_usd += fee_usd total_fees_usd += fee_usd
results = { results = {
"initial_usd": initial_usd, "initial_usd": initial_usd,
"final_usd": final_balance, "final_usd": final_balance,
"n_trades": n_trades, "n_trades": n_trades,
"n_stop_loss": stop_loss_count, # Add stop loss count
"win_rate": win_rate, "win_rate": win_rate,
"max_drawdown": max_drawdown, "max_drawdown": max_drawdown,
"avg_trade": avg_trade, "avg_trade": avg_trade,
@ -157,38 +185,112 @@ class Backtest:
return results return results
@staticmethod @staticmethod
def check_stop_loss(min1_df, entry_time, date, entry_price, stop_loss_pct, coin, usd, debug, current_trade_min1_start_idx): def check_stop_loss(min1_df, entry_time, current_time, entry_price, stop_loss_pct, coin, verbose=False):
"""
Check if stop loss should be triggered based on 1-minute data
Args:
min1_df: 1-minute DataFrame with DatetimeIndex
entry_time: Entry timestamp
current_time: Current timestamp
entry_price: Entry price
stop_loss_pct: Stop loss percentage (e.g. 0.05 for 5%)
coin: Current coin position
verbose: Enable debug logging
Returns:
Tuple of (trade_log_entry, position, coin, entry_price, usd) if stop loss triggered, None otherwise
"""
if min1_df is None or min1_df.empty:
if verbose:
print("Warning: No 1-minute data available for stop loss checking")
return None
stop_price = entry_price * (1 - stop_loss_pct) stop_price = entry_price * (1 - stop_loss_pct)
if current_trade_min1_start_idx is None: try:
current_trade_min1_start_idx = min1_df.index[min1_df.index >= entry_time][0] # Ensure min1_df has a DatetimeIndex
current_min1_end_idx = min1_df.index[min1_df.index <= date][-1] if not isinstance(min1_df.index, pd.DatetimeIndex):
if verbose:
# Check all 1-minute candles in between for stop loss print("Warning: min1_df does not have DatetimeIndex")
min1_slice = min1_df.loc[current_trade_min1_start_idx:current_min1_end_idx] return None
if (min1_slice['low'] <= stop_price).any():
# Stop loss triggered, find the exact candle # Convert entry_time and current_time to pandas Timestamps for comparison
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0] entry_ts = pd.to_datetime(entry_time)
# More realistic fill: if open < stop, fill at open, else at stop current_ts = pd.to_datetime(current_time)
if stop_candle['open'] < stop_price:
sell_price = stop_candle['open'] if verbose:
else: print(f"Checking stop loss from {entry_ts} to {current_ts}, stop_price: {stop_price:.2f}")
sell_price = stop_price
if debug: # Handle edge case where entry and current time are the same (1-minute timeframe)
print(f"STOP LOSS triggered: entry={entry_price}, stop={stop_price}, sell_price={sell_price}, entry_time={entry_time}, stop_time={stop_candle.name}") if entry_ts == current_ts:
btc_to_sell = coin if verbose:
usd_gross = btc_to_sell * sell_price print("Entry and current time are the same, no range to check")
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False) return None
trade_log_entry = {
'type': 'STOP', # Find the range of 1-minute data to check (exclusive of entry time, inclusive of current time)
'entry': entry_price, # We start from the candle AFTER entry to avoid checking the entry candle itself
'exit': sell_price, start_check_time = entry_ts + pd.Timedelta(minutes=1)
'entry_time': entry_time,
'exit_time': stop_candle.name, # Get the slice of data to check for stop loss
'fee_usd': exit_fee mask = (min1_df.index > entry_ts) & (min1_df.index <= current_ts)
} min1_slice = min1_df.loc[mask]
# After stop loss, reset position and entry
return trade_log_entry, None, 0, 0, 0 if len(min1_slice) == 0:
if verbose:
print(f"No 1-minute data found between {start_check_time} and {current_ts}")
return None
if verbose:
print(f"Checking {len(min1_slice)} candles for stop loss")
# Check if any low price in the slice hits the stop loss
stop_triggered = (min1_slice['low'] <= stop_price).any()
if stop_triggered:
# Find the exact candle where stop loss was triggered
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0]
if verbose:
print(f"Stop loss triggered at {stop_candle.name}, low: {stop_candle['low']:.2f}")
# More realistic fill: if open < stop, fill at open, else at stop
if stop_candle['open'] < stop_price:
sell_price = stop_candle['open']
if verbose:
print(f"Filled at open price: {sell_price:.2f}")
else:
sell_price = stop_price
if verbose:
print(f"Filled at stop price: {sell_price:.2f}")
btc_to_sell = coin
usd_gross = btc_to_sell * sell_price
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
usd_after_stop = usd_gross - exit_fee
trade_log_entry = {
'type': 'STOP',
'entry': entry_price,
'exit': sell_price,
'entry_time': entry_time,
'exit_time': stop_candle.name,
'fee_usd': exit_fee
}
# After stop loss, reset position and entry, return USD balance
return trade_log_entry, 0, 0, 0, usd_after_stop
elif verbose:
print(f"No stop loss triggered, min low in range: {min1_slice['low'].min():.2f}")
except Exception as e:
# In case of any error, don't trigger stop loss but log the issue
error_msg = f"Warning: Stop loss check failed: {e}"
print(error_msg)
if verbose:
import traceback
print(traceback.format_exc())
return None
return None return None
@staticmethod @staticmethod

View File

@ -65,30 +65,57 @@ def cached_supertrend_calculation(period, multiplier, data_tuple):
'lower_band': final_lower 'lower_band': final_lower
} }
def calculate_supertrend_external(data, period, multiplier): def calculate_supertrend_external(data, period, multiplier, close_column='close'):
"""
External function to calculate SuperTrend with configurable close column
Parameters:
- data: DataFrame with OHLC data
- period: int, period for ATR calculation
- multiplier: float, multiplier for ATR
- close_column: str, name of the column to use as close price (default: 'close')
"""
high_tuple = tuple(data['high']) high_tuple = tuple(data['high'])
low_tuple = tuple(data['low']) low_tuple = tuple(data['low'])
close_tuple = tuple(data['close']) close_tuple = tuple(data[close_column])
return cached_supertrend_calculation(period, multiplier, (high_tuple, low_tuple, close_tuple)) return cached_supertrend_calculation(period, multiplier, (high_tuple, low_tuple, close_tuple))
class Supertrends: class Supertrends:
def __init__(self, data, verbose=False, display=False): def __init__(self, data, close_column='close', verbose=False, display=False):
"""
Initialize Supertrends calculator
Parameters:
- data: pandas DataFrame with OHLC data or list of prices
- close_column: str, name of the column to use as close price (default: 'close')
- verbose: bool, enable verbose logging
- display: bool, display mode (currently unused)
"""
self.close_column = close_column
self.data = data self.data = data
self.verbose = verbose self.verbose = verbose
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING, logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
format='%(asctime)s - %(levelname)s - %(message)s') format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger('TrendDetectorSimple') self.logger = logging.getLogger('TrendDetectorSimple')
if not isinstance(self.data, pd.DataFrame): if not isinstance(self.data, pd.DataFrame):
if isinstance(self.data, list): if isinstance(self.data, list):
self.data = pd.DataFrame({'close': self.data}) self.data = pd.DataFrame({self.close_column: self.data})
else: else:
raise ValueError("Data must be a pandas DataFrame or a list") raise ValueError("Data must be a pandas DataFrame or a list")
# Validate that required columns exist
required_columns = ['high', 'low', self.close_column]
missing_columns = [col for col in required_columns if col not in self.data.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
def calculate_tr(self): def calculate_tr(self):
"""Calculate True Range using the configured close column"""
df = self.data.copy() df = self.data.copy()
high = df['high'].values high = df['high'].values
low = df['low'].values low = df['low'].values
close = df['close'].values close = df[self.close_column].values
tr = np.zeros_like(close) tr = np.zeros_like(close)
tr[0] = high[0] - low[0] tr[0] = high[0] - low[0]
for i in range(1, len(close)): for i in range(1, len(close)):
@ -99,6 +126,7 @@ class Supertrends:
return tr return tr
def calculate_atr(self, period=14): def calculate_atr(self, period=14):
"""Calculate Average True Range"""
tr = self.calculate_tr() tr = self.calculate_tr()
atr = np.zeros_like(tr) atr = np.zeros_like(tr)
atr[0] = tr[0] atr[0] = tr[0]
@ -109,18 +137,20 @@ class Supertrends:
def calculate_supertrend(self, period=10, multiplier=3.0): def calculate_supertrend(self, period=10, multiplier=3.0):
""" """
Calculate SuperTrend indicator for the price data. Calculate SuperTrend indicator for the price data using the configured close column.
SuperTrend is a trend-following indicator that uses ATR to determine the trend direction. SuperTrend is a trend-following indicator that uses ATR to determine the trend direction.
Parameters: Parameters:
- period: int, the period for the ATR calculation (default: 10) - period: int, the period for the ATR calculation (default: 10)
- multiplier: float, the multiplier for the ATR (default: 3.0) - multiplier: float, the multiplier for the ATR (default: 3.0)
Returns: Returns:
- Dictionary containing SuperTrend values, trend direction, and upper/lower bands - Dictionary containing SuperTrend values, trend direction, and upper/lower bands
""" """
df = self.data.copy() df = self.data.copy()
high = df['high'].values high = df['high'].values
low = df['low'].values low = df['low'].values
close = df['close'].values close = df[self.close_column].values
atr = self.calculate_atr(period) atr = self.calculate_atr(period)
upper_band = np.zeros_like(close) upper_band = np.zeros_like(close)
lower_band = np.zeros_like(close) lower_band = np.zeros_like(close)

View File

@ -0,0 +1,233 @@
#!/usr/bin/env python3
"""
Progress Manager for tracking multiple parallel backtest tasks
"""
import threading
import time
import sys
from typing import Dict, Optional, Callable
from dataclasses import dataclass
@dataclass
class TaskProgress:
"""Represents progress information for a single task"""
task_id: str
name: str
current: int
total: int
start_time: float
last_update: float
@property
def percentage(self) -> float:
"""Calculate completion percentage"""
if self.total == 0:
return 0.0
return (self.current / self.total) * 100
@property
def elapsed_time(self) -> float:
"""Calculate elapsed time in seconds"""
return time.time() - self.start_time
@property
def eta(self) -> Optional[float]:
"""Estimate time to completion in seconds"""
if self.current == 0 or self.percentage >= 100:
return None
elapsed = self.elapsed_time
rate = self.current / elapsed
remaining = self.total - self.current
return remaining / rate if rate > 0 else None
class ProgressManager:
"""Manages progress tracking for multiple parallel tasks"""
def __init__(self, update_interval: float = 1.0, display_width: int = 50):
"""
Initialize progress manager
Args:
update_interval: How often to update display (seconds)
display_width: Width of progress bar in characters
"""
self.tasks: Dict[str, TaskProgress] = {}
self.update_interval = update_interval
self.display_width = display_width
self.lock = threading.Lock()
self.display_thread: Optional[threading.Thread] = None
self.running = False
self.last_display_height = 0
def start_task(self, task_id: str, name: str, total: int) -> None:
"""
Start tracking a new task
Args:
task_id: Unique identifier for the task
name: Human-readable name for the task
total: Total number of steps in the task
"""
with self.lock:
self.tasks[task_id] = TaskProgress(
task_id=task_id,
name=name,
current=0,
total=total,
start_time=time.time(),
last_update=time.time()
)
def update_progress(self, task_id: str, current: int) -> None:
"""
Update progress for a specific task
Args:
task_id: Task identifier
current: Current progress value
"""
with self.lock:
if task_id in self.tasks:
self.tasks[task_id].current = current
self.tasks[task_id].last_update = time.time()
def complete_task(self, task_id: str) -> None:
"""
Mark a task as completed
Args:
task_id: Task identifier
"""
with self.lock:
if task_id in self.tasks:
task = self.tasks[task_id]
task.current = task.total
task.last_update = time.time()
def start_display(self) -> None:
"""Start the progress display thread"""
if not self.running:
self.running = True
self.display_thread = threading.Thread(target=self._display_loop, daemon=True)
self.display_thread.start()
def stop_display(self) -> None:
"""Stop the progress display thread"""
self.running = False
if self.display_thread:
self.display_thread.join(timeout=1.0)
self._clear_display()
def _display_loop(self) -> None:
"""Main loop for updating the progress display"""
while self.running:
self._update_display()
time.sleep(self.update_interval)
def _update_display(self) -> None:
"""Update the console display with current progress"""
with self.lock:
if not self.tasks:
return
# Clear previous display
self._clear_display()
# Build display lines
lines = []
for task in sorted(self.tasks.values(), key=lambda t: t.task_id):
line = self._format_progress_line(task)
lines.append(line)
# Print all lines
for line in lines:
print(line, flush=True)
self.last_display_height = len(lines)
def _clear_display(self) -> None:
"""Clear the previous progress display"""
if self.last_display_height > 0:
# Move cursor up and clear lines
for _ in range(self.last_display_height):
sys.stdout.write('\033[F') # Move cursor up one line
sys.stdout.write('\033[K') # Clear line
sys.stdout.flush()
def _format_progress_line(self, task: TaskProgress) -> str:
"""
Format a single progress line for display
Args:
task: TaskProgress instance
Returns:
Formatted progress string
"""
# Progress bar
filled_width = int(task.percentage / 100 * self.display_width)
bar = '' * filled_width + '' * (self.display_width - filled_width)
# Time information
elapsed_str = self._format_time(task.elapsed_time)
eta_str = self._format_time(task.eta) if task.eta else "N/A"
# Format line
line = (f"{task.name:<25}{bar}"
f"{task.percentage:5.1f}% "
f"({task.current:,}/{task.total:,}) "
f"{elapsed_str} ETA: {eta_str}")
return line
def _format_time(self, seconds: float) -> str:
"""
Format time duration for display
Args:
seconds: Time in seconds
Returns:
Formatted time string
"""
if seconds < 60:
return f"{seconds:.0f}s"
elif seconds < 3600:
minutes = seconds / 60
return f"{minutes:.1f}m"
else:
hours = seconds / 3600
return f"{hours:.1f}h"
def get_task_progress_callback(self, task_id: str) -> Callable[[int], None]:
"""
Get a progress callback function for a specific task
Args:
task_id: Task identifier
Returns:
Callback function that updates progress for this task
"""
def callback(current: int) -> None:
self.update_progress(task_id, current)
return callback
def all_tasks_completed(self) -> bool:
"""Check if all tasks are completed"""
with self.lock:
return all(task.current >= task.total for task in self.tasks.values())
def get_summary(self) -> str:
"""Get a summary of all tasks"""
with self.lock:
total_tasks = len(self.tasks)
completed_tasks = sum(1 for task in self.tasks.values()
if task.current >= task.total)
return f"Tasks: {completed_tasks}/{total_tasks} completed"

View File

@ -10,10 +10,12 @@ class SystemUtils:
"""Determine optimal number of worker processes based on system resources""" """Determine optimal number of worker processes based on system resources"""
cpu_count = os.cpu_count() or 4 cpu_count = os.cpu_count() or 4
memory_gb = psutil.virtual_memory().total / (1024**3) memory_gb = psutil.virtual_memory().total / (1024**3)
# Heuristic: Use 75% of cores, but cap based on available memory
# Assume each worker needs ~2GB for large datasets # OPTIMIZATION: More aggressive worker allocation for better performance
workers_by_memory = max(1, int(memory_gb / 2)) workers_by_memory = max(1, int(memory_gb / 2)) # 2GB per worker
workers_by_cpu = max(1, int(cpu_count * 0.75)) workers_by_cpu = max(1, int(cpu_count * 0.8)) # Use 80% of CPU cores
optimal_workers = min(workers_by_cpu, workers_by_memory, 8) # Cap at 8 workers
if self.logging is not None: if self.logging is not None:
self.logging.info(f"Using {min(workers_by_cpu, workers_by_memory)} workers for processing") self.logging.info(f"Using {optimal_workers} workers for processing (CPU-based: {workers_by_cpu}, Memory-based: {workers_by_memory})")
return min(workers_by_cpu, workers_by_memory) return optimal_workers

27
main.py
View File

@ -79,7 +79,23 @@ def main():
) )
system_utils = SystemUtils(logging=logger) system_utils = SystemUtils(logging=logger)
result_processor = ResultProcessor(storage, logging_instance=logger) result_processor = ResultProcessor(storage, logging_instance=logger)
runner = BacktestRunner(storage, system_utils, result_processor, logging_instance=logger)
# OPTIMIZATION: Disable progress for parallel execution to improve performance
show_progress = config.get('show_progress', True)
debug_mode = config.get('debug', 0) == 1
# Only show progress in debug (sequential) mode
if not debug_mode:
show_progress = False
logger.info("Progress tracking disabled for parallel execution (performance optimization)")
runner = BacktestRunner(
storage,
system_utils,
result_processor,
logging_instance=logger,
show_progress=show_progress
)
# Validate inputs # Validate inputs
logger.info("Validating inputs...") logger.info("Validating inputs...")
@ -91,7 +107,8 @@ def main():
# Load data # Load data
logger.info("Loading market data...") logger.info("Loading market data...")
data_filename = 'btcusd_1-min_data.csv' # data_filename = 'btcusd_1-min_data.csv'
data_filename = 'btcusd_1-min_data_with_price_predictions.csv'
data_1min = runner.load_data( data_1min = runner.load_data(
data_filename, data_filename,
config['start_date'], config['start_date'],
@ -100,7 +117,6 @@ def main():
# Run backtests # Run backtests
logger.info("Starting backtest execution...") logger.info("Starting backtest execution...")
debug_mode = True # Can be moved to config
all_results, all_trades = runner.run_backtests( all_results, all_trades = runner.run_backtests(
data_1min, data_1min,
@ -114,6 +130,11 @@ def main():
logger.info("Processing and saving results...") logger.info("Processing and saving results...")
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M") timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")
# OPTIMIZATION: Save trade files in batch after parallel execution
if all_trades and not debug_mode:
logger.info("Saving trade files in batch...")
result_processor.save_all_trade_files(all_trades)
# Create metadata # Create metadata
metadata_lines = create_metadata_lines(config, data_1min, result_processor) metadata_lines = create_metadata_lines(config, data_1min, result_processor)

View File

@ -29,8 +29,8 @@ class ResultProcessor:
df: pd.DataFrame, df: pd.DataFrame,
stop_loss_pcts: List[float], stop_loss_pcts: List[float],
timeframe_name: str, timeframe_name: str,
initial_usd: float, initial_usd: float,
debug: bool = False progress_callback=None
) -> Tuple[List[Dict], List[Dict]]: ) -> Tuple[List[Dict], List[Dict]]:
""" """
Process results for a single timeframe with multiple stop loss values Process results for a single timeframe with multiple stop loss values
@ -41,7 +41,7 @@ class ResultProcessor:
stop_loss_pcts: List of stop loss percentages to test stop_loss_pcts: List of stop loss percentages to test
timeframe_name: Name of the timeframe (e.g., '1D', '6h') timeframe_name: Name of the timeframe (e.g., '1D', '6h')
initial_usd: Initial USD amount initial_usd: Initial USD amount
debug: Whether to enable debug output progress_callback: Optional progress callback function
Returns: Returns:
Tuple of (results_rows, trade_rows) Tuple of (results_rows, trade_rows)
@ -59,7 +59,8 @@ class ResultProcessor:
df, df,
initial_usd=initial_usd, initial_usd=initial_usd,
stop_loss_pct=stop_loss_pct, stop_loss_pct=stop_loss_pct,
debug=debug progress_callback=progress_callback,
verbose=False # Default to False for production runs
) )
# Calculate metrics # Calculate metrics
@ -67,15 +68,14 @@ class ResultProcessor:
results_rows.append(metrics) results_rows.append(metrics)
# Process trades # Process trades
trades = self._process_trades(results.get('trades', []), timeframe_name, stop_loss_pct) if 'trades' not in results:
raise ValueError(f"Backtest results missing 'trades' field for {timeframe_name} with {stop_loss_pct} stop loss")
trades = self._process_trades(results['trades'], timeframe_name, stop_loss_pct)
trade_rows.extend(trades) trade_rows.extend(trades)
if self.logging: if self.logging:
self.logging.info(f"Timeframe: {timeframe_name}, Stop Loss: {stop_loss_pct}, Trades: {results['n_trades']}") self.logging.info(f"Timeframe: {timeframe_name}, Stop Loss: {stop_loss_pct}, Trades: {results['n_trades']}")
if debug:
self._debug_output(results)
except Exception as e: except Exception as e:
error_msg = f"Error processing {timeframe_name} with stop loss {stop_loss_pct}: {e}" error_msg = f"Error processing {timeframe_name} with stop loss {stop_loss_pct}: {e}"
if self.logging: if self.logging:
@ -92,36 +92,56 @@ class ResultProcessor:
timeframe_name: str timeframe_name: str
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Calculate performance metrics from backtest results""" """Calculate performance metrics from backtest results"""
trades = results.get('trades', []) if 'trades' not in results:
raise ValueError(f"Backtest results missing 'trades' field for {timeframe_name} with {stop_loss_pct} stop loss")
trades = results['trades']
n_trades = results["n_trades"] n_trades = results["n_trades"]
# Calculate win metrics # Validate that all required fields are present
winning_trades = [t for t in trades if t.get('exit') is not None and t['exit'] > t['entry']] required_fields = ['final_usd', 'max_drawdown', 'total_fees_usd', 'n_trades', 'n_stop_loss', 'win_rate', 'avg_trade']
missing_fields = [field for field in required_fields if field not in results]
if missing_fields:
raise ValueError(f"Backtest results missing required fields: {missing_fields}")
# Calculate win metrics - validate trade fields
winning_trades = []
for t in trades:
if 'exit' not in t:
raise ValueError(f"Trade missing 'exit' field: {t}")
if 'entry' not in t:
raise ValueError(f"Trade missing 'entry' field: {t}")
if t['exit'] is not None and t['exit'] > t['entry']:
winning_trades.append(t)
n_winning_trades = len(winning_trades) n_winning_trades = len(winning_trades)
win_rate = n_winning_trades / n_trades if n_trades > 0 else 0 win_rate = n_winning_trades / n_trades if n_trades > 0 else 0
# Calculate profit metrics # Calculate profit metrics
total_profit = sum(trade['profit_pct'] for trade in trades) total_profit = sum(trade['profit_pct'] for trade in trades if trade['profit_pct'] > 0)
total_loss = sum(-trade['profit_pct'] for trade in trades if trade['profit_pct'] < 0) total_loss = abs(sum(trade['profit_pct'] for trade in trades if trade['profit_pct'] < 0))
avg_trade = total_profit / n_trades if n_trades > 0 else 0 avg_trade = sum(trade['profit_pct'] for trade in trades) / n_trades if n_trades > 0 else 0
profit_ratio = total_profit / total_loss if total_loss > 0 else float('inf') profit_ratio = total_profit / total_loss if total_loss > 0 else (float('inf') if total_profit > 0 else 0)
# Calculate drawdown # Get values directly from backtest results (no defaults)
max_drawdown = self._calculate_max_drawdown(trades) max_drawdown = results['max_drawdown']
final_usd = results['final_usd']
total_fees_usd = results['total_fees_usd']
n_stop_loss = results['n_stop_loss'] # Get stop loss count directly from backtest
# Calculate final USD # Validate no None values
final_usd = initial_usd if max_drawdown is None:
for trade in trades: raise ValueError(f"max_drawdown is None for {timeframe_name} with {stop_loss_pct} stop loss")
final_usd *= (1 + trade['profit_pct']) if final_usd is None:
raise ValueError(f"final_usd is None for {timeframe_name} with {stop_loss_pct} stop loss")
# Calculate fees if total_fees_usd is None:
total_fees_usd = sum(trade.get('fee_usd', 0) for trade in trades) raise ValueError(f"total_fees_usd is None for {timeframe_name} with {stop_loss_pct} stop loss")
if n_stop_loss is None:
raise ValueError(f"n_stop_loss is None for {timeframe_name} with {stop_loss_pct} stop loss")
return { return {
"timeframe": timeframe_name, "timeframe": timeframe_name,
"stop_loss_pct": stop_loss_pct, "stop_loss_pct": stop_loss_pct,
"n_trades": n_trades, "n_trades": n_trades,
"n_stop_loss": sum(1 for trade in trades if trade.get('type') == 'STOP'), "n_stop_loss": n_stop_loss,
"win_rate": win_rate, "win_rate": win_rate,
"max_drawdown": max_drawdown, "max_drawdown": max_drawdown,
"avg_trade": avg_trade, "avg_trade": avg_trade,
@ -159,16 +179,22 @@ class ResultProcessor:
processed_trades = [] processed_trades = []
for trade in trades: for trade in trades:
# Validate all required trade fields
required_fields = ["entry_time", "exit_time", "entry", "exit", "profit_pct", "type", "fee_usd"]
missing_fields = [field for field in required_fields if field not in trade]
if missing_fields:
raise ValueError(f"Trade missing required fields: {missing_fields} in trade: {trade}")
processed_trade = { processed_trade = {
"timeframe": timeframe_name, "timeframe": timeframe_name,
"stop_loss_pct": stop_loss_pct, "stop_loss_pct": stop_loss_pct,
"entry_time": trade.get("entry_time"), "entry_time": trade["entry_time"],
"exit_time": trade.get("exit_time"), "exit_time": trade["exit_time"],
"entry_price": trade.get("entry"), "entry_price": trade["entry"],
"exit_price": trade.get("exit"), "exit_price": trade["exit"],
"profit_pct": trade.get("profit_pct"), "profit_pct": trade["profit_pct"],
"type": trade.get("type"), "type": trade["type"],
"fee_usd": trade.get("fee_usd"), "fee_usd": trade["fee_usd"],
} }
processed_trades.append(processed_trade) processed_trades.append(processed_trade)
@ -176,17 +202,31 @@ class ResultProcessor:
def _debug_output(self, results: Dict[str, Any]) -> None: def _debug_output(self, results: Dict[str, Any]) -> None:
"""Output debug information for backtest results""" """Output debug information for backtest results"""
trades = results.get('trades', []) if 'trades' not in results:
raise ValueError("Backtest results missing 'trades' field for debug output")
trades = results['trades']
# Print stop loss trades # Print stop loss trades
stop_loss_trades = [t for t in trades if t.get('type') == 'STOP'] stop_loss_trades = []
for t in trades:
if 'type' not in t:
raise ValueError(f"Trade missing 'type' field: {t}")
if t['type'] == 'STOP':
stop_loss_trades.append(t)
if stop_loss_trades: if stop_loss_trades:
print("Stop Loss Trades:") print("Stop Loss Trades:")
for trade in stop_loss_trades: for trade in stop_loss_trades:
print(trade) print(trade)
# Print large loss trades # Print large loss trades
large_loss_trades = [t for t in trades if t.get('profit_pct', 0) < -0.09] large_loss_trades = []
for t in trades:
if 'profit_pct' not in t:
raise ValueError(f"Trade missing 'profit_pct' field: {t}")
if t['profit_pct'] < -0.09:
large_loss_trades.append(t)
if large_loss_trades: if large_loss_trades:
print("Large Loss Trades:") print("Large Loss Trades:")
for trade in large_loss_trades: for trade in large_loss_trades:
@ -216,19 +256,32 @@ class ResultProcessor:
def _aggregate_group(self, rows: List[Dict], timeframe: str, stop_loss_pct: float) -> Dict: def _aggregate_group(self, rows: List[Dict], timeframe: str, stop_loss_pct: float) -> Dict:
"""Aggregate a group of rows with the same timeframe and stop loss""" """Aggregate a group of rows with the same timeframe and stop loss"""
if not rows:
raise ValueError(f"No rows to aggregate for {timeframe} with {stop_loss_pct} stop loss")
# Validate all rows have required fields
required_fields = ['n_trades', 'n_stop_loss', 'win_rate', 'max_drawdown', 'avg_trade', 'profit_ratio', 'final_usd', 'total_fees_usd', 'initial_usd']
for i, row in enumerate(rows):
missing_fields = [field for field in required_fields if field not in row]
if missing_fields:
raise ValueError(f"Row {i} missing required fields: {missing_fields}")
total_trades = sum(r['n_trades'] for r in rows) total_trades = sum(r['n_trades'] for r in rows)
total_stop_loss = sum(r['n_stop_loss'] for r in rows) total_stop_loss = sum(r['n_stop_loss'] for r in rows)
# Calculate averages # Calculate averages (no defaults, expect all values to be present)
avg_win_rate = np.mean([r['win_rate'] for r in rows]) avg_win_rate = np.mean([r['win_rate'] for r in rows])
avg_max_drawdown = np.mean([r['max_drawdown'] for r in rows]) avg_max_drawdown = np.mean([r['max_drawdown'] for r in rows])
avg_avg_trade = np.mean([r['avg_trade'] for r in rows]) avg_avg_trade = np.mean([r['avg_trade'] for r in rows])
avg_profit_ratio = np.mean([r['profit_ratio'] for r in rows])
# Calculate final USD and fees # Handle infinite profit ratios properly
final_usd = np.mean([r.get('final_usd', r.get('initial_usd', 0)) for r in rows]) finite_profit_ratios = [r['profit_ratio'] for r in rows if not np.isinf(r['profit_ratio'])]
total_fees_usd = np.mean([r.get('total_fees_usd', 0) for r in rows]) avg_profit_ratio = np.mean(finite_profit_ratios) if finite_profit_ratios else 0
initial_usd = rows[0].get('initial_usd', 0) if rows else 0
# Calculate final USD and fees (no defaults)
final_usd = np.mean([r['final_usd'] for r in rows])
total_fees_usd = np.mean([r['total_fees_usd'] for r in rows])
initial_usd = rows[0]['initial_usd']
return { return {
"timeframe": timeframe, "timeframe": timeframe,
@ -278,7 +331,11 @@ class ResultProcessor:
writer = csv.DictWriter(f, fieldnames=trades_fieldnames) writer = csv.DictWriter(f, fieldnames=trades_fieldnames)
writer.writeheader() writer.writeheader()
for trade in trades: for trade in trades:
writer.writerow({k: trade.get(k, "") for k in trades_fieldnames}) # Validate all required fields are present
missing_fields = [k for k in trades_fieldnames if k not in trade]
if missing_fields:
raise ValueError(f"Trade missing required fields for CSV: {missing_fields} in trade: {trade}")
writer.writerow({k: trade[k] for k in trades_fieldnames})
if self.logging: if self.logging:
self.logging.info(f"Trades saved to {trades_filename}") self.logging.info(f"Trades saved to {trades_filename}")
@ -351,4 +408,39 @@ class ResultProcessor:
except Exception as e: except Exception as e:
if self.logging: if self.logging:
self.logging.warning(f"Could not get price info for {date}: {e}") self.logging.warning(f"Could not get price info for {date}: {e}")
return None, None return None, None
def save_all_trade_files(self, all_trades: List[Dict]) -> None:
"""
Save all trade files in batch after parallel execution completes
Args:
all_trades: List of all trades from all tasks
"""
if not all_trades:
return
try:
# Group trades by timeframe and stop loss
trade_groups = {}
for trade in all_trades:
timeframe = trade.get('timeframe')
stop_loss_pct = trade.get('stop_loss_pct')
if timeframe and stop_loss_pct is not None:
key = (timeframe, stop_loss_pct)
if key not in trade_groups:
trade_groups[key] = []
trade_groups[key].append(trade)
# Save each group
for (timeframe, stop_loss_pct), trades in trade_groups.items():
self.save_trade_file(trades, timeframe, stop_loss_pct)
if self.logging:
self.logging.info(f"Saved {len(trade_groups)} trade files in batch")
except Exception as e:
error_msg = f"Failed to save trade files in batch: {e}"
if self.logging:
self.logging.error(error_msg)
raise RuntimeError(error_msg) from e