Refactor Backtest class and update strategy functions for improved modularity

- Refactored the Backtest class to encapsulate state and behavior, enhancing clarity and maintainability.
- Updated strategy functions to accept the Backtest instance, streamlining data access and manipulation.
- Introduced a new plotting method in BacktestCharts for visualizing close prices with trend indicators.
- Improved handling of meta_trend data to ensure proper visualization and trend representation.
- Adjusted main execution logic to support the new Backtest structure and enhanced debugging capabilities.
This commit is contained in:
Simon Moisy
2025-05-22 20:02:14 +08:00
parent 00873d593f
commit e5c2988d71
3 changed files with 181 additions and 198 deletions

View File

@@ -4,84 +4,85 @@ import numpy as np
from cycles.market_fees import MarketFees
class Backtest:
class Data:
def __init__(self, initial_usd, df, min1_df, init_strategy_fields) -> None:
self.initial_usd = initial_usd
self.usd = initial_usd
self.max_balance = initial_usd
self.coin = 0
self.position = 0
self.entry_price = 0
self.entry_time = None
self.current_trade_min1_start_idx = None
self.current_min1_end_idx = None
self.price_open = None
self.price_close = None
self.current_date = None
self.strategies = {}
self.df = df
self.min1_df = min1_df
def __init__(self, initial_usd, df, min1_df, init_strategy_fields) -> None:
self.initial_usd = initial_usd
self.usd = initial_usd
self.max_balance = initial_usd
self.coin = 0
self.position = 0
self.entry_price = 0
self.entry_time = None
self.current_trade_min1_start_idx = None
self.current_min1_end_idx = None
self.price_open = None
self.price_close = None
self.current_date = None
self.strategies = {}
self.df = df
self.min1_df = min1_df
self = init_strategy_fields(self)
self.trade_log = []
self.drawdowns = []
self.trades = []
@staticmethod
def run(data, entry_strategy, exit_strategy, debug=False):
self = init_strategy_fields(self)
def run(self, entry_strategy, exit_strategy, debug=False):
"""
Backtest a simple strategy using the meta supertrend (all three supertrends agree).
Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss.
Runs the backtest using provided entry and exit strategy functions.
The method iterates over the main DataFrame (self.df), simulating trades based on the entry and exit strategies. It tracks balances, drawdowns, and logs each trade, including fees. At the end, it returns a dictionary of performance statistics.
Parameters:
- min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional)
- initial_usd: float, starting USD amount
- stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%)
- debug: bool, whether to print debug info
- entry_strategy: function, determines when to enter a trade. Should accept (self, i) and return True to enter.
- exit_strategy: function, determines when to exit a trade. Should accept (self, i) and return (exit_reason, sell_price) or (None, None) to hold.
- debug: bool, whether to print debug info (default: False)
Returns:
- dict with keys: initial_usd, final_usd, n_trades, win_rate, max_drawdown, avg_trade, trade_log, trades, total_fees_usd, and optionally first_trade and last_trade.
"""
trade_log = []
drawdowns = []
trades = []
for i in range(1, len(data.df)):
data.price_open = data.df['open'].iloc[i]
data.price_close = data.df['close'].iloc[i]
for i in range(1, len(self.df)):
self.price_open = self.df['open'].iloc[i]
self.price_close = self.df['close'].iloc[i]
data.current_date = data.df['timestamp'].iloc[i]
self.current_date = self.df['timestamp'].iloc[i]
if data.position == 0:
if entry_strategy(data, i):
data, entry_log_entry = Backtest.handle_entry(data)
trade_log.append(entry_log_entry)
elif data.position == 1:
exit_test_results, data, sell_price = exit_strategy(data, i)
if self.position == 0:
if entry_strategy(self, i):
self.handle_entry()
elif self.position == 1:
exit_test_results, sell_price = exit_strategy(self, i)
if exit_test_results is not None:
data, exit_log_entry = Backtest.handle_exit(data, exit_test_results, sell_price)
trade_log.append(exit_log_entry)
self.handle_exit(exit_test_results, sell_price)
# Track drawdown
balance = data.usd if data.position == 0 else data.coin * data.price_close
balance = self.usd if self.position == 0 else self.coin * self.price_close
if balance > data.max_balance:
data.max_balance = balance
if balance > self.max_balance:
self.max_balance = balance
drawdown = (data.max_balance - balance) / data.max_balance
drawdowns.append(drawdown)
drawdown = (self.max_balance - balance) / self.max_balance
self.drawdowns.append(drawdown)
# If still in position at end, sell at last close
if data.position == 1:
data, exit_log_entry = Backtest.handle_exit(data, "EOD", None)
trade_log.append(exit_log_entry)
if self.position == 1:
self.handle_exit("EOD", None)
# Calculate statistics
final_balance = data.usd
n_trades = len(trade_log)
wins = [1 for t in trade_log if t['exit'] is not None and t['exit'] > t['entry']]
final_balance = self.usd
n_trades = len(self.trade_log)
wins = [1 for t in self.trade_log if t['exit'] is not None and t['exit'] > t['entry']]
win_rate = len(wins) / n_trades if n_trades > 0 else 0
max_drawdown = max(drawdowns) if drawdowns else 0
avg_trade = np.mean([t['exit']/t['entry']-1 for t in trade_log if t['exit'] is not None]) if trade_log else 0
max_drawdown = max(self.drawdowns) if self.drawdowns else 0
avg_trade = np.mean([t['exit']/t['entry']-1 for t in self.trade_log if t['exit'] is not None]) if self.trade_log else 0
trades = []
total_fees_usd = 0.0
for trade in trade_log:
for trade in self.trade_log:
if trade['exit'] is not None:
profit_pct = (trade['exit'] - trade['entry']) / trade['entry']
else:
@@ -99,67 +100,66 @@ class Backtest:
total_fees_usd += fee_usd
results = {
"initial_usd": data.initial_usd,
"initial_usd": self.initial_usd,
"final_usd": final_balance,
"n_trades": n_trades,
"win_rate": win_rate,
"max_drawdown": max_drawdown,
"avg_trade": avg_trade,
"trade_log": trade_log,
"trade_log": self.trade_log,
"trades": trades,
"total_fees_usd": total_fees_usd,
}
if n_trades > 0:
results["first_trade"] = {
"entry_time": trade_log[0]['entry_time'],
"entry": trade_log[0]['entry']
"entry_time": self.trade_log[0]['entry_time'],
"entry": self.trade_log[0]['entry']
}
results["last_trade"] = {
"exit_time": trade_log[-1]['exit_time'],
"exit": trade_log[-1]['exit']
"exit_time": self.trade_log[-1]['exit_time'],
"exit": self.trade_log[-1]['exit']
}
return results
@staticmethod
def handle_entry(data):
entry_fee = MarketFees.calculate_okx_taker_maker_fee(data.usd, is_maker=False)
usd_after_fee = data.usd - entry_fee
def handle_entry(self):
entry_fee = MarketFees.calculate_okx_taker_maker_fee(self.usd, is_maker=False)
usd_after_fee = self.usd - entry_fee
data.coin = usd_after_fee / data.price_open
data.entry_price = data.price_open
data.entry_time = data.current_date
data.usd = 0
data.position = 1
self.coin = usd_after_fee / self.price_open
self.entry_price = self.price_open
self.entry_time = self.current_date
self.usd = 0
self.position = 1
trade_log_entry = {
'type': 'BUY',
'entry': data.entry_price,
'entry': self.entry_price,
'exit': None,
'entry_time': data.entry_time,
'entry_time': self.entry_time,
'exit_time': None,
'fee_usd': entry_fee
}
return data, trade_log_entry
self.trade_log.append(trade_log_entry)
@staticmethod
def handle_exit(data, exit_reason, sell_price):
btc_to_sell = data.coin
exit_price = sell_price if sell_price is not None else data.price_open
def handle_exit(self, exit_reason, sell_price):
btc_to_sell = self.coin
exit_price = sell_price if sell_price is not None else self.price_open
usd_gross = btc_to_sell * exit_price
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
data.usd = usd_gross - exit_fee
self.usd = usd_gross - exit_fee
exit_log_entry = {
'type': exit_reason,
'entry': data.entry_price,
'entry': self.entry_price,
'exit': exit_price,
'entry_time': data.entry_time,
'exit_time': data.current_date,
'entry_time': self.entry_time,
'exit_time': self.current_date,
'fee_usd': exit_fee
}
data.coin = 0
data.position = 0
data.entry_price = 0
self.coin = 0
self.position = 0
self.entry_price = 0
return data, exit_log_entry
self.trade_log.append(exit_log_entry)

View File

@@ -1,86 +1,71 @@
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
class BacktestCharts:
def __init__(self, charts_dir="charts"):
self.charts_dir = charts_dir
os.makedirs(self.charts_dir, exist_ok=True)
def plot_profit_ratio_vs_stop_loss(self, results, filename="profit_ratio_vs_stop_loss.png"):
@staticmethod
def plot(df, meta_trend):
"""
Plots profit ratio vs stop loss percentage for each timeframe.
Parameters:
- results: list of dicts, each with keys: 'timeframe', 'stop_loss_pct', 'profit_ratio'
- filename: output filename (will be saved in charts_dir)
Plot close price line chart with a bar at the bottom: green when trend is 1, red when trend is 0.
The bar stays at the bottom even when zooming/panning.
- df: DataFrame with columns ['close', ...] and a datetime index or 'timestamp' column.
- meta_trend: array-like, same length as df, values 1 (green) or 0 (red).
"""
# Organize data by timeframe
from collections import defaultdict
data = defaultdict(lambda: {"stop_loss_pct": [], "profit_ratio": []})
for row in results:
tf = row["timeframe"]
data[tf]["stop_loss_pct"].append(row["stop_loss_pct"])
data[tf]["profit_ratio"].append(row["profit_ratio"])
fig, (ax_price, ax_bar) = plt.subplots(
nrows=2, ncols=1, figsize=(16, 8), sharex=True,
gridspec_kw={'height_ratios': [12, 1]}
)
plt.figure(figsize=(10, 6))
for tf, vals in data.items():
# Sort by stop_loss_pct for smooth lines
sorted_pairs = sorted(zip(vals["stop_loss_pct"], vals["profit_ratio"]))
stop_loss, profit_ratio = zip(*sorted_pairs)
plt.plot(
[s * 100 for s in stop_loss], # Convert to percent
profit_ratio,
marker="o",
label=tf
)
sns.lineplot(x=df.index, y=df['close'], label='Close Price', color='blue', ax=ax_price)
ax_price.set_title('Close Price with Trend Bar (Green=1, Red=0)')
ax_price.set_ylabel('Price')
ax_price.grid(True, alpha=0.3)
ax_price.legend()
plt.xlabel("Stop Loss (%)")
plt.ylabel("Profit Ratio")
plt.title("Profit Ratio vs Stop Loss (%) per Timeframe")
plt.legend(title="Timeframe")
plt.grid(True, linestyle="--", alpha=0.5)
plt.tight_layout()
# Clean meta_trend: ensure only 0/1, handle NaNs by forward-fill then fill remaining with 0
meta_trend_arr = np.asarray(meta_trend)
if not np.issubdtype(meta_trend_arr.dtype, np.number):
meta_trend_arr = pd.Series(meta_trend_arr).astype(float).to_numpy()
if np.isnan(meta_trend_arr).any():
meta_trend_arr = pd.Series(meta_trend_arr).fillna(method='ffill').fillna(0).astype(int).to_numpy()
else:
meta_trend_arr = meta_trend_arr.astype(int)
meta_trend_arr = np.where(meta_trend_arr != 1, 0, 1) # force only 0 or 1
if hasattr(df.index, 'to_numpy'):
x_vals = df.index.to_numpy()
else:
x_vals = np.array(df.index)
output_path = os.path.join(self.charts_dir, filename)
plt.savefig(output_path)
plt.close()
# Find contiguous regions
regions = []
start = 0
for i in range(1, len(meta_trend_arr)):
if meta_trend_arr[i] != meta_trend_arr[i-1]:
regions.append((start, i-1, meta_trend_arr[i-1]))
start = i
regions.append((start, len(meta_trend_arr)-1, meta_trend_arr[-1]))
def plot_average_trade_vs_stop_loss(self, results, filename="average_trade_vs_stop_loss.png"):
"""
Plots average trade vs stop loss percentage for each timeframe.
# Draw red vertical lines at the start of each new region (except the first)
for region_idx in range(1, len(regions)):
region_start = regions[region_idx][0]
ax_price.axvline(x=x_vals[region_start], color='black', linestyle='--', alpha=0.7, linewidth=1)
Parameters:
- results: list of dicts, each with keys: 'timeframe', 'stop_loss_pct', 'average_trade'
- filename: output filename (will be saved in charts_dir)
"""
from collections import defaultdict
data = defaultdict(lambda: {"stop_loss_pct": [], "average_trade": []})
for row in results:
tf = row["timeframe"]
if "average_trade" not in row:
continue # Skip rows without average_trade
data[tf]["stop_loss_pct"].append(row["stop_loss_pct"])
data[tf]["average_trade"].append(row["average_trade"])
for start, end, trend in regions:
color = '#089981' if trend == 1 else '#F23645'
# Offset by 1 on x: span from x_vals[start] to x_vals[end+1] if possible
x_start = x_vals[start]
x_end = x_vals[end+1] if end+1 < len(x_vals) else x_vals[end]
ax_bar.axvspan(x_start, x_end, color=color, alpha=1, ymin=0, ymax=1)
plt.figure(figsize=(10, 6))
for tf, vals in data.items():
# Sort by stop_loss_pct for smooth lines
sorted_pairs = sorted(zip(vals["stop_loss_pct"], vals["average_trade"]))
stop_loss, average_trade = zip(*sorted_pairs)
plt.plot(
[s * 100 for s in stop_loss], # Convert to percent
average_trade,
marker="o",
label=tf
)
ax_bar.set_ylim(0, 1)
ax_bar.set_yticks([])
ax_bar.set_ylabel('Trend')
ax_bar.set_xlabel('Time')
ax_bar.grid(False)
ax_bar.set_title('Meta Trend')
plt.xlabel("Stop Loss (%)")
plt.ylabel("Average Trade")
plt.title("Average Trade vs Stop Loss (%) per Timeframe")
plt.legend(title="Timeframe")
plt.grid(True, linestyle="--", alpha=0.5)
plt.tight_layout()
output_path = os.path.join(self.charts_dir, filename)
plt.savefig(output_path)
plt.close()
plt.tight_layout(h_pad=0.1)
plt.show()