Refactor Backtest class and update strategy functions for improved modularity

- Refactored the Backtest class to encapsulate state and behavior, enhancing clarity and maintainability.
- Updated strategy functions to accept the Backtest instance, streamlining data access and manipulation.
- Introduced a new plotting method in BacktestCharts for visualizing close prices with trend indicators.
- Improved handling of meta_trend data to ensure proper visualization and trend representation.
- Adjusted main execution logic to support the new Backtest structure and enhanced debugging capabilities.
This commit is contained in:
Simon Moisy 2025-05-22 20:02:14 +08:00
parent 00873d593f
commit e5c2988d71
3 changed files with 181 additions and 198 deletions

View File

@ -4,84 +4,85 @@ import numpy as np
from cycles.market_fees import MarketFees from cycles.market_fees import MarketFees
class Backtest: class Backtest:
class Data: def __init__(self, initial_usd, df, min1_df, init_strategy_fields) -> None:
def __init__(self, initial_usd, df, min1_df, init_strategy_fields) -> None: self.initial_usd = initial_usd
self.initial_usd = initial_usd self.usd = initial_usd
self.usd = initial_usd self.max_balance = initial_usd
self.max_balance = initial_usd self.coin = 0
self.coin = 0 self.position = 0
self.position = 0 self.entry_price = 0
self.entry_price = 0 self.entry_time = None
self.entry_time = None self.current_trade_min1_start_idx = None
self.current_trade_min1_start_idx = None self.current_min1_end_idx = None
self.current_min1_end_idx = None self.price_open = None
self.price_open = None self.price_close = None
self.price_close = None self.current_date = None
self.current_date = None self.strategies = {}
self.strategies = {} self.df = df
self.df = df self.min1_df = min1_df
self.min1_df = min1_df
self = init_strategy_fields(self) self.trade_log = []
self.drawdowns = []
self.trades = []
@staticmethod self = init_strategy_fields(self)
def run(data, entry_strategy, exit_strategy, debug=False):
def run(self, entry_strategy, exit_strategy, debug=False):
""" """
Backtest a simple strategy using the meta supertrend (all three supertrends agree). Runs the backtest using provided entry and exit strategy functions.
Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss.
The method iterates over the main DataFrame (self.df), simulating trades based on the entry and exit strategies. It tracks balances, drawdowns, and logs each trade, including fees. At the end, it returns a dictionary of performance statistics.
Parameters: Parameters:
- min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional) - entry_strategy: function, determines when to enter a trade. Should accept (self, i) and return True to enter.
- initial_usd: float, starting USD amount - exit_strategy: function, determines when to exit a trade. Should accept (self, i) and return (exit_reason, sell_price) or (None, None) to hold.
- stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%) - debug: bool, whether to print debug info (default: False)
- debug: bool, whether to print debug info
Returns:
- dict with keys: initial_usd, final_usd, n_trades, win_rate, max_drawdown, avg_trade, trade_log, trades, total_fees_usd, and optionally first_trade and last_trade.
""" """
trade_log = []
drawdowns = []
trades = []
for i in range(1, len(data.df)): for i in range(1, len(self.df)):
data.price_open = data.df['open'].iloc[i] self.price_open = self.df['open'].iloc[i]
data.price_close = data.df['close'].iloc[i] self.price_close = self.df['close'].iloc[i]
data.current_date = data.df['timestamp'].iloc[i] self.current_date = self.df['timestamp'].iloc[i]
if data.position == 0: if self.position == 0:
if entry_strategy(data, i): if entry_strategy(self, i):
data, entry_log_entry = Backtest.handle_entry(data) self.handle_entry()
trade_log.append(entry_log_entry) elif self.position == 1:
elif data.position == 1: exit_test_results, sell_price = exit_strategy(self, i)
exit_test_results, data, sell_price = exit_strategy(data, i)
if exit_test_results is not None: if exit_test_results is not None:
data, exit_log_entry = Backtest.handle_exit(data, exit_test_results, sell_price) self.handle_exit(exit_test_results, sell_price)
trade_log.append(exit_log_entry)
# Track drawdown # Track drawdown
balance = data.usd if data.position == 0 else data.coin * data.price_close balance = self.usd if self.position == 0 else self.coin * self.price_close
if balance > data.max_balance: if balance > self.max_balance:
data.max_balance = balance self.max_balance = balance
drawdown = (data.max_balance - balance) / data.max_balance drawdown = (self.max_balance - balance) / self.max_balance
drawdowns.append(drawdown) self.drawdowns.append(drawdown)
# If still in position at end, sell at last close # If still in position at end, sell at last close
if data.position == 1: if self.position == 1:
data, exit_log_entry = Backtest.handle_exit(data, "EOD", None) self.handle_exit("EOD", None)
trade_log.append(exit_log_entry)
# Calculate statistics # Calculate statistics
final_balance = data.usd final_balance = self.usd
n_trades = len(trade_log) n_trades = len(self.trade_log)
wins = [1 for t in trade_log if t['exit'] is not None and t['exit'] > t['entry']] wins = [1 for t in self.trade_log if t['exit'] is not None and t['exit'] > t['entry']]
win_rate = len(wins) / n_trades if n_trades > 0 else 0 win_rate = len(wins) / n_trades if n_trades > 0 else 0
max_drawdown = max(drawdowns) if drawdowns else 0 max_drawdown = max(self.drawdowns) if self.drawdowns else 0
avg_trade = np.mean([t['exit']/t['entry']-1 for t in trade_log if t['exit'] is not None]) if trade_log else 0 avg_trade = np.mean([t['exit']/t['entry']-1 for t in self.trade_log if t['exit'] is not None]) if self.trade_log else 0
trades = [] trades = []
total_fees_usd = 0.0 total_fees_usd = 0.0
for trade in trade_log:
for trade in self.trade_log:
if trade['exit'] is not None: if trade['exit'] is not None:
profit_pct = (trade['exit'] - trade['entry']) / trade['entry'] profit_pct = (trade['exit'] - trade['entry']) / trade['entry']
else: else:
@ -99,67 +100,66 @@ class Backtest:
total_fees_usd += fee_usd total_fees_usd += fee_usd
results = { results = {
"initial_usd": data.initial_usd, "initial_usd": self.initial_usd,
"final_usd": final_balance, "final_usd": final_balance,
"n_trades": n_trades, "n_trades": n_trades,
"win_rate": win_rate, "win_rate": win_rate,
"max_drawdown": max_drawdown, "max_drawdown": max_drawdown,
"avg_trade": avg_trade, "avg_trade": avg_trade,
"trade_log": trade_log, "trade_log": self.trade_log,
"trades": trades, "trades": trades,
"total_fees_usd": total_fees_usd, "total_fees_usd": total_fees_usd,
} }
if n_trades > 0: if n_trades > 0:
results["first_trade"] = { results["first_trade"] = {
"entry_time": trade_log[0]['entry_time'], "entry_time": self.trade_log[0]['entry_time'],
"entry": trade_log[0]['entry'] "entry": self.trade_log[0]['entry']
} }
results["last_trade"] = { results["last_trade"] = {
"exit_time": trade_log[-1]['exit_time'], "exit_time": self.trade_log[-1]['exit_time'],
"exit": trade_log[-1]['exit'] "exit": self.trade_log[-1]['exit']
} }
return results return results
@staticmethod def handle_entry(self):
def handle_entry(data): entry_fee = MarketFees.calculate_okx_taker_maker_fee(self.usd, is_maker=False)
entry_fee = MarketFees.calculate_okx_taker_maker_fee(data.usd, is_maker=False) usd_after_fee = self.usd - entry_fee
usd_after_fee = data.usd - entry_fee
data.coin = usd_after_fee / data.price_open self.coin = usd_after_fee / self.price_open
data.entry_price = data.price_open self.entry_price = self.price_open
data.entry_time = data.current_date self.entry_time = self.current_date
data.usd = 0 self.usd = 0
data.position = 1 self.position = 1
trade_log_entry = { trade_log_entry = {
'type': 'BUY', 'type': 'BUY',
'entry': data.entry_price, 'entry': self.entry_price,
'exit': None, 'exit': None,
'entry_time': data.entry_time, 'entry_time': self.entry_time,
'exit_time': None, 'exit_time': None,
'fee_usd': entry_fee 'fee_usd': entry_fee
} }
return data, trade_log_entry self.trade_log.append(trade_log_entry)
@staticmethod def handle_exit(self, exit_reason, sell_price):
def handle_exit(data, exit_reason, sell_price): btc_to_sell = self.coin
btc_to_sell = data.coin exit_price = sell_price if sell_price is not None else self.price_open
exit_price = sell_price if sell_price is not None else data.price_open
usd_gross = btc_to_sell * exit_price usd_gross = btc_to_sell * exit_price
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False) exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
data.usd = usd_gross - exit_fee self.usd = usd_gross - exit_fee
exit_log_entry = { exit_log_entry = {
'type': exit_reason, 'type': exit_reason,
'entry': data.entry_price, 'entry': self.entry_price,
'exit': exit_price, 'exit': exit_price,
'entry_time': data.entry_time, 'entry_time': self.entry_time,
'exit_time': data.current_date, 'exit_time': self.current_date,
'fee_usd': exit_fee 'fee_usd': exit_fee
} }
data.coin = 0 self.coin = 0
data.position = 0 self.position = 0
data.entry_price = 0 self.entry_price = 0
return data, exit_log_entry self.trade_log.append(exit_log_entry)

View File

@ -1,86 +1,71 @@
import os import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
class BacktestCharts: class BacktestCharts:
def __init__(self, charts_dir="charts"): @staticmethod
self.charts_dir = charts_dir def plot(df, meta_trend):
os.makedirs(self.charts_dir, exist_ok=True)
def plot_profit_ratio_vs_stop_loss(self, results, filename="profit_ratio_vs_stop_loss.png"):
""" """
Plots profit ratio vs stop loss percentage for each timeframe. Plot close price line chart with a bar at the bottom: green when trend is 1, red when trend is 0.
The bar stays at the bottom even when zooming/panning.
Parameters: - df: DataFrame with columns ['close', ...] and a datetime index or 'timestamp' column.
- results: list of dicts, each with keys: 'timeframe', 'stop_loss_pct', 'profit_ratio' - meta_trend: array-like, same length as df, values 1 (green) or 0 (red).
- filename: output filename (will be saved in charts_dir)
""" """
# Organize data by timeframe fig, (ax_price, ax_bar) = plt.subplots(
from collections import defaultdict nrows=2, ncols=1, figsize=(16, 8), sharex=True,
data = defaultdict(lambda: {"stop_loss_pct": [], "profit_ratio": []}) gridspec_kw={'height_ratios': [12, 1]}
for row in results: )
tf = row["timeframe"]
data[tf]["stop_loss_pct"].append(row["stop_loss_pct"])
data[tf]["profit_ratio"].append(row["profit_ratio"])
plt.figure(figsize=(10, 6)) sns.lineplot(x=df.index, y=df['close'], label='Close Price', color='blue', ax=ax_price)
for tf, vals in data.items(): ax_price.set_title('Close Price with Trend Bar (Green=1, Red=0)')
# Sort by stop_loss_pct for smooth lines ax_price.set_ylabel('Price')
sorted_pairs = sorted(zip(vals["stop_loss_pct"], vals["profit_ratio"])) ax_price.grid(True, alpha=0.3)
stop_loss, profit_ratio = zip(*sorted_pairs) ax_price.legend()
plt.plot(
[s * 100 for s in stop_loss], # Convert to percent
profit_ratio,
marker="o",
label=tf
)
plt.xlabel("Stop Loss (%)") # Clean meta_trend: ensure only 0/1, handle NaNs by forward-fill then fill remaining with 0
plt.ylabel("Profit Ratio") meta_trend_arr = np.asarray(meta_trend)
plt.title("Profit Ratio vs Stop Loss (%) per Timeframe") if not np.issubdtype(meta_trend_arr.dtype, np.number):
plt.legend(title="Timeframe") meta_trend_arr = pd.Series(meta_trend_arr).astype(float).to_numpy()
plt.grid(True, linestyle="--", alpha=0.5) if np.isnan(meta_trend_arr).any():
plt.tight_layout() meta_trend_arr = pd.Series(meta_trend_arr).fillna(method='ffill').fillna(0).astype(int).to_numpy()
else:
meta_trend_arr = meta_trend_arr.astype(int)
meta_trend_arr = np.where(meta_trend_arr != 1, 0, 1) # force only 0 or 1
if hasattr(df.index, 'to_numpy'):
x_vals = df.index.to_numpy()
else:
x_vals = np.array(df.index)
output_path = os.path.join(self.charts_dir, filename) # Find contiguous regions
plt.savefig(output_path) regions = []
plt.close() start = 0
for i in range(1, len(meta_trend_arr)):
if meta_trend_arr[i] != meta_trend_arr[i-1]:
regions.append((start, i-1, meta_trend_arr[i-1]))
start = i
regions.append((start, len(meta_trend_arr)-1, meta_trend_arr[-1]))
def plot_average_trade_vs_stop_loss(self, results, filename="average_trade_vs_stop_loss.png"): # Draw red vertical lines at the start of each new region (except the first)
""" for region_idx in range(1, len(regions)):
Plots average trade vs stop loss percentage for each timeframe. region_start = regions[region_idx][0]
ax_price.axvline(x=x_vals[region_start], color='black', linestyle='--', alpha=0.7, linewidth=1)
Parameters: for start, end, trend in regions:
- results: list of dicts, each with keys: 'timeframe', 'stop_loss_pct', 'average_trade' color = '#089981' if trend == 1 else '#F23645'
- filename: output filename (will be saved in charts_dir) # Offset by 1 on x: span from x_vals[start] to x_vals[end+1] if possible
""" x_start = x_vals[start]
from collections import defaultdict x_end = x_vals[end+1] if end+1 < len(x_vals) else x_vals[end]
data = defaultdict(lambda: {"stop_loss_pct": [], "average_trade": []}) ax_bar.axvspan(x_start, x_end, color=color, alpha=1, ymin=0, ymax=1)
for row in results:
tf = row["timeframe"]
if "average_trade" not in row:
continue # Skip rows without average_trade
data[tf]["stop_loss_pct"].append(row["stop_loss_pct"])
data[tf]["average_trade"].append(row["average_trade"])
plt.figure(figsize=(10, 6)) ax_bar.set_ylim(0, 1)
for tf, vals in data.items(): ax_bar.set_yticks([])
# Sort by stop_loss_pct for smooth lines ax_bar.set_ylabel('Trend')
sorted_pairs = sorted(zip(vals["stop_loss_pct"], vals["average_trade"])) ax_bar.set_xlabel('Time')
stop_loss, average_trade = zip(*sorted_pairs) ax_bar.grid(False)
plt.plot( ax_bar.set_title('Meta Trend')
[s * 100 for s in stop_loss], # Convert to percent
average_trade,
marker="o",
label=tf
)
plt.xlabel("Stop Loss (%)") plt.tight_layout(h_pad=0.1)
plt.ylabel("Average Trade") plt.show()
plt.title("Average Trade vs Stop Loss (%) per Timeframe")
plt.legend(title="Timeframe")
plt.grid(True, linestyle="--", alpha=0.5)
plt.tight_layout()
output_path = os.path.join(self.charts_dir, filename)
plt.savefig(output_path)
plt.close()

76
main.py
View File

@ -11,6 +11,7 @@ from cycles.utils.storage import Storage
from cycles.utils.system import SystemUtils from cycles.utils.system import SystemUtils
from cycles.backtest import Backtest from cycles.backtest import Backtest
from cycles.Analysis.supertrend import Supertrends from cycles.Analysis.supertrend import Supertrends
from cycles.charts import BacktestCharts
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@ -21,8 +22,8 @@ logging.basicConfig(
] ]
) )
def default_init_strategy(data: Backtest.Data) -> Backtest.Data: def default_init_strategy(backtester: Backtest):
supertrends = Supertrends(data.df, verbose=False) supertrends = Supertrends(backtester.df, verbose=False)
supertrend_results_list = supertrends.calculate_supertrend_indicators() supertrend_results_list = supertrends.calculate_supertrend_indicators()
trends = [st['results']['trend'] for st in supertrend_results_list] trends = [st['results']['trend'] for st in supertrend_results_list]
@ -30,31 +31,25 @@ def default_init_strategy(data: Backtest.Data) -> Backtest.Data:
meta_trend = np.where((trends_arr[:,0] == trends_arr[:,1]) & (trends_arr[:,1] == trends_arr[:,2]), meta_trend = np.where((trends_arr[:,0] == trends_arr[:,1]) & (trends_arr[:,1] == trends_arr[:,2]),
trends_arr[:,0], 0) trends_arr[:,0], 0)
data.strategies["meta_trend"] = meta_trend backtester.strategies["meta_trend"] = meta_trend
return data
def default_entry_strategy(data, df_index): def default_entry_strategy(backtester: Backtest, df_index):
return data.strategies["meta_trend"][df_index - 1] != 1 and data.strategies["meta_trend"][df_index] == 1 return backtester.strategies["meta_trend"][df_index - 1] != 1 and backtester.strategies["meta_trend"][df_index] == 1
def stop_loss_strategy(data): def stop_loss_strategy(backtester: Backtest):
stop_price = data.entry_price * (1 - data.strategies["stop_loss_pct"]) stop_price = backtester.entry_price * (1 - backtester.strategies["stop_loss_pct"])
# Ensure index is sorted and is a DatetimeIndex min1_index = backtester.min1_df.index
min1_index = data.min1_df.index start_candidates = min1_index[min1_index >= backtester.entry_time]
backtester.current_trade_min1_start_idx = start_candidates[0]
end_candidates = min1_index[min1_index <= backtester.current_date]
# Find the first index >= entry_time
start_candidates = min1_index[min1_index >= data.entry_time]
data.current_trade_min1_start_idx = start_candidates[0]
# Find the last index <= current_date
end_candidates = min1_index[min1_index <= data.current_date]
if len(end_candidates) == 0: if len(end_candidates) == 0:
print("Warning: no end candidate here. Need to be checked") print("Warning: no end candidate here. Need to be checked")
return False, None return False, None
data.current_min1_end_idx = end_candidates[-1] backtester.current_min1_end_idx = end_candidates[-1]
min1_slice = data.min1_df.loc[data.current_trade_min1_start_idx:data.current_min1_end_idx] min1_slice = backtester.min1_df.loc[backtester.current_trade_min1_start_idx:backtester.current_min1_end_idx]
# print(f"lowest low in that range: {min1_slice['low'].min()}, count: {len(min1_slice)}") # print(f"lowest low in that range: {min1_slice['low'].min()}, count: {len(min1_slice)}")
# print(f"slice start: {min1_slice.index[0]}, slice end: {min1_slice.index[-1]}") # print(f"slice start: {min1_slice.index[0]}, slice end: {min1_slice.index[-1]}")
@ -70,18 +65,18 @@ def stop_loss_strategy(data):
return False, None return False, None
def default_exit_strategy(data: Backtest.Data, df_index): def default_exit_strategy(backtester: Backtest, df_index):
if data.strategies["meta_trend"][df_index - 1] != 1 and \ if backtester.strategies["meta_trend"][df_index - 1] != 1 and \
data.strategies["meta_trend"][df_index] == -1: backtester.strategies["meta_trend"][df_index] == -1:
return "META_TREND_EXIT_SIGNAL", data, None return "META_TREND_EXIT_SIGNAL", None
stop_loss_result, sell_price = stop_loss_strategy(data) stop_loss_result, sell_price = stop_loss_strategy(backtester)
if stop_loss_result: if stop_loss_result:
data.strategies["current_trade_min1_start_idx"] = \ backtester.strategies["current_trade_min1_start_idx"] = \
data.min1_df.index[data.min1_df.index <= data.current_date][-1] backtester.min1_df.index[backtester.min1_df.index <= backtester.current_date][-1]
return "STOP_LOSS", data, sell_price return "STOP_LOSS", sell_price
return None, data, None return None, None
def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd, debug=False): def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd, debug=False):
"""Process the entire timeframe with all stop loss values (no monthly split)""" """Process the entire timeframe with all stop loss values (no monthly split)"""
@ -90,14 +85,13 @@ def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd,
results_rows = [] results_rows = []
trade_rows = [] trade_rows = []
min1_df['timestamp'] = pd.to_datetime(min1_df.index) # need ? min1_df['timestamp'] = pd.to_datetime(min1_df.index) # need ?
for stop_loss_pct in stop_loss_pcts: for stop_loss_pct in stop_loss_pcts:
data = Backtest.Data(initial_usd, df, min1_df, default_init_strategy) backtester = Backtest(initial_usd, df, min1_df, default_init_strategy)
data.strategies["stop_loss_pct"] = stop_loss_pct backtester.strategies["stop_loss_pct"] = stop_loss_pct
results = Backtest.run( results = backtester.run(
data,
default_entry_strategy, default_entry_strategy,
default_exit_strategy, default_exit_strategy,
debug debug
@ -164,8 +158,12 @@ def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd,
logging.info(f"Timeframe: {rule_name}, Stop Loss: {stop_loss_pct}, Trades: {n_trades}") logging.info(f"Timeframe: {rule_name}, Stop Loss: {stop_loss_pct}, Trades: {n_trades}")
if debug: if debug:
for trade in trades: # Plot after each backtest run
print(trade) try:
meta_trend = backtester.strategies["meta_trend"]
BacktestCharts.plot(df, meta_trend)
except Exception as e:
print(f"Plotting failed: {e}")
return results_rows, trade_rows return results_rows, trade_rows
@ -235,7 +233,7 @@ def get_nearest_price(df, target_date):
return nearest_time, price return nearest_time, price
if __name__ == "__main__": if __name__ == "__main__":
debug = False debug = True
parser = argparse.ArgumentParser(description="Run backtest with config file.") parser = argparse.ArgumentParser(description="Run backtest with config file.")
parser.add_argument("config", type=str, nargs="?", help="Path to config JSON file.") parser.add_argument("config", type=str, nargs="?", help="Path to config JSON file.")
@ -243,7 +241,7 @@ if __name__ == "__main__":
# Default values (from config.json) # Default values (from config.json)
default_config = { default_config = {
"start_date": "2024-05-15", "start_date": "2025-05-01",
"stop_date": datetime.datetime.today().strftime('%Y-%m-%d'), "stop_date": datetime.datetime.today().strftime('%Y-%m-%d'),
"initial_usd": 10000, "initial_usd": 10000,
"timeframes": ["15min"], "timeframes": ["15min"],
@ -306,8 +304,6 @@ if __name__ == "__main__":
for stop_loss_pct in stop_loss_pcts for stop_loss_pct in stop_loss_pcts
] ]
workers = system_utils.get_optimal_workers()
if debug: if debug:
all_results_rows = [] all_results_rows = []
all_trade_rows = [] all_trade_rows = []
@ -317,6 +313,8 @@ if __name__ == "__main__":
all_results_rows.extend(results) all_results_rows.extend(results)
all_trade_rows.extend(trades) all_trade_rows.extend(trades)
else: else:
workers = system_utils.get_optimal_workers()
with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor: with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
futures = {executor.submit(process, task, debug): task for task in tasks} futures = {executor.submit(process, task, debug): task for task in tasks}
all_results_rows = [] all_results_rows = []