Refactor Backtest class and update strategy functions for improved modularity

- Refactored the Backtest class to encapsulate state and behavior, enhancing clarity and maintainability.
- Updated strategy functions to accept the Backtest instance, streamlining data access and manipulation.
- Introduced a new plotting method in BacktestCharts for visualizing close prices with trend indicators.
- Improved handling of meta_trend data to ensure proper visualization and trend representation.
- Adjusted main execution logic to support the new Backtest structure and enhanced debugging capabilities.
This commit is contained in:
Simon Moisy 2025-05-22 20:02:14 +08:00
parent 00873d593f
commit e5c2988d71
3 changed files with 181 additions and 198 deletions

View File

@ -4,7 +4,6 @@ import numpy as np
from cycles.market_fees import MarketFees
class Backtest:
class Data:
def __init__(self, initial_usd, df, min1_df, init_strategy_fields) -> None:
self.initial_usd = initial_usd
self.usd = initial_usd
@ -22,66 +21,68 @@ class Backtest:
self.df = df
self.min1_df = min1_df
self.trade_log = []
self.drawdowns = []
self.trades = []
self = init_strategy_fields(self)
@staticmethod
def run(data, entry_strategy, exit_strategy, debug=False):
def run(self, entry_strategy, exit_strategy, debug=False):
"""
Backtest a simple strategy using the meta supertrend (all three supertrends agree).
Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss.
Runs the backtest using provided entry and exit strategy functions.
The method iterates over the main DataFrame (self.df), simulating trades based on the entry and exit strategies. It tracks balances, drawdowns, and logs each trade, including fees. At the end, it returns a dictionary of performance statistics.
Parameters:
- min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional)
- initial_usd: float, starting USD amount
- stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%)
- debug: bool, whether to print debug info
- entry_strategy: function, determines when to enter a trade. Should accept (self, i) and return True to enter.
- exit_strategy: function, determines when to exit a trade. Should accept (self, i) and return (exit_reason, sell_price) or (None, None) to hold.
- debug: bool, whether to print debug info (default: False)
Returns:
- dict with keys: initial_usd, final_usd, n_trades, win_rate, max_drawdown, avg_trade, trade_log, trades, total_fees_usd, and optionally first_trade and last_trade.
"""
trade_log = []
drawdowns = []
trades = []
for i in range(1, len(data.df)):
data.price_open = data.df['open'].iloc[i]
data.price_close = data.df['close'].iloc[i]
for i in range(1, len(self.df)):
self.price_open = self.df['open'].iloc[i]
self.price_close = self.df['close'].iloc[i]
data.current_date = data.df['timestamp'].iloc[i]
self.current_date = self.df['timestamp'].iloc[i]
if data.position == 0:
if entry_strategy(data, i):
data, entry_log_entry = Backtest.handle_entry(data)
trade_log.append(entry_log_entry)
elif data.position == 1:
exit_test_results, data, sell_price = exit_strategy(data, i)
if self.position == 0:
if entry_strategy(self, i):
self.handle_entry()
elif self.position == 1:
exit_test_results, sell_price = exit_strategy(self, i)
if exit_test_results is not None:
data, exit_log_entry = Backtest.handle_exit(data, exit_test_results, sell_price)
trade_log.append(exit_log_entry)
self.handle_exit(exit_test_results, sell_price)
# Track drawdown
balance = data.usd if data.position == 0 else data.coin * data.price_close
balance = self.usd if self.position == 0 else self.coin * self.price_close
if balance > data.max_balance:
data.max_balance = balance
if balance > self.max_balance:
self.max_balance = balance
drawdown = (data.max_balance - balance) / data.max_balance
drawdowns.append(drawdown)
drawdown = (self.max_balance - balance) / self.max_balance
self.drawdowns.append(drawdown)
# If still in position at end, sell at last close
if data.position == 1:
data, exit_log_entry = Backtest.handle_exit(data, "EOD", None)
trade_log.append(exit_log_entry)
if self.position == 1:
self.handle_exit("EOD", None)
# Calculate statistics
final_balance = data.usd
n_trades = len(trade_log)
wins = [1 for t in trade_log if t['exit'] is not None and t['exit'] > t['entry']]
final_balance = self.usd
n_trades = len(self.trade_log)
wins = [1 for t in self.trade_log if t['exit'] is not None and t['exit'] > t['entry']]
win_rate = len(wins) / n_trades if n_trades > 0 else 0
max_drawdown = max(drawdowns) if drawdowns else 0
avg_trade = np.mean([t['exit']/t['entry']-1 for t in trade_log if t['exit'] is not None]) if trade_log else 0
max_drawdown = max(self.drawdowns) if self.drawdowns else 0
avg_trade = np.mean([t['exit']/t['entry']-1 for t in self.trade_log if t['exit'] is not None]) if self.trade_log else 0
trades = []
total_fees_usd = 0.0
for trade in trade_log:
for trade in self.trade_log:
if trade['exit'] is not None:
profit_pct = (trade['exit'] - trade['entry']) / trade['entry']
else:
@ -99,67 +100,66 @@ class Backtest:
total_fees_usd += fee_usd
results = {
"initial_usd": data.initial_usd,
"initial_usd": self.initial_usd,
"final_usd": final_balance,
"n_trades": n_trades,
"win_rate": win_rate,
"max_drawdown": max_drawdown,
"avg_trade": avg_trade,
"trade_log": trade_log,
"trade_log": self.trade_log,
"trades": trades,
"total_fees_usd": total_fees_usd,
}
if n_trades > 0:
results["first_trade"] = {
"entry_time": trade_log[0]['entry_time'],
"entry": trade_log[0]['entry']
"entry_time": self.trade_log[0]['entry_time'],
"entry": self.trade_log[0]['entry']
}
results["last_trade"] = {
"exit_time": trade_log[-1]['exit_time'],
"exit": trade_log[-1]['exit']
"exit_time": self.trade_log[-1]['exit_time'],
"exit": self.trade_log[-1]['exit']
}
return results
@staticmethod
def handle_entry(data):
entry_fee = MarketFees.calculate_okx_taker_maker_fee(data.usd, is_maker=False)
usd_after_fee = data.usd - entry_fee
def handle_entry(self):
entry_fee = MarketFees.calculate_okx_taker_maker_fee(self.usd, is_maker=False)
usd_after_fee = self.usd - entry_fee
data.coin = usd_after_fee / data.price_open
data.entry_price = data.price_open
data.entry_time = data.current_date
data.usd = 0
data.position = 1
self.coin = usd_after_fee / self.price_open
self.entry_price = self.price_open
self.entry_time = self.current_date
self.usd = 0
self.position = 1
trade_log_entry = {
'type': 'BUY',
'entry': data.entry_price,
'entry': self.entry_price,
'exit': None,
'entry_time': data.entry_time,
'entry_time': self.entry_time,
'exit_time': None,
'fee_usd': entry_fee
}
return data, trade_log_entry
self.trade_log.append(trade_log_entry)
@staticmethod
def handle_exit(data, exit_reason, sell_price):
btc_to_sell = data.coin
exit_price = sell_price if sell_price is not None else data.price_open
def handle_exit(self, exit_reason, sell_price):
btc_to_sell = self.coin
exit_price = sell_price if sell_price is not None else self.price_open
usd_gross = btc_to_sell * exit_price
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
data.usd = usd_gross - exit_fee
self.usd = usd_gross - exit_fee
exit_log_entry = {
'type': exit_reason,
'entry': data.entry_price,
'entry': self.entry_price,
'exit': exit_price,
'entry_time': data.entry_time,
'exit_time': data.current_date,
'entry_time': self.entry_time,
'exit_time': self.current_date,
'fee_usd': exit_fee
}
data.coin = 0
data.position = 0
data.entry_price = 0
self.coin = 0
self.position = 0
self.entry_price = 0
self.trade_log.append(exit_log_entry)
return data, exit_log_entry

View File

@ -1,86 +1,71 @@
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
class BacktestCharts:
def __init__(self, charts_dir="charts"):
self.charts_dir = charts_dir
os.makedirs(self.charts_dir, exist_ok=True)
def plot_profit_ratio_vs_stop_loss(self, results, filename="profit_ratio_vs_stop_loss.png"):
@staticmethod
def plot(df, meta_trend):
"""
Plots profit ratio vs stop loss percentage for each timeframe.
Parameters:
- results: list of dicts, each with keys: 'timeframe', 'stop_loss_pct', 'profit_ratio'
- filename: output filename (will be saved in charts_dir)
Plot close price line chart with a bar at the bottom: green when trend is 1, red when trend is 0.
The bar stays at the bottom even when zooming/panning.
- df: DataFrame with columns ['close', ...] and a datetime index or 'timestamp' column.
- meta_trend: array-like, same length as df, values 1 (green) or 0 (red).
"""
# Organize data by timeframe
from collections import defaultdict
data = defaultdict(lambda: {"stop_loss_pct": [], "profit_ratio": []})
for row in results:
tf = row["timeframe"]
data[tf]["stop_loss_pct"].append(row["stop_loss_pct"])
data[tf]["profit_ratio"].append(row["profit_ratio"])
plt.figure(figsize=(10, 6))
for tf, vals in data.items():
# Sort by stop_loss_pct for smooth lines
sorted_pairs = sorted(zip(vals["stop_loss_pct"], vals["profit_ratio"]))
stop_loss, profit_ratio = zip(*sorted_pairs)
plt.plot(
[s * 100 for s in stop_loss], # Convert to percent
profit_ratio,
marker="o",
label=tf
fig, (ax_price, ax_bar) = plt.subplots(
nrows=2, ncols=1, figsize=(16, 8), sharex=True,
gridspec_kw={'height_ratios': [12, 1]}
)
plt.xlabel("Stop Loss (%)")
plt.ylabel("Profit Ratio")
plt.title("Profit Ratio vs Stop Loss (%) per Timeframe")
plt.legend(title="Timeframe")
plt.grid(True, linestyle="--", alpha=0.5)
plt.tight_layout()
sns.lineplot(x=df.index, y=df['close'], label='Close Price', color='blue', ax=ax_price)
ax_price.set_title('Close Price with Trend Bar (Green=1, Red=0)')
ax_price.set_ylabel('Price')
ax_price.grid(True, alpha=0.3)
ax_price.legend()
output_path = os.path.join(self.charts_dir, filename)
plt.savefig(output_path)
plt.close()
# Clean meta_trend: ensure only 0/1, handle NaNs by forward-fill then fill remaining with 0
meta_trend_arr = np.asarray(meta_trend)
if not np.issubdtype(meta_trend_arr.dtype, np.number):
meta_trend_arr = pd.Series(meta_trend_arr).astype(float).to_numpy()
if np.isnan(meta_trend_arr).any():
meta_trend_arr = pd.Series(meta_trend_arr).fillna(method='ffill').fillna(0).astype(int).to_numpy()
else:
meta_trend_arr = meta_trend_arr.astype(int)
meta_trend_arr = np.where(meta_trend_arr != 1, 0, 1) # force only 0 or 1
if hasattr(df.index, 'to_numpy'):
x_vals = df.index.to_numpy()
else:
x_vals = np.array(df.index)
def plot_average_trade_vs_stop_loss(self, results, filename="average_trade_vs_stop_loss.png"):
"""
Plots average trade vs stop loss percentage for each timeframe.
# Find contiguous regions
regions = []
start = 0
for i in range(1, len(meta_trend_arr)):
if meta_trend_arr[i] != meta_trend_arr[i-1]:
regions.append((start, i-1, meta_trend_arr[i-1]))
start = i
regions.append((start, len(meta_trend_arr)-1, meta_trend_arr[-1]))
Parameters:
- results: list of dicts, each with keys: 'timeframe', 'stop_loss_pct', 'average_trade'
- filename: output filename (will be saved in charts_dir)
"""
from collections import defaultdict
data = defaultdict(lambda: {"stop_loss_pct": [], "average_trade": []})
for row in results:
tf = row["timeframe"]
if "average_trade" not in row:
continue # Skip rows without average_trade
data[tf]["stop_loss_pct"].append(row["stop_loss_pct"])
data[tf]["average_trade"].append(row["average_trade"])
# Draw red vertical lines at the start of each new region (except the first)
for region_idx in range(1, len(regions)):
region_start = regions[region_idx][0]
ax_price.axvline(x=x_vals[region_start], color='black', linestyle='--', alpha=0.7, linewidth=1)
plt.figure(figsize=(10, 6))
for tf, vals in data.items():
# Sort by stop_loss_pct for smooth lines
sorted_pairs = sorted(zip(vals["stop_loss_pct"], vals["average_trade"]))
stop_loss, average_trade = zip(*sorted_pairs)
plt.plot(
[s * 100 for s in stop_loss], # Convert to percent
average_trade,
marker="o",
label=tf
)
for start, end, trend in regions:
color = '#089981' if trend == 1 else '#F23645'
# Offset by 1 on x: span from x_vals[start] to x_vals[end+1] if possible
x_start = x_vals[start]
x_end = x_vals[end+1] if end+1 < len(x_vals) else x_vals[end]
ax_bar.axvspan(x_start, x_end, color=color, alpha=1, ymin=0, ymax=1)
plt.xlabel("Stop Loss (%)")
plt.ylabel("Average Trade")
plt.title("Average Trade vs Stop Loss (%) per Timeframe")
plt.legend(title="Timeframe")
plt.grid(True, linestyle="--", alpha=0.5)
plt.tight_layout()
ax_bar.set_ylim(0, 1)
ax_bar.set_yticks([])
ax_bar.set_ylabel('Trend')
ax_bar.set_xlabel('Time')
ax_bar.grid(False)
ax_bar.set_title('Meta Trend')
plt.tight_layout(h_pad=0.1)
plt.show()
output_path = os.path.join(self.charts_dir, filename)
plt.savefig(output_path)
plt.close()

74
main.py
View File

@ -11,6 +11,7 @@ from cycles.utils.storage import Storage
from cycles.utils.system import SystemUtils
from cycles.backtest import Backtest
from cycles.Analysis.supertrend import Supertrends
from cycles.charts import BacktestCharts
logging.basicConfig(
level=logging.INFO,
@ -21,8 +22,8 @@ logging.basicConfig(
]
)
def default_init_strategy(data: Backtest.Data) -> Backtest.Data:
supertrends = Supertrends(data.df, verbose=False)
def default_init_strategy(backtester: Backtest):
supertrends = Supertrends(backtester.df, verbose=False)
supertrend_results_list = supertrends.calculate_supertrend_indicators()
trends = [st['results']['trend'] for st in supertrend_results_list]
@ -30,31 +31,25 @@ def default_init_strategy(data: Backtest.Data) -> Backtest.Data:
meta_trend = np.where((trends_arr[:,0] == trends_arr[:,1]) & (trends_arr[:,1] == trends_arr[:,2]),
trends_arr[:,0], 0)
data.strategies["meta_trend"] = meta_trend
backtester.strategies["meta_trend"] = meta_trend
return data
def default_entry_strategy(backtester: Backtest, df_index):
return backtester.strategies["meta_trend"][df_index - 1] != 1 and backtester.strategies["meta_trend"][df_index] == 1
def default_entry_strategy(data, df_index):
return data.strategies["meta_trend"][df_index - 1] != 1 and data.strategies["meta_trend"][df_index] == 1
def stop_loss_strategy(backtester: Backtest):
stop_price = backtester.entry_price * (1 - backtester.strategies["stop_loss_pct"])
def stop_loss_strategy(data):
stop_price = data.entry_price * (1 - data.strategies["stop_loss_pct"])
min1_index = backtester.min1_df.index
start_candidates = min1_index[min1_index >= backtester.entry_time]
backtester.current_trade_min1_start_idx = start_candidates[0]
end_candidates = min1_index[min1_index <= backtester.current_date]
# Ensure index is sorted and is a DatetimeIndex
min1_index = data.min1_df.index
# Find the first index >= entry_time
start_candidates = min1_index[min1_index >= data.entry_time]
data.current_trade_min1_start_idx = start_candidates[0]
# Find the last index <= current_date
end_candidates = min1_index[min1_index <= data.current_date]
if len(end_candidates) == 0:
print("Warning: no end candidate here. Need to be checked")
return False, None
data.current_min1_end_idx = end_candidates[-1]
backtester.current_min1_end_idx = end_candidates[-1]
min1_slice = data.min1_df.loc[data.current_trade_min1_start_idx:data.current_min1_end_idx]
min1_slice = backtester.min1_df.loc[backtester.current_trade_min1_start_idx:backtester.current_min1_end_idx]
# print(f"lowest low in that range: {min1_slice['low'].min()}, count: {len(min1_slice)}")
# print(f"slice start: {min1_slice.index[0]}, slice end: {min1_slice.index[-1]}")
@ -70,18 +65,18 @@ def stop_loss_strategy(data):
return False, None
def default_exit_strategy(data: Backtest.Data, df_index):
if data.strategies["meta_trend"][df_index - 1] != 1 and \
data.strategies["meta_trend"][df_index] == -1:
return "META_TREND_EXIT_SIGNAL", data, None
def default_exit_strategy(backtester: Backtest, df_index):
if backtester.strategies["meta_trend"][df_index - 1] != 1 and \
backtester.strategies["meta_trend"][df_index] == -1:
return "META_TREND_EXIT_SIGNAL", None
stop_loss_result, sell_price = stop_loss_strategy(data)
stop_loss_result, sell_price = stop_loss_strategy(backtester)
if stop_loss_result:
data.strategies["current_trade_min1_start_idx"] = \
data.min1_df.index[data.min1_df.index <= data.current_date][-1]
return "STOP_LOSS", data, sell_price
backtester.strategies["current_trade_min1_start_idx"] = \
backtester.min1_df.index[backtester.min1_df.index <= backtester.current_date][-1]
return "STOP_LOSS", sell_price
return None, data, None
return None, None
def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd, debug=False):
"""Process the entire timeframe with all stop loss values (no monthly split)"""
@ -93,11 +88,10 @@ def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd,
min1_df['timestamp'] = pd.to_datetime(min1_df.index) # need ?
for stop_loss_pct in stop_loss_pcts:
data = Backtest.Data(initial_usd, df, min1_df, default_init_strategy)
data.strategies["stop_loss_pct"] = stop_loss_pct
backtester = Backtest(initial_usd, df, min1_df, default_init_strategy)
backtester.strategies["stop_loss_pct"] = stop_loss_pct
results = Backtest.run(
data,
results = backtester.run(
default_entry_strategy,
default_exit_strategy,
debug
@ -164,8 +158,12 @@ def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd,
logging.info(f"Timeframe: {rule_name}, Stop Loss: {stop_loss_pct}, Trades: {n_trades}")
if debug:
for trade in trades:
print(trade)
# Plot after each backtest run
try:
meta_trend = backtester.strategies["meta_trend"]
BacktestCharts.plot(df, meta_trend)
except Exception as e:
print(f"Plotting failed: {e}")
return results_rows, trade_rows
@ -235,7 +233,7 @@ def get_nearest_price(df, target_date):
return nearest_time, price
if __name__ == "__main__":
debug = False
debug = True
parser = argparse.ArgumentParser(description="Run backtest with config file.")
parser.add_argument("config", type=str, nargs="?", help="Path to config JSON file.")
@ -243,7 +241,7 @@ if __name__ == "__main__":
# Default values (from config.json)
default_config = {
"start_date": "2024-05-15",
"start_date": "2025-05-01",
"stop_date": datetime.datetime.today().strftime('%Y-%m-%d'),
"initial_usd": 10000,
"timeframes": ["15min"],
@ -306,8 +304,6 @@ if __name__ == "__main__":
for stop_loss_pct in stop_loss_pcts
]
workers = system_utils.get_optimal_workers()
if debug:
all_results_rows = []
all_trade_rows = []
@ -317,6 +313,8 @@ if __name__ == "__main__":
all_results_rows.extend(results)
all_trade_rows.extend(trades)
else:
workers = system_utils.get_optimal_workers()
with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
futures = {executor.submit(process, task, debug): task for task in tasks}
all_results_rows = []