Enhance backtesting framework with static task processing and progress management. Introduced static task processing for parallel execution, improved error handling, and added a progress manager for better task tracking. Updated BacktestRunner to support progress callbacks and optimized worker allocation based on system resources. Added new configuration files for flexible backtesting setups.
This commit is contained in:
parent
be331ed631
commit
65f30a4020
@ -5,9 +5,85 @@ from typing import List, Tuple, Dict, Any, Optional
|
||||
|
||||
from cycles.utils.storage import Storage
|
||||
from cycles.utils.system import SystemUtils
|
||||
from cycles.utils.progress_manager import ProgressManager
|
||||
from result_processor import ResultProcessor
|
||||
|
||||
|
||||
def _process_single_task_static(task: Tuple[str, str, pd.DataFrame, float, float], progress_callback=None) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
Static version of _process_single_task for use with ProcessPoolExecutor
|
||||
|
||||
Args:
|
||||
task: Tuple of (task_id, timeframe, data_1min, stop_loss_pct, initial_usd)
|
||||
progress_callback: Optional progress callback function
|
||||
|
||||
Returns:
|
||||
Tuple of (results, trades)
|
||||
"""
|
||||
task_id, timeframe, data_1min, stop_loss_pct, initial_usd = task
|
||||
|
||||
try:
|
||||
if timeframe == "1T" or timeframe == "1min":
|
||||
df = data_1min.copy()
|
||||
else:
|
||||
df = _resample_data_static(data_1min, timeframe)
|
||||
|
||||
# Create required components for processing
|
||||
from cycles.utils.storage import Storage
|
||||
from result_processor import ResultProcessor
|
||||
|
||||
# Create storage with default paths (for subprocess)
|
||||
storage = Storage()
|
||||
result_processor = ResultProcessor(storage)
|
||||
|
||||
results, trades = result_processor.process_timeframe_results(
|
||||
data_1min,
|
||||
df,
|
||||
[stop_loss_pct],
|
||||
timeframe,
|
||||
initial_usd,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
return results, trades
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to process {timeframe} with stop loss {stop_loss_pct}: {e}"
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
|
||||
def _resample_data_static(data_1min: pd.DataFrame, timeframe: str) -> pd.DataFrame:
|
||||
"""
|
||||
Static function to resample 1-minute data to specified timeframe
|
||||
|
||||
Args:
|
||||
data_1min: 1-minute data DataFrame
|
||||
timeframe: Target timeframe string
|
||||
|
||||
Returns:
|
||||
Resampled DataFrame
|
||||
"""
|
||||
try:
|
||||
agg_dict = {
|
||||
'open': 'first',
|
||||
'high': 'max',
|
||||
'low': 'min',
|
||||
'close': 'last',
|
||||
'volume': 'sum'
|
||||
}
|
||||
|
||||
if 'predicted_close_price' in data_1min.columns:
|
||||
agg_dict['predicted_close_price'] = 'last'
|
||||
|
||||
resampled = data_1min.resample(timeframe).agg(agg_dict).dropna()
|
||||
|
||||
return resampled.reset_index()
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to resample data to {timeframe}: {e}"
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
class BacktestRunner:
|
||||
"""Handles the execution of backtests across multiple timeframes and parameters"""
|
||||
|
||||
@ -16,7 +92,8 @@ class BacktestRunner:
|
||||
storage: Storage,
|
||||
system_utils: SystemUtils,
|
||||
result_processor: ResultProcessor,
|
||||
logging_instance: Optional[logging.Logger] = None
|
||||
logging_instance: Optional[logging.Logger] = None,
|
||||
show_progress: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize backtest runner
|
||||
@ -26,11 +103,14 @@ class BacktestRunner:
|
||||
system_utils: System utilities for resource management
|
||||
result_processor: Result processor for handling outputs
|
||||
logging_instance: Optional logging instance
|
||||
show_progress: Whether to show visual progress bars
|
||||
"""
|
||||
self.storage = storage
|
||||
self.system_utils = system_utils
|
||||
self.result_processor = result_processor
|
||||
self.logging = logging_instance
|
||||
self.show_progress = show_progress
|
||||
self.progress_manager = ProgressManager() if show_progress else None
|
||||
|
||||
def run_backtests(
|
||||
self,
|
||||
@ -56,10 +136,13 @@ class BacktestRunner:
|
||||
# Create tasks for all combinations
|
||||
tasks = self._create_tasks(timeframes, stop_loss_pcts, data_1min, initial_usd)
|
||||
|
||||
if self.logging:
|
||||
self.logging.info(f"Starting {len(tasks)} backtest tasks")
|
||||
|
||||
if debug:
|
||||
return self._run_sequential(tasks, debug)
|
||||
return self._run_sequential(tasks)
|
||||
else:
|
||||
return self._run_parallel(tasks, debug)
|
||||
return self._run_parallel(tasks)
|
||||
|
||||
def _create_tasks(
|
||||
self,
|
||||
@ -72,50 +155,92 @@ class BacktestRunner:
|
||||
tasks = []
|
||||
for timeframe in timeframes:
|
||||
for stop_loss_pct in stop_loss_pcts:
|
||||
task = (timeframe, data_1min, stop_loss_pct, initial_usd)
|
||||
task_id = f"{timeframe}_{stop_loss_pct}"
|
||||
task = (task_id, timeframe, data_1min, stop_loss_pct, initial_usd)
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
|
||||
|
||||
def _run_sequential(self, tasks: List[Tuple], debug: bool) -> Tuple[List[Dict], List[Dict]]:
|
||||
|
||||
def _run_sequential(self, tasks: List[Tuple]) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""Run tasks sequentially (for debug mode)"""
|
||||
# Initialize progress tracking if enabled
|
||||
if self.progress_manager:
|
||||
for task in tasks:
|
||||
task_id, timeframe, data_1min, stop_loss_pct, initial_usd = task
|
||||
|
||||
# Calculate actual DataFrame size that will be processed
|
||||
if timeframe == "1T" or timeframe == "1min":
|
||||
actual_df_size = len(data_1min)
|
||||
else:
|
||||
# Get the actual resampled DataFrame size
|
||||
temp_df = self._resample_data(data_1min, timeframe)
|
||||
actual_df_size = len(temp_df)
|
||||
|
||||
task_name = f"{timeframe} SL:{stop_loss_pct:.0%}"
|
||||
self.progress_manager.start_task(task_id, task_name, actual_df_size)
|
||||
|
||||
self.progress_manager.start_display()
|
||||
|
||||
all_results = []
|
||||
all_trades = []
|
||||
|
||||
for task in tasks:
|
||||
try:
|
||||
results, trades = self._process_single_task(task, debug)
|
||||
if results:
|
||||
all_results.extend(results)
|
||||
if trades:
|
||||
all_trades.extend(trades)
|
||||
try:
|
||||
for task in tasks:
|
||||
try:
|
||||
# Get progress callback for this task if available
|
||||
progress_callback = None
|
||||
if self.progress_manager:
|
||||
progress_callback = self.progress_manager.get_task_progress_callback(task[0])
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing task {task[0]} with stop loss {task[2]}: {e}"
|
||||
if self.logging:
|
||||
self.logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
results, trades = self._process_single_task(task, progress_callback)
|
||||
|
||||
if results:
|
||||
all_results.extend(results)
|
||||
if trades:
|
||||
all_trades.extend(trades)
|
||||
|
||||
# Mark task as completed
|
||||
if self.progress_manager:
|
||||
self.progress_manager.complete_task(task[0])
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing task {task[1]} with stop loss {task[3]}: {e}"
|
||||
|
||||
if self.logging:
|
||||
self.logging.error(error_msg)
|
||||
|
||||
raise RuntimeError(error_msg) from e
|
||||
finally:
|
||||
# Stop progress display
|
||||
if self.progress_manager:
|
||||
self.progress_manager.stop_display()
|
||||
|
||||
return all_results, all_trades
|
||||
|
||||
def _run_parallel(self, tasks: List[Tuple], debug: bool) -> Tuple[List[Dict], List[Dict]]:
|
||||
def _run_parallel(self, tasks: List[Tuple]) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""Run tasks in parallel using ProcessPoolExecutor"""
|
||||
workers = self.system_utils.get_optimal_workers()
|
||||
|
||||
if self.logging:
|
||||
self.logging.info(f"Running {len(tasks)} tasks with {workers} workers")
|
||||
|
||||
# OPTIMIZATION: Disable progress manager for parallel execution to reduce overhead
|
||||
# Progress tracking adds significant overhead in multiprocessing
|
||||
if self.progress_manager and self.logging:
|
||||
self.logging.info("Progress tracking disabled for parallel execution (performance optimization)")
|
||||
|
||||
all_results = []
|
||||
all_trades = []
|
||||
completed_tasks = 0
|
||||
|
||||
try:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
|
||||
# Submit all tasks
|
||||
future_to_task = {
|
||||
executor.submit(self._process_single_task, task, debug): task
|
||||
executor.submit(_process_single_task_static, task): task
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
# Collect results as they complete
|
||||
for future in concurrent.futures.as_completed(future_to_task):
|
||||
task = future_to_task[future]
|
||||
try:
|
||||
@ -124,9 +249,14 @@ class BacktestRunner:
|
||||
all_results.extend(results)
|
||||
if trades:
|
||||
all_trades.extend(trades)
|
||||
|
||||
completed_tasks += 1
|
||||
|
||||
if self.logging:
|
||||
self.logging.info(f"Completed task {task[0]} ({completed_tasks}/{len(tasks)})")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Task {task[0]} with stop loss {task[2]} failed: {e}"
|
||||
error_msg = f"Task {task[1]} with stop loss {task[3]} failed: {e}"
|
||||
if self.logging:
|
||||
self.logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
@ -136,46 +266,56 @@ class BacktestRunner:
|
||||
if self.logging:
|
||||
self.logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
finally:
|
||||
# Stop progress display
|
||||
if self.progress_manager:
|
||||
self.progress_manager.stop_display()
|
||||
|
||||
if self.logging:
|
||||
self.logging.info(f"All {len(tasks)} tasks completed successfully")
|
||||
|
||||
return all_results, all_trades
|
||||
|
||||
def _process_single_task(
|
||||
self,
|
||||
task: Tuple[str, pd.DataFrame, float, float],
|
||||
debug: bool = False
|
||||
task: Tuple[str, str, pd.DataFrame, float, float],
|
||||
progress_callback=None
|
||||
) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
Process a single backtest task
|
||||
|
||||
Args:
|
||||
task: Tuple of (timeframe, data_1min, stop_loss_pct, initial_usd)
|
||||
debug: Whether to enable debug output
|
||||
task: Tuple of (task_id, timeframe, data_1min, stop_loss_pct, initial_usd)
|
||||
progress_callback: Optional progress callback function
|
||||
|
||||
Returns:
|
||||
Tuple of (results, trades)
|
||||
"""
|
||||
timeframe, data_1min, stop_loss_pct, initial_usd = task
|
||||
task_id, timeframe, data_1min, stop_loss_pct, initial_usd = task
|
||||
|
||||
try:
|
||||
# Resample data if needed
|
||||
if timeframe == "1T" or timeframe == "1min":
|
||||
df = data_1min.copy()
|
||||
else:
|
||||
df = self._resample_data(data_1min, timeframe)
|
||||
|
||||
# Process timeframe results
|
||||
results, trades = self.result_processor.process_timeframe_results(
|
||||
data_1min,
|
||||
df,
|
||||
[stop_loss_pct],
|
||||
timeframe,
|
||||
initial_usd,
|
||||
debug
|
||||
initial_usd,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
# Save individual trade files if trades exist
|
||||
if trades:
|
||||
self.result_processor.save_trade_file(trades, timeframe, stop_loss_pct)
|
||||
# OPTIMIZATION: Skip individual trade file saving during parallel execution
|
||||
# Trade files will be saved in batch at the end
|
||||
# if trades:
|
||||
# self.result_processor.save_trade_file(trades, timeframe, stop_loss_pct)
|
||||
|
||||
if self.logging:
|
||||
self.logging.info(f"Completed task {task_id}: {len(results)} results, {len(trades)} trades")
|
||||
|
||||
return results, trades
|
||||
|
||||
@ -197,13 +337,18 @@ class BacktestRunner:
|
||||
Resampled DataFrame
|
||||
"""
|
||||
try:
|
||||
resampled = data_1min.resample(timeframe).agg({
|
||||
agg_dict = {
|
||||
'open': 'first',
|
||||
'high': 'max',
|
||||
'low': 'min',
|
||||
'close': 'last',
|
||||
'volume': 'sum'
|
||||
}).dropna()
|
||||
}
|
||||
|
||||
if 'predicted_close_price' in data_1min.columns:
|
||||
agg_dict['predicted_close_price'] = 'last'
|
||||
|
||||
resampled = data_1min.resample(timeframe).agg(agg_dict).dropna()
|
||||
|
||||
return resampled.reset_index()
|
||||
|
||||
@ -213,6 +358,34 @@ class BacktestRunner:
|
||||
self.logging.error(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
def _get_timeframe_factor(self, timeframe: str) -> int:
|
||||
"""
|
||||
Get the factor by which data is reduced when resampling to timeframe
|
||||
|
||||
Args:
|
||||
timeframe: Target timeframe string (e.g., '1h', '4h', '1D')
|
||||
|
||||
Returns:
|
||||
Factor for estimating data size after resampling
|
||||
"""
|
||||
timeframe_factors = {
|
||||
'1T': 1, '1min': 1,
|
||||
'5T': 5, '5min': 5,
|
||||
'15T': 15, '15min': 15,
|
||||
'30T': 30, '30min': 30,
|
||||
'1h': 60, '1H': 60,
|
||||
'2h': 120, '2H': 120,
|
||||
'4h': 240, '4H': 240,
|
||||
'6h': 360, '6H': 360,
|
||||
'8h': 480, '8H': 480,
|
||||
'12h': 720, '12H': 720,
|
||||
'1D': 1440, '1d': 1440,
|
||||
'2D': 2880, '2d': 2880,
|
||||
'3D': 4320, '3d': 4320,
|
||||
'1W': 10080, '1w': 10080
|
||||
}
|
||||
return timeframe_factors.get(timeframe, 60) # Default to 1 hour if unknown
|
||||
|
||||
def load_data(self, filename: str, start_date: str, stop_date: str) -> pd.DataFrame:
|
||||
"""
|
||||
Load and validate data for backtesting
|
||||
@ -234,8 +407,11 @@ class BacktestRunner:
|
||||
if data.empty:
|
||||
raise ValueError(f"No data loaded for period {start_date} to {stop_date}")
|
||||
|
||||
# Validate required columns
|
||||
required_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||
|
||||
if 'predicted_close_price' in data.columns:
|
||||
required_columns.append('predicted_close_price')
|
||||
|
||||
missing_columns = [col for col in required_columns if col not in data.columns]
|
||||
|
||||
if missing_columns:
|
||||
@ -269,11 +445,9 @@ class BacktestRunner:
|
||||
Raises:
|
||||
ValueError: If any input is invalid
|
||||
"""
|
||||
# Validate timeframes
|
||||
if not timeframes:
|
||||
raise ValueError("At least one timeframe must be specified")
|
||||
|
||||
# Validate stop loss percentages
|
||||
if not stop_loss_pcts:
|
||||
raise ValueError("At least one stop loss percentage must be specified")
|
||||
|
||||
@ -281,7 +455,6 @@ class BacktestRunner:
|
||||
if not 0 < pct < 1:
|
||||
raise ValueError(f"Stop loss percentage must be between 0 and 1, got: {pct}")
|
||||
|
||||
# Validate initial USD
|
||||
if initial_usd <= 0:
|
||||
raise ValueError("Initial USD must be positive")
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ class ConfigManager:
|
||||
"initial_usd": 10000,
|
||||
"timeframes": ["1D", "6h", "3h", "1h", "30m", "15m", "5m", "1m"],
|
||||
"stop_loss_pcts": [0.01, 0.02, 0.03, 0.05],
|
||||
"data_dir": "data",
|
||||
"data_dir": "../data",
|
||||
"results_dir": "results"
|
||||
}
|
||||
|
||||
|
||||
10
configs/flat_2021_2024_config.json
Normal file
10
configs/flat_2021_2024_config.json
Normal file
@ -0,0 +1,10 @@
|
||||
{
|
||||
"start_date": "2021-11-01",
|
||||
"stop_date": "2024-04-01",
|
||||
"initial_usd": 10000,
|
||||
"timeframes": ["1min", "2min", "3min", "4min", "5min", "10min", "15min", "30min", "1h", "2h", "4h", "6h", "8h", "12h", "1d"],
|
||||
"stop_loss_pcts": [0.01, 0.02, 0.03, 0.04, 0.05, 0.1],
|
||||
"data_dir": "../data",
|
||||
"results_dir": "../results",
|
||||
"debug": 0
|
||||
}
|
||||
10
configs/full_config.json
Normal file
10
configs/full_config.json
Normal file
@ -0,0 +1,10 @@
|
||||
{
|
||||
"start_date": "2020-01-01",
|
||||
"stop_date": "2025-07-08",
|
||||
"initial_usd": 10000,
|
||||
"timeframes": ["1h", "4h", "15ME", "5ME", "1ME"],
|
||||
"stop_loss_pcts": [0.01, 0.02, 0.03, 0.05],
|
||||
"data_dir": "../data",
|
||||
"results_dir": "../results",
|
||||
"debug": 1
|
||||
}
|
||||
@ -2,8 +2,9 @@
|
||||
"start_date": "2023-01-01",
|
||||
"stop_date": "2025-01-15",
|
||||
"initial_usd": 10000,
|
||||
"timeframes": ["1h", "4h"],
|
||||
"stop_loss_pcts": [0.02, 0.05],
|
||||
"timeframes": ["4h"],
|
||||
"stop_loss_pcts": [0.05],
|
||||
"data_dir": "../data",
|
||||
"results_dir": "../results"
|
||||
"results_dir": "../results",
|
||||
"debug": 0
|
||||
}
|
||||
@ -7,21 +7,32 @@ from cycles.market_fees import MarketFees
|
||||
|
||||
class Backtest:
|
||||
@staticmethod
|
||||
def run(min1_df, df, initial_usd, stop_loss_pct, debug=False):
|
||||
def run(min1_df, df, initial_usd, stop_loss_pct, progress_callback=None, verbose=False):
|
||||
"""
|
||||
Backtest a simple strategy using the meta supertrend (all three supertrends agree).
|
||||
Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss.
|
||||
|
||||
Parameters:
|
||||
- min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional)
|
||||
- df: pandas DataFrame, main timeframe data for signals
|
||||
- initial_usd: float, starting USD amount
|
||||
- stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%)
|
||||
- debug: bool, whether to print debug info
|
||||
- progress_callback: callable, optional callback function to report progress (current_step)
|
||||
- verbose: bool, enable debug logging for stop loss checks
|
||||
"""
|
||||
_df = df.copy().reset_index(drop=True)
|
||||
_df = df.copy().reset_index()
|
||||
|
||||
# Ensure we have a timestamp column regardless of original index name
|
||||
if 'timestamp' not in _df.columns:
|
||||
# If reset_index() created a column with the original index name, rename it
|
||||
if len(_df.columns) > 0 and _df.columns[0] not in ['open', 'high', 'low', 'close', 'volume', 'predicted_close_price']:
|
||||
_df = _df.rename(columns={_df.columns[0]: 'timestamp'})
|
||||
else:
|
||||
raise ValueError("Unable to identify timestamp column in DataFrame")
|
||||
|
||||
_df['timestamp'] = pd.to_datetime(_df['timestamp'])
|
||||
|
||||
supertrends = Supertrends(_df, verbose=False)
|
||||
supertrends = Supertrends(_df, verbose=False, close_column='predicted_close_price')
|
||||
|
||||
supertrend_results_list = supertrends.calculate_supertrend_indicators()
|
||||
trends = [st['results']['trend'] for st in supertrend_results_list]
|
||||
@ -41,18 +52,21 @@ class Backtest:
|
||||
drawdowns = []
|
||||
trades = []
|
||||
entry_time = None
|
||||
current_trade_min1_start_idx = None
|
||||
stop_loss_count = 0 # Track number of stop losses
|
||||
|
||||
min1_df.index = pd.to_datetime(min1_df.index)
|
||||
min1_timestamps = min1_df.index.values
|
||||
# Ensure min1_df has proper DatetimeIndex
|
||||
if min1_df is not None and not min1_df.empty:
|
||||
min1_df.index = pd.to_datetime(min1_df.index)
|
||||
|
||||
last_print_time = time.time()
|
||||
for i in range(1, len(_df)):
|
||||
current_time = time.time()
|
||||
if current_time - last_print_time >= 5:
|
||||
progress = (i / len(_df)) * 100
|
||||
print(f"\rProgress: {progress:.1f}%", end="", flush=True)
|
||||
last_print_time = current_time
|
||||
# Report progress if callback is provided
|
||||
if progress_callback:
|
||||
# Update more frequently for better responsiveness
|
||||
update_frequency = max(1, len(_df) // 50) # Update every 2% of dataset (50 updates total)
|
||||
if i % update_frequency == 0 or i == len(_df) - 1: # Always update on last iteration
|
||||
if verbose: # Only print in verbose mode to avoid spam
|
||||
print(f"DEBUG: Progress callback called with i={i}, total={len(_df)-1}")
|
||||
progress_callback(i)
|
||||
|
||||
price_open = _df['open'].iloc[i]
|
||||
price_close = _df['close'].iloc[i]
|
||||
@ -69,16 +83,13 @@ class Backtest:
|
||||
entry_price,
|
||||
stop_loss_pct,
|
||||
coin,
|
||||
usd,
|
||||
debug,
|
||||
current_trade_min1_start_idx
|
||||
verbose=verbose
|
||||
)
|
||||
if stop_loss_result is not None:
|
||||
trade_log_entry, current_trade_min1_start_idx, position, coin, entry_price = stop_loss_result
|
||||
trade_log_entry, position, coin, entry_price, usd = stop_loss_result
|
||||
trade_log.append(trade_log_entry)
|
||||
stop_loss_count += 1
|
||||
continue
|
||||
# Update the start index for next check
|
||||
current_trade_min1_start_idx = min1_df.index[min1_df.index <= date][-1]
|
||||
|
||||
# Entry: only if not in position and signal changes to 1
|
||||
if position == 0 and prev_mt != 1 and curr_mt == 1:
|
||||
@ -99,7 +110,9 @@ class Backtest:
|
||||
drawdown = (max_balance - balance) / max_balance
|
||||
drawdowns.append(drawdown)
|
||||
|
||||
print("\rProgress: 100%\r\n", end="", flush=True)
|
||||
# Report completion if callback is provided
|
||||
if progress_callback:
|
||||
progress_callback(len(_df) - 1)
|
||||
|
||||
# If still in position at end, sell at last close
|
||||
if position == 1:
|
||||
@ -122,22 +135,37 @@ class Backtest:
|
||||
profit_pct = (trade['exit'] - trade['entry']) / trade['entry']
|
||||
else:
|
||||
profit_pct = 0.0
|
||||
|
||||
# Validate fee_usd field
|
||||
if 'fee_usd' not in trade:
|
||||
raise ValueError(f"Trade missing required field 'fee_usd': {trade}")
|
||||
fee_usd = trade['fee_usd']
|
||||
if fee_usd is None:
|
||||
raise ValueError(f"Trade fee_usd is None: {trade}")
|
||||
|
||||
# Validate trade type field
|
||||
if 'type' not in trade:
|
||||
raise ValueError(f"Trade missing required field 'type': {trade}")
|
||||
trade_type = trade['type']
|
||||
if trade_type is None:
|
||||
raise ValueError(f"Trade type is None: {trade}")
|
||||
|
||||
trades.append({
|
||||
'entry_time': trade['entry_time'],
|
||||
'exit_time': trade['exit_time'],
|
||||
'entry': trade['entry'],
|
||||
'exit': trade['exit'],
|
||||
'profit_pct': profit_pct,
|
||||
'type': trade.get('type', 'SELL'),
|
||||
'fee_usd': trade.get('fee_usd')
|
||||
'type': trade_type,
|
||||
'fee_usd': fee_usd
|
||||
})
|
||||
fee_usd = trade.get('fee_usd')
|
||||
total_fees_usd += fee_usd
|
||||
|
||||
results = {
|
||||
"initial_usd": initial_usd,
|
||||
"final_usd": final_balance,
|
||||
"n_trades": n_trades,
|
||||
"n_stop_loss": stop_loss_count, # Add stop loss count
|
||||
"win_rate": win_rate,
|
||||
"max_drawdown": max_drawdown,
|
||||
"avg_trade": avg_trade,
|
||||
@ -157,38 +185,112 @@ class Backtest:
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def check_stop_loss(min1_df, entry_time, date, entry_price, stop_loss_pct, coin, usd, debug, current_trade_min1_start_idx):
|
||||
def check_stop_loss(min1_df, entry_time, current_time, entry_price, stop_loss_pct, coin, verbose=False):
|
||||
"""
|
||||
Check if stop loss should be triggered based on 1-minute data
|
||||
|
||||
Args:
|
||||
min1_df: 1-minute DataFrame with DatetimeIndex
|
||||
entry_time: Entry timestamp
|
||||
current_time: Current timestamp
|
||||
entry_price: Entry price
|
||||
stop_loss_pct: Stop loss percentage (e.g. 0.05 for 5%)
|
||||
coin: Current coin position
|
||||
verbose: Enable debug logging
|
||||
|
||||
Returns:
|
||||
Tuple of (trade_log_entry, position, coin, entry_price, usd) if stop loss triggered, None otherwise
|
||||
"""
|
||||
if min1_df is None or min1_df.empty:
|
||||
if verbose:
|
||||
print("Warning: No 1-minute data available for stop loss checking")
|
||||
return None
|
||||
|
||||
stop_price = entry_price * (1 - stop_loss_pct)
|
||||
|
||||
if current_trade_min1_start_idx is None:
|
||||
current_trade_min1_start_idx = min1_df.index[min1_df.index >= entry_time][0]
|
||||
current_min1_end_idx = min1_df.index[min1_df.index <= date][-1]
|
||||
|
||||
# Check all 1-minute candles in between for stop loss
|
||||
min1_slice = min1_df.loc[current_trade_min1_start_idx:current_min1_end_idx]
|
||||
if (min1_slice['low'] <= stop_price).any():
|
||||
# Stop loss triggered, find the exact candle
|
||||
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0]
|
||||
# More realistic fill: if open < stop, fill at open, else at stop
|
||||
if stop_candle['open'] < stop_price:
|
||||
sell_price = stop_candle['open']
|
||||
else:
|
||||
sell_price = stop_price
|
||||
if debug:
|
||||
print(f"STOP LOSS triggered: entry={entry_price}, stop={stop_price}, sell_price={sell_price}, entry_time={entry_time}, stop_time={stop_candle.name}")
|
||||
btc_to_sell = coin
|
||||
usd_gross = btc_to_sell * sell_price
|
||||
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
|
||||
trade_log_entry = {
|
||||
'type': 'STOP',
|
||||
'entry': entry_price,
|
||||
'exit': sell_price,
|
||||
'entry_time': entry_time,
|
||||
'exit_time': stop_candle.name,
|
||||
'fee_usd': exit_fee
|
||||
}
|
||||
# After stop loss, reset position and entry
|
||||
return trade_log_entry, None, 0, 0, 0
|
||||
try:
|
||||
# Ensure min1_df has a DatetimeIndex
|
||||
if not isinstance(min1_df.index, pd.DatetimeIndex):
|
||||
if verbose:
|
||||
print("Warning: min1_df does not have DatetimeIndex")
|
||||
return None
|
||||
|
||||
# Convert entry_time and current_time to pandas Timestamps for comparison
|
||||
entry_ts = pd.to_datetime(entry_time)
|
||||
current_ts = pd.to_datetime(current_time)
|
||||
|
||||
if verbose:
|
||||
print(f"Checking stop loss from {entry_ts} to {current_ts}, stop_price: {stop_price:.2f}")
|
||||
|
||||
# Handle edge case where entry and current time are the same (1-minute timeframe)
|
||||
if entry_ts == current_ts:
|
||||
if verbose:
|
||||
print("Entry and current time are the same, no range to check")
|
||||
return None
|
||||
|
||||
# Find the range of 1-minute data to check (exclusive of entry time, inclusive of current time)
|
||||
# We start from the candle AFTER entry to avoid checking the entry candle itself
|
||||
start_check_time = entry_ts + pd.Timedelta(minutes=1)
|
||||
|
||||
# Get the slice of data to check for stop loss
|
||||
mask = (min1_df.index > entry_ts) & (min1_df.index <= current_ts)
|
||||
min1_slice = min1_df.loc[mask]
|
||||
|
||||
if len(min1_slice) == 0:
|
||||
if verbose:
|
||||
print(f"No 1-minute data found between {start_check_time} and {current_ts}")
|
||||
return None
|
||||
|
||||
if verbose:
|
||||
print(f"Checking {len(min1_slice)} candles for stop loss")
|
||||
|
||||
# Check if any low price in the slice hits the stop loss
|
||||
stop_triggered = (min1_slice['low'] <= stop_price).any()
|
||||
|
||||
if stop_triggered:
|
||||
# Find the exact candle where stop loss was triggered
|
||||
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0]
|
||||
|
||||
if verbose:
|
||||
print(f"Stop loss triggered at {stop_candle.name}, low: {stop_candle['low']:.2f}")
|
||||
|
||||
# More realistic fill: if open < stop, fill at open, else at stop
|
||||
if stop_candle['open'] < stop_price:
|
||||
sell_price = stop_candle['open']
|
||||
if verbose:
|
||||
print(f"Filled at open price: {sell_price:.2f}")
|
||||
else:
|
||||
sell_price = stop_price
|
||||
if verbose:
|
||||
print(f"Filled at stop price: {sell_price:.2f}")
|
||||
|
||||
btc_to_sell = coin
|
||||
usd_gross = btc_to_sell * sell_price
|
||||
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
|
||||
usd_after_stop = usd_gross - exit_fee
|
||||
|
||||
trade_log_entry = {
|
||||
'type': 'STOP',
|
||||
'entry': entry_price,
|
||||
'exit': sell_price,
|
||||
'entry_time': entry_time,
|
||||
'exit_time': stop_candle.name,
|
||||
'fee_usd': exit_fee
|
||||
}
|
||||
# After stop loss, reset position and entry, return USD balance
|
||||
return trade_log_entry, 0, 0, 0, usd_after_stop
|
||||
elif verbose:
|
||||
print(f"No stop loss triggered, min low in range: {min1_slice['low'].min():.2f}")
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, don't trigger stop loss but log the issue
|
||||
error_msg = f"Warning: Stop loss check failed: {e}"
|
||||
print(error_msg)
|
||||
if verbose:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -65,30 +65,57 @@ def cached_supertrend_calculation(period, multiplier, data_tuple):
|
||||
'lower_band': final_lower
|
||||
}
|
||||
|
||||
def calculate_supertrend_external(data, period, multiplier):
|
||||
def calculate_supertrend_external(data, period, multiplier, close_column='close'):
|
||||
"""
|
||||
External function to calculate SuperTrend with configurable close column
|
||||
|
||||
Parameters:
|
||||
- data: DataFrame with OHLC data
|
||||
- period: int, period for ATR calculation
|
||||
- multiplier: float, multiplier for ATR
|
||||
- close_column: str, name of the column to use as close price (default: 'close')
|
||||
"""
|
||||
high_tuple = tuple(data['high'])
|
||||
low_tuple = tuple(data['low'])
|
||||
close_tuple = tuple(data['close'])
|
||||
close_tuple = tuple(data[close_column])
|
||||
return cached_supertrend_calculation(period, multiplier, (high_tuple, low_tuple, close_tuple))
|
||||
|
||||
class Supertrends:
|
||||
def __init__(self, data, verbose=False, display=False):
|
||||
def __init__(self, data, close_column='close', verbose=False, display=False):
|
||||
"""
|
||||
Initialize Supertrends calculator
|
||||
|
||||
Parameters:
|
||||
- data: pandas DataFrame with OHLC data or list of prices
|
||||
- close_column: str, name of the column to use as close price (default: 'close')
|
||||
- verbose: bool, enable verbose logging
|
||||
- display: bool, display mode (currently unused)
|
||||
"""
|
||||
self.close_column = close_column
|
||||
self.data = data
|
||||
self.verbose = verbose
|
||||
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
self.logger = logging.getLogger('TrendDetectorSimple')
|
||||
|
||||
if not isinstance(self.data, pd.DataFrame):
|
||||
if isinstance(self.data, list):
|
||||
self.data = pd.DataFrame({'close': self.data})
|
||||
self.data = pd.DataFrame({self.close_column: self.data})
|
||||
else:
|
||||
raise ValueError("Data must be a pandas DataFrame or a list")
|
||||
|
||||
# Validate that required columns exist
|
||||
required_columns = ['high', 'low', self.close_column]
|
||||
missing_columns = [col for col in required_columns if col not in self.data.columns]
|
||||
if missing_columns:
|
||||
raise ValueError(f"Missing required columns: {missing_columns}")
|
||||
|
||||
def calculate_tr(self):
|
||||
"""Calculate True Range using the configured close column"""
|
||||
df = self.data.copy()
|
||||
high = df['high'].values
|
||||
low = df['low'].values
|
||||
close = df['close'].values
|
||||
close = df[self.close_column].values
|
||||
tr = np.zeros_like(close)
|
||||
tr[0] = high[0] - low[0]
|
||||
for i in range(1, len(close)):
|
||||
@ -99,6 +126,7 @@ class Supertrends:
|
||||
return tr
|
||||
|
||||
def calculate_atr(self, period=14):
|
||||
"""Calculate Average True Range"""
|
||||
tr = self.calculate_tr()
|
||||
atr = np.zeros_like(tr)
|
||||
atr[0] = tr[0]
|
||||
@ -109,18 +137,20 @@ class Supertrends:
|
||||
|
||||
def calculate_supertrend(self, period=10, multiplier=3.0):
|
||||
"""
|
||||
Calculate SuperTrend indicator for the price data.
|
||||
Calculate SuperTrend indicator for the price data using the configured close column.
|
||||
SuperTrend is a trend-following indicator that uses ATR to determine the trend direction.
|
||||
|
||||
Parameters:
|
||||
- period: int, the period for the ATR calculation (default: 10)
|
||||
- multiplier: float, the multiplier for the ATR (default: 3.0)
|
||||
|
||||
Returns:
|
||||
- Dictionary containing SuperTrend values, trend direction, and upper/lower bands
|
||||
"""
|
||||
df = self.data.copy()
|
||||
high = df['high'].values
|
||||
low = df['low'].values
|
||||
close = df['close'].values
|
||||
close = df[self.close_column].values
|
||||
atr = self.calculate_atr(period)
|
||||
upper_band = np.zeros_like(close)
|
||||
lower_band = np.zeros_like(close)
|
||||
|
||||
233
cycles/utils/progress_manager.py
Normal file
233
cycles/utils/progress_manager.py
Normal file
@ -0,0 +1,233 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Progress Manager for tracking multiple parallel backtest tasks
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import sys
|
||||
from typing import Dict, Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskProgress:
|
||||
"""Represents progress information for a single task"""
|
||||
task_id: str
|
||||
name: str
|
||||
current: int
|
||||
total: int
|
||||
start_time: float
|
||||
last_update: float
|
||||
|
||||
@property
|
||||
def percentage(self) -> float:
|
||||
"""Calculate completion percentage"""
|
||||
if self.total == 0:
|
||||
return 0.0
|
||||
return (self.current / self.total) * 100
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float:
|
||||
"""Calculate elapsed time in seconds"""
|
||||
return time.time() - self.start_time
|
||||
|
||||
@property
|
||||
def eta(self) -> Optional[float]:
|
||||
"""Estimate time to completion in seconds"""
|
||||
if self.current == 0 or self.percentage >= 100:
|
||||
return None
|
||||
|
||||
elapsed = self.elapsed_time
|
||||
rate = self.current / elapsed
|
||||
remaining = self.total - self.current
|
||||
return remaining / rate if rate > 0 else None
|
||||
|
||||
|
||||
class ProgressManager:
|
||||
"""Manages progress tracking for multiple parallel tasks"""
|
||||
|
||||
def __init__(self, update_interval: float = 1.0, display_width: int = 50):
|
||||
"""
|
||||
Initialize progress manager
|
||||
|
||||
Args:
|
||||
update_interval: How often to update display (seconds)
|
||||
display_width: Width of progress bar in characters
|
||||
"""
|
||||
self.tasks: Dict[str, TaskProgress] = {}
|
||||
self.update_interval = update_interval
|
||||
self.display_width = display_width
|
||||
self.lock = threading.Lock()
|
||||
self.display_thread: Optional[threading.Thread] = None
|
||||
self.running = False
|
||||
self.last_display_height = 0
|
||||
|
||||
def start_task(self, task_id: str, name: str, total: int) -> None:
|
||||
"""
|
||||
Start tracking a new task
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for the task
|
||||
name: Human-readable name for the task
|
||||
total: Total number of steps in the task
|
||||
"""
|
||||
with self.lock:
|
||||
self.tasks[task_id] = TaskProgress(
|
||||
task_id=task_id,
|
||||
name=name,
|
||||
current=0,
|
||||
total=total,
|
||||
start_time=time.time(),
|
||||
last_update=time.time()
|
||||
)
|
||||
|
||||
def update_progress(self, task_id: str, current: int) -> None:
|
||||
"""
|
||||
Update progress for a specific task
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
current: Current progress value
|
||||
"""
|
||||
with self.lock:
|
||||
if task_id in self.tasks:
|
||||
self.tasks[task_id].current = current
|
||||
self.tasks[task_id].last_update = time.time()
|
||||
|
||||
def complete_task(self, task_id: str) -> None:
|
||||
"""
|
||||
Mark a task as completed
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
"""
|
||||
with self.lock:
|
||||
if task_id in self.tasks:
|
||||
task = self.tasks[task_id]
|
||||
task.current = task.total
|
||||
task.last_update = time.time()
|
||||
|
||||
def start_display(self) -> None:
|
||||
"""Start the progress display thread"""
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.display_thread = threading.Thread(target=self._display_loop, daemon=True)
|
||||
self.display_thread.start()
|
||||
|
||||
def stop_display(self) -> None:
|
||||
"""Stop the progress display thread"""
|
||||
self.running = False
|
||||
if self.display_thread:
|
||||
self.display_thread.join(timeout=1.0)
|
||||
self._clear_display()
|
||||
|
||||
def _display_loop(self) -> None:
|
||||
"""Main loop for updating the progress display"""
|
||||
while self.running:
|
||||
self._update_display()
|
||||
time.sleep(self.update_interval)
|
||||
|
||||
def _update_display(self) -> None:
|
||||
"""Update the console display with current progress"""
|
||||
with self.lock:
|
||||
if not self.tasks:
|
||||
return
|
||||
|
||||
# Clear previous display
|
||||
self._clear_display()
|
||||
|
||||
# Build display lines
|
||||
lines = []
|
||||
for task in sorted(self.tasks.values(), key=lambda t: t.task_id):
|
||||
line = self._format_progress_line(task)
|
||||
lines.append(line)
|
||||
|
||||
# Print all lines
|
||||
for line in lines:
|
||||
print(line, flush=True)
|
||||
|
||||
self.last_display_height = len(lines)
|
||||
|
||||
def _clear_display(self) -> None:
|
||||
"""Clear the previous progress display"""
|
||||
if self.last_display_height > 0:
|
||||
# Move cursor up and clear lines
|
||||
for _ in range(self.last_display_height):
|
||||
sys.stdout.write('\033[F') # Move cursor up one line
|
||||
sys.stdout.write('\033[K') # Clear line
|
||||
sys.stdout.flush()
|
||||
|
||||
def _format_progress_line(self, task: TaskProgress) -> str:
|
||||
"""
|
||||
Format a single progress line for display
|
||||
|
||||
Args:
|
||||
task: TaskProgress instance
|
||||
|
||||
Returns:
|
||||
Formatted progress string
|
||||
"""
|
||||
# Progress bar
|
||||
filled_width = int(task.percentage / 100 * self.display_width)
|
||||
bar = '█' * filled_width + '░' * (self.display_width - filled_width)
|
||||
|
||||
# Time information
|
||||
elapsed_str = self._format_time(task.elapsed_time)
|
||||
eta_str = self._format_time(task.eta) if task.eta else "N/A"
|
||||
|
||||
# Format line
|
||||
line = (f"{task.name:<25} │{bar}│ "
|
||||
f"{task.percentage:5.1f}% "
|
||||
f"({task.current:,}/{task.total:,}) "
|
||||
f"⏱ {elapsed_str} ETA: {eta_str}")
|
||||
|
||||
return line
|
||||
|
||||
def _format_time(self, seconds: float) -> str:
|
||||
"""
|
||||
Format time duration for display
|
||||
|
||||
Args:
|
||||
seconds: Time in seconds
|
||||
|
||||
Returns:
|
||||
Formatted time string
|
||||
"""
|
||||
if seconds < 60:
|
||||
return f"{seconds:.0f}s"
|
||||
elif seconds < 3600:
|
||||
minutes = seconds / 60
|
||||
return f"{minutes:.1f}m"
|
||||
else:
|
||||
hours = seconds / 3600
|
||||
return f"{hours:.1f}h"
|
||||
|
||||
def get_task_progress_callback(self, task_id: str) -> Callable[[int], None]:
|
||||
"""
|
||||
Get a progress callback function for a specific task
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
|
||||
Returns:
|
||||
Callback function that updates progress for this task
|
||||
"""
|
||||
def callback(current: int) -> None:
|
||||
self.update_progress(task_id, current)
|
||||
|
||||
return callback
|
||||
|
||||
def all_tasks_completed(self) -> bool:
|
||||
"""Check if all tasks are completed"""
|
||||
with self.lock:
|
||||
return all(task.current >= task.total for task in self.tasks.values())
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""Get a summary of all tasks"""
|
||||
with self.lock:
|
||||
total_tasks = len(self.tasks)
|
||||
completed_tasks = sum(1 for task in self.tasks.values()
|
||||
if task.current >= task.total)
|
||||
|
||||
return f"Tasks: {completed_tasks}/{total_tasks} completed"
|
||||
@ -10,10 +10,12 @@ class SystemUtils:
|
||||
"""Determine optimal number of worker processes based on system resources"""
|
||||
cpu_count = os.cpu_count() or 4
|
||||
memory_gb = psutil.virtual_memory().total / (1024**3)
|
||||
# Heuristic: Use 75% of cores, but cap based on available memory
|
||||
# Assume each worker needs ~2GB for large datasets
|
||||
workers_by_memory = max(1, int(memory_gb / 2))
|
||||
workers_by_cpu = max(1, int(cpu_count * 0.75))
|
||||
|
||||
# OPTIMIZATION: More aggressive worker allocation for better performance
|
||||
workers_by_memory = max(1, int(memory_gb / 2)) # 2GB per worker
|
||||
workers_by_cpu = max(1, int(cpu_count * 0.8)) # Use 80% of CPU cores
|
||||
optimal_workers = min(workers_by_cpu, workers_by_memory, 8) # Cap at 8 workers
|
||||
|
||||
if self.logging is not None:
|
||||
self.logging.info(f"Using {min(workers_by_cpu, workers_by_memory)} workers for processing")
|
||||
return min(workers_by_cpu, workers_by_memory)
|
||||
self.logging.info(f"Using {optimal_workers} workers for processing (CPU-based: {workers_by_cpu}, Memory-based: {workers_by_memory})")
|
||||
return optimal_workers
|
||||
27
main.py
27
main.py
@ -79,7 +79,23 @@ def main():
|
||||
)
|
||||
system_utils = SystemUtils(logging=logger)
|
||||
result_processor = ResultProcessor(storage, logging_instance=logger)
|
||||
runner = BacktestRunner(storage, system_utils, result_processor, logging_instance=logger)
|
||||
|
||||
# OPTIMIZATION: Disable progress for parallel execution to improve performance
|
||||
show_progress = config.get('show_progress', True)
|
||||
debug_mode = config.get('debug', 0) == 1
|
||||
|
||||
# Only show progress in debug (sequential) mode
|
||||
if not debug_mode:
|
||||
show_progress = False
|
||||
logger.info("Progress tracking disabled for parallel execution (performance optimization)")
|
||||
|
||||
runner = BacktestRunner(
|
||||
storage,
|
||||
system_utils,
|
||||
result_processor,
|
||||
logging_instance=logger,
|
||||
show_progress=show_progress
|
||||
)
|
||||
|
||||
# Validate inputs
|
||||
logger.info("Validating inputs...")
|
||||
@ -91,7 +107,8 @@ def main():
|
||||
|
||||
# Load data
|
||||
logger.info("Loading market data...")
|
||||
data_filename = 'btcusd_1-min_data.csv'
|
||||
# data_filename = 'btcusd_1-min_data.csv'
|
||||
data_filename = 'btcusd_1-min_data_with_price_predictions.csv'
|
||||
data_1min = runner.load_data(
|
||||
data_filename,
|
||||
config['start_date'],
|
||||
@ -100,7 +117,6 @@ def main():
|
||||
|
||||
# Run backtests
|
||||
logger.info("Starting backtest execution...")
|
||||
debug_mode = True # Can be moved to config
|
||||
|
||||
all_results, all_trades = runner.run_backtests(
|
||||
data_1min,
|
||||
@ -114,6 +130,11 @@ def main():
|
||||
logger.info("Processing and saving results...")
|
||||
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")
|
||||
|
||||
# OPTIMIZATION: Save trade files in batch after parallel execution
|
||||
if all_trades and not debug_mode:
|
||||
logger.info("Saving trade files in batch...")
|
||||
result_processor.save_all_trade_files(all_trades)
|
||||
|
||||
# Create metadata
|
||||
metadata_lines = create_metadata_lines(config, data_1min, result_processor)
|
||||
|
||||
|
||||
@ -29,8 +29,8 @@ class ResultProcessor:
|
||||
df: pd.DataFrame,
|
||||
stop_loss_pcts: List[float],
|
||||
timeframe_name: str,
|
||||
initial_usd: float,
|
||||
debug: bool = False
|
||||
initial_usd: float,
|
||||
progress_callback=None
|
||||
) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
Process results for a single timeframe with multiple stop loss values
|
||||
@ -41,7 +41,7 @@ class ResultProcessor:
|
||||
stop_loss_pcts: List of stop loss percentages to test
|
||||
timeframe_name: Name of the timeframe (e.g., '1D', '6h')
|
||||
initial_usd: Initial USD amount
|
||||
debug: Whether to enable debug output
|
||||
progress_callback: Optional progress callback function
|
||||
|
||||
Returns:
|
||||
Tuple of (results_rows, trade_rows)
|
||||
@ -59,7 +59,8 @@ class ResultProcessor:
|
||||
df,
|
||||
initial_usd=initial_usd,
|
||||
stop_loss_pct=stop_loss_pct,
|
||||
debug=debug
|
||||
progress_callback=progress_callback,
|
||||
verbose=False # Default to False for production runs
|
||||
)
|
||||
|
||||
# Calculate metrics
|
||||
@ -67,15 +68,14 @@ class ResultProcessor:
|
||||
results_rows.append(metrics)
|
||||
|
||||
# Process trades
|
||||
trades = self._process_trades(results.get('trades', []), timeframe_name, stop_loss_pct)
|
||||
if 'trades' not in results:
|
||||
raise ValueError(f"Backtest results missing 'trades' field for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||
trades = self._process_trades(results['trades'], timeframe_name, stop_loss_pct)
|
||||
trade_rows.extend(trades)
|
||||
|
||||
if self.logging:
|
||||
self.logging.info(f"Timeframe: {timeframe_name}, Stop Loss: {stop_loss_pct}, Trades: {results['n_trades']}")
|
||||
|
||||
if debug:
|
||||
self._debug_output(results)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing {timeframe_name} with stop loss {stop_loss_pct}: {e}"
|
||||
if self.logging:
|
||||
@ -92,36 +92,56 @@ class ResultProcessor:
|
||||
timeframe_name: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Calculate performance metrics from backtest results"""
|
||||
trades = results.get('trades', [])
|
||||
if 'trades' not in results:
|
||||
raise ValueError(f"Backtest results missing 'trades' field for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||
trades = results['trades']
|
||||
n_trades = results["n_trades"]
|
||||
|
||||
# Calculate win metrics
|
||||
winning_trades = [t for t in trades if t.get('exit') is not None and t['exit'] > t['entry']]
|
||||
# Validate that all required fields are present
|
||||
required_fields = ['final_usd', 'max_drawdown', 'total_fees_usd', 'n_trades', 'n_stop_loss', 'win_rate', 'avg_trade']
|
||||
missing_fields = [field for field in required_fields if field not in results]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Backtest results missing required fields: {missing_fields}")
|
||||
|
||||
# Calculate win metrics - validate trade fields
|
||||
winning_trades = []
|
||||
for t in trades:
|
||||
if 'exit' not in t:
|
||||
raise ValueError(f"Trade missing 'exit' field: {t}")
|
||||
if 'entry' not in t:
|
||||
raise ValueError(f"Trade missing 'entry' field: {t}")
|
||||
if t['exit'] is not None and t['exit'] > t['entry']:
|
||||
winning_trades.append(t)
|
||||
n_winning_trades = len(winning_trades)
|
||||
win_rate = n_winning_trades / n_trades if n_trades > 0 else 0
|
||||
|
||||
# Calculate profit metrics
|
||||
total_profit = sum(trade['profit_pct'] for trade in trades)
|
||||
total_loss = sum(-trade['profit_pct'] for trade in trades if trade['profit_pct'] < 0)
|
||||
avg_trade = total_profit / n_trades if n_trades > 0 else 0
|
||||
profit_ratio = total_profit / total_loss if total_loss > 0 else float('inf')
|
||||
total_profit = sum(trade['profit_pct'] for trade in trades if trade['profit_pct'] > 0)
|
||||
total_loss = abs(sum(trade['profit_pct'] for trade in trades if trade['profit_pct'] < 0))
|
||||
avg_trade = sum(trade['profit_pct'] for trade in trades) / n_trades if n_trades > 0 else 0
|
||||
profit_ratio = total_profit / total_loss if total_loss > 0 else (float('inf') if total_profit > 0 else 0)
|
||||
|
||||
# Calculate drawdown
|
||||
max_drawdown = self._calculate_max_drawdown(trades)
|
||||
# Get values directly from backtest results (no defaults)
|
||||
max_drawdown = results['max_drawdown']
|
||||
final_usd = results['final_usd']
|
||||
total_fees_usd = results['total_fees_usd']
|
||||
n_stop_loss = results['n_stop_loss'] # Get stop loss count directly from backtest
|
||||
|
||||
# Calculate final USD
|
||||
final_usd = initial_usd
|
||||
for trade in trades:
|
||||
final_usd *= (1 + trade['profit_pct'])
|
||||
|
||||
# Calculate fees
|
||||
total_fees_usd = sum(trade.get('fee_usd', 0) for trade in trades)
|
||||
# Validate no None values
|
||||
if max_drawdown is None:
|
||||
raise ValueError(f"max_drawdown is None for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||
if final_usd is None:
|
||||
raise ValueError(f"final_usd is None for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||
if total_fees_usd is None:
|
||||
raise ValueError(f"total_fees_usd is None for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||
if n_stop_loss is None:
|
||||
raise ValueError(f"n_stop_loss is None for {timeframe_name} with {stop_loss_pct} stop loss")
|
||||
|
||||
return {
|
||||
"timeframe": timeframe_name,
|
||||
"stop_loss_pct": stop_loss_pct,
|
||||
"n_trades": n_trades,
|
||||
"n_stop_loss": sum(1 for trade in trades if trade.get('type') == 'STOP'),
|
||||
"n_stop_loss": n_stop_loss,
|
||||
"win_rate": win_rate,
|
||||
"max_drawdown": max_drawdown,
|
||||
"avg_trade": avg_trade,
|
||||
@ -159,16 +179,22 @@ class ResultProcessor:
|
||||
processed_trades = []
|
||||
|
||||
for trade in trades:
|
||||
# Validate all required trade fields
|
||||
required_fields = ["entry_time", "exit_time", "entry", "exit", "profit_pct", "type", "fee_usd"]
|
||||
missing_fields = [field for field in required_fields if field not in trade]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Trade missing required fields: {missing_fields} in trade: {trade}")
|
||||
|
||||
processed_trade = {
|
||||
"timeframe": timeframe_name,
|
||||
"stop_loss_pct": stop_loss_pct,
|
||||
"entry_time": trade.get("entry_time"),
|
||||
"exit_time": trade.get("exit_time"),
|
||||
"entry_price": trade.get("entry"),
|
||||
"exit_price": trade.get("exit"),
|
||||
"profit_pct": trade.get("profit_pct"),
|
||||
"type": trade.get("type"),
|
||||
"fee_usd": trade.get("fee_usd"),
|
||||
"entry_time": trade["entry_time"],
|
||||
"exit_time": trade["exit_time"],
|
||||
"entry_price": trade["entry"],
|
||||
"exit_price": trade["exit"],
|
||||
"profit_pct": trade["profit_pct"],
|
||||
"type": trade["type"],
|
||||
"fee_usd": trade["fee_usd"],
|
||||
}
|
||||
processed_trades.append(processed_trade)
|
||||
|
||||
@ -176,17 +202,31 @@ class ResultProcessor:
|
||||
|
||||
def _debug_output(self, results: Dict[str, Any]) -> None:
|
||||
"""Output debug information for backtest results"""
|
||||
trades = results.get('trades', [])
|
||||
if 'trades' not in results:
|
||||
raise ValueError("Backtest results missing 'trades' field for debug output")
|
||||
trades = results['trades']
|
||||
|
||||
# Print stop loss trades
|
||||
stop_loss_trades = [t for t in trades if t.get('type') == 'STOP']
|
||||
stop_loss_trades = []
|
||||
for t in trades:
|
||||
if 'type' not in t:
|
||||
raise ValueError(f"Trade missing 'type' field: {t}")
|
||||
if t['type'] == 'STOP':
|
||||
stop_loss_trades.append(t)
|
||||
|
||||
if stop_loss_trades:
|
||||
print("Stop Loss Trades:")
|
||||
for trade in stop_loss_trades:
|
||||
print(trade)
|
||||
|
||||
# Print large loss trades
|
||||
large_loss_trades = [t for t in trades if t.get('profit_pct', 0) < -0.09]
|
||||
large_loss_trades = []
|
||||
for t in trades:
|
||||
if 'profit_pct' not in t:
|
||||
raise ValueError(f"Trade missing 'profit_pct' field: {t}")
|
||||
if t['profit_pct'] < -0.09:
|
||||
large_loss_trades.append(t)
|
||||
|
||||
if large_loss_trades:
|
||||
print("Large Loss Trades:")
|
||||
for trade in large_loss_trades:
|
||||
@ -216,19 +256,32 @@ class ResultProcessor:
|
||||
|
||||
def _aggregate_group(self, rows: List[Dict], timeframe: str, stop_loss_pct: float) -> Dict:
|
||||
"""Aggregate a group of rows with the same timeframe and stop loss"""
|
||||
if not rows:
|
||||
raise ValueError(f"No rows to aggregate for {timeframe} with {stop_loss_pct} stop loss")
|
||||
|
||||
# Validate all rows have required fields
|
||||
required_fields = ['n_trades', 'n_stop_loss', 'win_rate', 'max_drawdown', 'avg_trade', 'profit_ratio', 'final_usd', 'total_fees_usd', 'initial_usd']
|
||||
for i, row in enumerate(rows):
|
||||
missing_fields = [field for field in required_fields if field not in row]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Row {i} missing required fields: {missing_fields}")
|
||||
|
||||
total_trades = sum(r['n_trades'] for r in rows)
|
||||
total_stop_loss = sum(r['n_stop_loss'] for r in rows)
|
||||
|
||||
# Calculate averages
|
||||
# Calculate averages (no defaults, expect all values to be present)
|
||||
avg_win_rate = np.mean([r['win_rate'] for r in rows])
|
||||
avg_max_drawdown = np.mean([r['max_drawdown'] for r in rows])
|
||||
avg_avg_trade = np.mean([r['avg_trade'] for r in rows])
|
||||
avg_profit_ratio = np.mean([r['profit_ratio'] for r in rows])
|
||||
|
||||
# Calculate final USD and fees
|
||||
final_usd = np.mean([r.get('final_usd', r.get('initial_usd', 0)) for r in rows])
|
||||
total_fees_usd = np.mean([r.get('total_fees_usd', 0) for r in rows])
|
||||
initial_usd = rows[0].get('initial_usd', 0) if rows else 0
|
||||
# Handle infinite profit ratios properly
|
||||
finite_profit_ratios = [r['profit_ratio'] for r in rows if not np.isinf(r['profit_ratio'])]
|
||||
avg_profit_ratio = np.mean(finite_profit_ratios) if finite_profit_ratios else 0
|
||||
|
||||
# Calculate final USD and fees (no defaults)
|
||||
final_usd = np.mean([r['final_usd'] for r in rows])
|
||||
total_fees_usd = np.mean([r['total_fees_usd'] for r in rows])
|
||||
initial_usd = rows[0]['initial_usd']
|
||||
|
||||
return {
|
||||
"timeframe": timeframe,
|
||||
@ -278,7 +331,11 @@ class ResultProcessor:
|
||||
writer = csv.DictWriter(f, fieldnames=trades_fieldnames)
|
||||
writer.writeheader()
|
||||
for trade in trades:
|
||||
writer.writerow({k: trade.get(k, "") for k in trades_fieldnames})
|
||||
# Validate all required fields are present
|
||||
missing_fields = [k for k in trades_fieldnames if k not in trade]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Trade missing required fields for CSV: {missing_fields} in trade: {trade}")
|
||||
writer.writerow({k: trade[k] for k in trades_fieldnames})
|
||||
|
||||
if self.logging:
|
||||
self.logging.info(f"Trades saved to {trades_filename}")
|
||||
@ -351,4 +408,39 @@ class ResultProcessor:
|
||||
except Exception as e:
|
||||
if self.logging:
|
||||
self.logging.warning(f"Could not get price info for {date}: {e}")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
def save_all_trade_files(self, all_trades: List[Dict]) -> None:
|
||||
"""
|
||||
Save all trade files in batch after parallel execution completes
|
||||
|
||||
Args:
|
||||
all_trades: List of all trades from all tasks
|
||||
"""
|
||||
if not all_trades:
|
||||
return
|
||||
|
||||
try:
|
||||
# Group trades by timeframe and stop loss
|
||||
trade_groups = {}
|
||||
for trade in all_trades:
|
||||
timeframe = trade.get('timeframe')
|
||||
stop_loss_pct = trade.get('stop_loss_pct')
|
||||
if timeframe and stop_loss_pct is not None:
|
||||
key = (timeframe, stop_loss_pct)
|
||||
if key not in trade_groups:
|
||||
trade_groups[key] = []
|
||||
trade_groups[key].append(trade)
|
||||
|
||||
# Save each group
|
||||
for (timeframe, stop_loss_pct), trades in trade_groups.items():
|
||||
self.save_trade_file(trades, timeframe, stop_loss_pct)
|
||||
|
||||
if self.logging:
|
||||
self.logging.info(f"Saved {len(trade_groups)} trade files in batch")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to save trade files in batch: {e}"
|
||||
if self.logging:
|
||||
self.logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
Loading…
x
Reference in New Issue
Block a user