Refactor backtesting logic and introduce new components

- Replaced TrendDetectorSimple with a new Backtest class for improved backtesting functionality.
- Integrated argparse for configuration file input, allowing dynamic parameter setting.
- Added MarketFees and Supertrends classes to handle fee calculations and trend detection, respectively.
- Removed deprecated main_debug.py and trend_detector_simple.py files to streamline the codebase.
- Enhanced process_timeframe_data to utilize the new Backtest class for executing trades and calculating results.
- Updated Storage class to support writing backtest results with metadata.
This commit is contained in:
Simon Moisy 2025-05-21 17:03:34 +08:00
parent 14905017c8
commit 806697116d
9 changed files with 650 additions and 1134 deletions

222
cycles/backtest.py Normal file
View File

@ -0,0 +1,222 @@
import pandas as pd
import numpy as np
from cycles.supertrend import Supertrends
from cycles.market_fees import MarketFees
class Backtest:
@staticmethod
def run(min1_df, df, initial_usd, stop_loss_pct, debug=False):
"""
Backtest a simple strategy using the meta supertrend (all three supertrends agree).
Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss.
Parameters:
- min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional)
- initial_usd: float, starting USD amount
- stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%)
- debug: bool, whether to print debug info
"""
_df = df.copy().reset_index(drop=True)
_df['timestamp'] = pd.to_datetime(_df['timestamp'])
supertrends = Supertrends(_df, verbose=False)
supertrend_results_list = supertrends.calculate_supertrend_indicators()
trends = [st['results']['trend'] for st in supertrend_results_list]
trends_arr = np.stack(trends, axis=1)
meta_trend = np.where((trends_arr[:,0] == trends_arr[:,1]) & (trends_arr[:,1] == trends_arr[:,2]),
trends_arr[:,0], 0)
position = 0 # 0 = no position, 1 = long
entry_price = 0
usd = initial_usd
coin = 0
trade_log = []
max_balance = initial_usd
drawdowns = []
trades = []
entry_time = None
current_trade_min1_start_idx = None
min1_df['timestamp'] = pd.to_datetime(min1_df.index)
for i in range(1, len(_df)):
price_open = _df['open'].iloc[i]
price_close = _df['close'].iloc[i]
date = _df['timestamp'].iloc[i]
prev_mt = meta_trend[i-1]
curr_mt = meta_trend[i]
# Check stop loss if in position
if position == 1:
stop_loss_result = Backtest.check_stop_loss(
min1_df,
entry_time,
date,
entry_price,
stop_loss_pct,
coin,
usd,
debug,
current_trade_min1_start_idx
)
if stop_loss_result is not None:
trade_log_entry, current_trade_min1_start_idx, position, coin, entry_price = stop_loss_result
trade_log.append(trade_log_entry)
continue
# Update the start index for next check
current_trade_min1_start_idx = Backtest.get_current_min1_end_idx(min1_df, date)
# Entry: only if not in position and signal changes to 1
if position == 0 and prev_mt != 1 and curr_mt == 1:
entry_result = Backtest.handle_entry(usd, price_open, date)
coin, entry_price, entry_time, usd, position, trade_log_entry = entry_result
trade_log.append(trade_log_entry)
# Exit: only if in position and signal changes from 1 to -1
elif position == 1 and prev_mt == 1 and curr_mt == -1:
exit_result = Backtest.handle_exit(coin, price_open, entry_price, entry_time, date)
usd, coin, position, entry_price, trade_log_entry = exit_result
trade_log.append(trade_log_entry)
# Track drawdown
balance = usd if position == 0 else coin * price_close
if balance > max_balance:
max_balance = balance
drawdown = (max_balance - balance) / max_balance
drawdowns.append(drawdown)
# If still in position at end, sell at last close
if position == 1:
exit_result = Backtest.handle_exit(coin, _df['close'].iloc[-1], entry_price, entry_time, _df['timestamp'].iloc[-1])
usd, coin, position, entry_price, trade_log_entry = exit_result
trade_log.append(trade_log_entry)
# Calculate statistics
final_balance = usd
n_trades = len(trade_log)
wins = [1 for t in trade_log if t['exit'] is not None and t['exit'] > t['entry']]
win_rate = len(wins) / n_trades if n_trades > 0 else 0
max_drawdown = max(drawdowns) if drawdowns else 0
avg_trade = np.mean([t['exit']/t['entry']-1 for t in trade_log if t['exit'] is not None]) if trade_log else 0
trades = []
total_fees_usd = 0.0
for trade in trade_log:
if trade['exit'] is not None:
profit_pct = (trade['exit'] - trade['entry']) / trade['entry']
else:
profit_pct = 0.0
trades.append({
'entry_time': trade['entry_time'],
'exit_time': trade['exit_time'],
'entry': trade['entry'],
'exit': trade['exit'],
'profit_pct': profit_pct,
'type': trade.get('type', 'SELL'),
'fee_usd': trade.get('fee_usd')
})
fee_usd = trade.get('fee_usd')
total_fees_usd += fee_usd
results = {
"initial_usd": initial_usd,
"final_usd": final_balance,
"n_trades": n_trades,
"win_rate": win_rate,
"max_drawdown": max_drawdown,
"avg_trade": avg_trade,
"trade_log": trade_log,
"trades": trades,
"total_fees_usd": total_fees_usd,
}
if n_trades > 0:
results["first_trade"] = {
"entry_time": trade_log[0]['entry_time'],
"entry": trade_log[0]['entry']
}
results["last_trade"] = {
"exit_time": trade_log[-1]['exit_time'],
"exit": trade_log[-1]['exit']
}
return results
@staticmethod
def check_stop_loss(min1_df, entry_time, date, entry_price, stop_loss_pct, coin, usd, debug, current_trade_min1_start_idx):
stop_price = entry_price * (1 - stop_loss_pct)
if current_trade_min1_start_idx is None:
current_trade_min1_start_idx = min1_df.index[min1_df.index >= entry_time][0]
current_min1_end_idx = min1_df.index[min1_df.index <= date][-1]
# Check all 1-minute candles in between for stop loss
min1_slice = min1_df.loc[current_trade_min1_start_idx:current_min1_end_idx]
if (min1_slice['low'] <= stop_price).any():
# Stop loss triggered, find the exact candle
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0]
# More realistic fill: if open < stop, fill at open, else at stop
if stop_candle['open'] < stop_price:
sell_price = stop_candle['open']
else:
sell_price = stop_price
if debug:
print(f"STOP LOSS triggered: entry={entry_price}, stop={stop_price}, sell_price={sell_price}, entry_time={entry_time}, stop_time={stop_candle.name}")
btc_to_sell = coin
usd_gross = btc_to_sell * sell_price
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
trade_log_entry = {
'type': 'STOP',
'entry': entry_price,
'exit': sell_price,
'entry_time': entry_time,
'exit_time': stop_candle.name,
'fee_usd': exit_fee
}
# After stop loss, reset position and entry
return trade_log_entry, None, 0, 0, 0
return None
@staticmethod
def handle_entry(usd, price_open, date):
entry_fee = MarketFees.calculate_okx_taker_maker_fee(usd, is_maker=False)
usd_after_fee = usd - entry_fee
coin = usd_after_fee / price_open
entry_price = price_open
entry_time = date
usd = 0
position = 1
trade_log_entry = {
'type': 'BUY',
'entry': entry_price,
'exit': None,
'entry_time': entry_time,
'exit_time': None,
'fee_usd': entry_fee
}
return coin, entry_price, entry_time, usd, position, trade_log_entry
@staticmethod
def handle_exit(coin, price_open, entry_price, entry_time, date):
btc_to_sell = coin
usd_gross = btc_to_sell * price_open
exit_fee = MarketFees.calculate_okx_taker_maker_fee(usd_gross, is_maker=False)
usd = usd_gross - exit_fee
trade_log_entry = {
'type': 'SELL',
'entry': entry_price,
'exit': price_open,
'entry_time': entry_time,
'exit_time': date,
'fee_usd': exit_fee
}
coin = 0
position = 0
entry_price = 0
return usd, coin, position, entry_price, trade_log_entry
@staticmethod
def get_current_min1_end_idx(min1_df, date):
# Implement the logic to find the end index of the current 1-minute candle
# This is a placeholder and should be replaced with the actual implementation
return min1_df.index[min1_df.index <= date][-1]

View File

@ -1,197 +0,0 @@
import pandas as pd
import numpy as np
from trend_detector_simple import TrendDetectorSimple
import os
import datetime
import csv
def load_data(file_path, start_date, stop_date):
"""Load and filter data by date range."""
data = pd.read_csv(file_path)
data['Timestamp'] = pd.to_datetime(data['Timestamp'], unit='s')
data = data[(data['Timestamp'] >= start_date) & (data['Timestamp'] <= stop_date)]
data.columns = data.columns.str.lower()
return data.set_index('timestamp')
def process_month_timeframe(min1_df, month_df, stop_loss_pcts, rule_name, initial_usd):
"""Process a single month for a given timeframe with all stop loss values."""
month_df = month_df.copy().reset_index(drop=True)
trend_detector = TrendDetectorSimple(month_df, verbose=False)
analysis_results = trend_detector.detect_trends()
signal_df = analysis_results.get('signal_df')
results_rows = []
trade_rows = []
for stop_loss_pct in stop_loss_pcts:
results = trend_detector.backtest_meta_supertrend(
min1_df,
initial_usd=initial_usd,
stop_loss_pct=stop_loss_pct
)
trades = results.get('trades', [])
n_trades = results["n_trades"]
n_winning_trades = sum(1 for trade in trades if trade['profit_pct'] > 0)
total_profit = sum(trade['profit_pct'] for trade in trades)
total_loss = sum(-trade['profit_pct'] for trade in trades if trade['profit_pct'] < 0)
win_rate = n_winning_trades / n_trades if n_trades > 0 else 0
avg_trade = total_profit / n_trades if n_trades > 0 else 0
profit_ratio = total_profit / total_loss if total_loss > 0 else float('inf')
# Max drawdown
cumulative_profit = 0
max_drawdown = 0
peak = 0
for trade in trades:
cumulative_profit += trade['profit_pct']
if cumulative_profit > peak:
peak = cumulative_profit
drawdown = peak - cumulative_profit
if drawdown > max_drawdown:
max_drawdown = drawdown
# Final USD
final_usd = initial_usd
for trade in trades:
final_usd *= (1 + trade['profit_pct'])
row = {
"timeframe": rule_name,
"month": str(month_df['timestamp'].iloc[0].to_period('M')),
"stop_loss_pct": stop_loss_pct,
"n_trades": n_trades,
"n_stop_loss": sum(1 for trade in trades if 'type' in trade and trade['type'] == 'STOP'),
"win_rate": win_rate,
"max_drawdown": max_drawdown,
"avg_trade": avg_trade,
"profit_ratio": profit_ratio,
"initial_usd": initial_usd,
"final_usd": final_usd,
}
results_rows.append(row)
for trade in trades:
trade_rows.append({
"timeframe": rule_name,
"month": str(month_df['timestamp'].iloc[0].to_period('M')),
"stop_loss_pct": stop_loss_pct,
"entry_time": trade.get("entry_time"),
"exit_time": trade.get("exit_time"),
"entry_price": trade.get("entry_price"),
"exit_price": trade.get("exit_price"),
"profit_pct": trade.get("profit_pct"),
"type": trade.get("type", ""),
})
return results_rows, trade_rows
def process_timeframe(rule, data_1min, stop_loss_pcts, initial_usd):
"""Process an entire timeframe sequentially."""
if rule == "1T":
df = data_1min.copy()
else:
df = data_1min.resample(rule).agg({
'open': 'first',
'high': 'max',
'low': 'min',
'close': 'last',
'volume': 'sum'
}).dropna()
df = df.reset_index()
df['month'] = df['timestamp'].dt.to_period('M')
results_rows = []
all_trade_rows = []
for month, month_df in df.groupby('month'):
if len(month_df) < 10:
continue
month_results, month_trades = process_month_timeframe(data_1min, month_df, stop_loss_pcts, rule, initial_usd)
results_rows.extend(month_results)
all_trade_rows.extend(month_trades)
return results_rows, all_trade_rows
def aggregate_results(all_rows, initial_usd):
"""Aggregate results per stop_loss_pct and per rule (timeframe)."""
from collections import defaultdict
grouped = defaultdict(list)
for row in all_rows:
key = (row['timeframe'], row['stop_loss_pct'])
grouped[key].append(row)
summary_rows = []
for (rule, stop_loss_pct), rows in grouped.items():
n_months = len(rows)
total_trades = sum(r['n_trades'] for r in rows)
total_stop_loss = sum(r['n_stop_loss'] for r in rows)
avg_win_rate = np.mean([r['win_rate'] for r in rows])
avg_max_drawdown = np.mean([r['max_drawdown'] for r in rows])
avg_avg_trade = np.mean([r['avg_trade'] for r in rows])
avg_profit_ratio = np.mean([r['profit_ratio'] for r in rows])
final_usd = np.mean([r.get('final_usd', initial_usd) for r in rows])
summary_rows.append({
"timeframe": rule,
"stop_loss_pct": stop_loss_pct,
"n_trades": total_trades,
"n_stop_loss": total_stop_loss,
"win_rate": avg_win_rate,
"max_drawdown": avg_max_drawdown,
"avg_trade": avg_avg_trade,
"profit_ratio": avg_profit_ratio,
"initial_usd": initial_usd,
"final_usd": final_usd,
})
return summary_rows
def write_results(filename, fieldnames, rows):
"""Write results to a CSV file."""
with open(filename, 'w', newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for row in rows:
writer.writerow(row)
if __name__ == "__main__":
# Config
start_date = '2020-01-01'
stop_date = '2025-05-15'
initial_usd = 10000
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M")
timeframes = ["6h", "1D"]
stop_loss_pcts = [0.01, 0.02, 0.03, 0.05, 0.07, 0.10]
data_1min = load_data('./data/btcusd_1-min_data.csv', start_date, stop_date)
print(f"1min rows: {len(data_1min)}")
filename = os.path.join(
results_dir,
f"{timestamp}_backtest_results_{start_date}_{stop_date}_multi_timeframe_stoploss.csv"
)
fieldnames = ["timeframe", "stop_loss_pct", "n_trades", "n_stop_loss", "win_rate", "max_drawdown", "avg_trade", "profit_ratio", "initial_usd", "final_usd"]
all_results = []
all_trades = []
for name in timeframes:
print(f"Processing timeframe: {name}")
results, trades = process_timeframe(name, data_1min, stop_loss_pcts, initial_usd)
all_results.extend(results)
all_trades.extend(trades)
summary_rows = aggregate_results(all_results, initial_usd)
# write_results(filename, fieldnames, summary_rows)
trades_filename = os.path.join(
results_dir,
f"{timestamp}_backtest_trades.csv"
)
trades_fieldnames = [
"timeframe", "month", "stop_loss_pct", "entry_time", "exit_time",
"entry_price", "exit_price", "profit_pct", "type"
]
# write_results(trades_filename, trades_fieldnames, all_trades)

7
cycles/market_fees.py Normal file
View File

@ -0,0 +1,7 @@
import pandas as pd
class MarketFees:
@staticmethod
def calculate_okx_taker_maker_fee(amount, is_maker=True):
fee_rate = 0.0008 if is_maker else 0.0010
return amount * fee_rate

336
cycles/supertrend.py Normal file
View File

@ -0,0 +1,336 @@
import pandas as pd
import numpy as np
import logging
from scipy.signal import find_peaks
from matplotlib.patches import Rectangle
from scipy import stats
import concurrent.futures
from functools import partial
from functools import lru_cache
import matplotlib.pyplot as plt
# Color configuration
# Plot colors
DARK_BG_COLOR = '#181C27'
LEGEND_BG_COLOR = '#333333'
TITLE_COLOR = 'white'
AXIS_LABEL_COLOR = 'white'
# Candlestick colors
CANDLE_UP_COLOR = '#089981' # Green
CANDLE_DOWN_COLOR = '#F23645' # Red
# Marker colors
MIN_COLOR = 'red'
MAX_COLOR = 'green'
# Line style colors
MIN_LINE_STYLE = 'g--' # Green dashed
MAX_LINE_STYLE = 'r--' # Red dashed
SMA7_LINE_STYLE = 'y-' # Yellow solid
SMA15_LINE_STYLE = 'm-' # Magenta solid
# SuperTrend colors
ST_COLOR_UP = 'g-'
ST_COLOR_DOWN = 'r-'
# Cache the calculation results by function parameters
@lru_cache(maxsize=32)
def cached_supertrend_calculation(period, multiplier, data_tuple):
# Convert tuple back to numpy arrays
high = np.array(data_tuple[0])
low = np.array(data_tuple[1])
close = np.array(data_tuple[2])
# Calculate TR and ATR using vectorized operations
tr = np.zeros_like(close)
tr[0] = high[0] - low[0]
hc_range = np.abs(high[1:] - close[:-1])
lc_range = np.abs(low[1:] - close[:-1])
hl_range = high[1:] - low[1:]
tr[1:] = np.maximum.reduce([hl_range, hc_range, lc_range])
# Use numpy's exponential moving average
atr = np.zeros_like(tr)
atr[0] = tr[0]
multiplier_ema = 2.0 / (period + 1)
for i in range(1, len(tr)):
atr[i] = (tr[i] * multiplier_ema) + (atr[i-1] * (1 - multiplier_ema))
# Calculate bands
upper_band = np.zeros_like(close)
lower_band = np.zeros_like(close)
for i in range(len(close)):
hl_avg = (high[i] + low[i]) / 2
upper_band[i] = hl_avg + (multiplier * atr[i])
lower_band[i] = hl_avg - (multiplier * atr[i])
final_upper = np.zeros_like(close)
final_lower = np.zeros_like(close)
supertrend = np.zeros_like(close)
trend = np.zeros_like(close)
final_upper[0] = upper_band[0]
final_lower[0] = lower_band[0]
if close[0] <= upper_band[0]:
supertrend[0] = upper_band[0]
trend[0] = -1
else:
supertrend[0] = lower_band[0]
trend[0] = 1
for i in range(1, len(close)):
if (upper_band[i] < final_upper[i-1]) or (close[i-1] > final_upper[i-1]):
final_upper[i] = upper_band[i]
else:
final_upper[i] = final_upper[i-1]
if (lower_band[i] > final_lower[i-1]) or (close[i-1] < final_lower[i-1]):
final_lower[i] = lower_band[i]
else:
final_lower[i] = final_lower[i-1]
if supertrend[i-1] == final_upper[i-1] and close[i] <= final_upper[i]:
supertrend[i] = final_upper[i]
trend[i] = -1
elif supertrend[i-1] == final_upper[i-1] and close[i] > final_upper[i]:
supertrend[i] = final_lower[i]
trend[i] = 1
elif supertrend[i-1] == final_lower[i-1] and close[i] >= final_lower[i]:
supertrend[i] = final_lower[i]
trend[i] = 1
elif supertrend[i-1] == final_lower[i-1] and close[i] < final_lower[i]:
supertrend[i] = final_upper[i]
trend[i] = -1
return {
'supertrend': supertrend,
'trend': trend,
'upper_band': final_upper,
'lower_band': final_lower
}
def calculate_supertrend_external(data, period, multiplier):
# Convert DataFrame columns to hashable tuples
high_tuple = tuple(data['high'])
low_tuple = tuple(data['low'])
close_tuple = tuple(data['close'])
# Call the cached function
return cached_supertrend_calculation(period, multiplier, (high_tuple, low_tuple, close_tuple))
class Supertrends:
def __init__(self, data, verbose=False, display=False):
"""
Initialize the TrendDetectorSimple class.
Parameters:
- data: pandas DataFrame containing price data
- verbose: boolean, whether to display detailed logging information
- display: boolean, whether to enable display/plotting features
"""
self.data = data
self.verbose = verbose
self.display = display
# Only define display-related variables if display is True
if self.display:
# Plot style configuration
self.plot_style = 'dark_background'
self.bg_color = DARK_BG_COLOR
self.plot_size = (12, 8)
# Candlestick configuration
self.candle_width = 0.6
self.candle_up_color = CANDLE_UP_COLOR
self.candle_down_color = CANDLE_DOWN_COLOR
self.candle_alpha = 0.8
self.wick_width = 1
# Marker configuration
self.min_marker = '^'
self.min_color = MIN_COLOR
self.min_size = 100
self.max_marker = 'v'
self.max_color = MAX_COLOR
self.max_size = 100
self.marker_zorder = 100
# Line configuration
self.line_width = 1
self.min_line_style = MIN_LINE_STYLE
self.max_line_style = MAX_LINE_STYLE
self.sma7_line_style = SMA7_LINE_STYLE
self.sma15_line_style = SMA15_LINE_STYLE
# Text configuration
self.title_size = 14
self.title_color = TITLE_COLOR
self.axis_label_size = 12
self.axis_label_color = AXIS_LABEL_COLOR
# Legend configuration
self.legend_loc = 'best'
self.legend_bg_color = LEGEND_BG_COLOR
# Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger('TrendDetectorSimple')
# Convert data to pandas DataFrame if it's not already
if not isinstance(self.data, pd.DataFrame):
if isinstance(self.data, list):
self.data = pd.DataFrame({'close': self.data})
else:
raise ValueError("Data must be a pandas DataFrame or a list")
def calculate_tr(self):
"""
Calculate True Range (TR) for the price data.
True Range is the greatest of:
1. Current high - current low
2. |Current high - previous close|
3. |Current low - previous close|
Returns:
- Numpy array of TR values
"""
df = self.data.copy()
high = df['high'].values
low = df['low'].values
close = df['close'].values
tr = np.zeros_like(close)
tr[0] = high[0] - low[0] # First TR is just the first day's range
for i in range(1, len(close)):
# Current high - current low
hl_range = high[i] - low[i]
# |Current high - previous close|
hc_range = abs(high[i] - close[i-1])
# |Current low - previous close|
lc_range = abs(low[i] - close[i-1])
# TR is the maximum of these three values
tr[i] = max(hl_range, hc_range, lc_range)
return tr
def calculate_atr(self, period=14):
"""
Calculate Average True Range (ATR) for the price data.
ATR is the exponential moving average of the True Range over a specified period.
Parameters:
- period: int, the period for the ATR calculation (default: 14)
Returns:
- Numpy array of ATR values
"""
tr = self.calculate_tr()
atr = np.zeros_like(tr)
# First ATR value is just the first TR
atr[0] = tr[0]
# Calculate exponential moving average (EMA) of TR
multiplier = 2.0 / (period + 1)
for i in range(1, len(tr)):
atr[i] = (tr[i] * multiplier) + (atr[i-1] * (1 - multiplier))
return atr
def detect_trends(self):
"""
Detect trends by identifying local minima and maxima in the price data
using scipy.signal.find_peaks.
Parameters:
- prominence: float, required prominence of peaks (relative to the price range)
- width: int, required width of peaks in data points
Returns:
- DataFrame with columns for timestamps, prices, and trend indicators
- Dictionary containing analysis results including linear regression, SMAs, and SuperTrend indicators
"""
df = self.data
# close_prices = df['close'].values
# max_peaks, _ = find_peaks(close_prices)
# min_peaks, _ = find_peaks(-close_prices)
# df['is_min'] = False
# df['is_max'] = False
# for peak in max_peaks:
# df.at[peak, 'is_max'] = True
# for peak in min_peaks:
# df.at[peak, 'is_min'] = True
# result = df[['timestamp', 'close', 'is_min', 'is_max']].copy()
# Perform linear regression on min_peaks and max_peaks
# min_prices = df['close'].iloc[min_peaks].values
# max_prices = df['close'].iloc[max_peaks].values
# Linear regression for min peaks if we have at least 2 points
# min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
# Linear regression for max peaks if we have at least 2 points
# max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
# sma_7 = pd.Series(close_prices).rolling(window=7, min_periods=1).mean().values
# sma_15 = pd.Series(close_prices).rolling(window=15, min_periods=1).mean().values
analysis_results = {}
# analysis_results['linear_regression'] = {
# 'min': {
# 'slope': min_slope,
# 'intercept': min_intercept,
# 'r_squared': min_r_value ** 2
# },
# 'max': {
# 'slope': max_slope,
# 'intercept': max_intercept,
# 'r_squared': max_r_value ** 2
# }
# }
# analysis_results['sma'] = {
# '7': sma_7,
# '15': sma_15
# }
# Calculate SuperTrend indicators
supertrend_results_list = self._calculate_supertrend_indicators()
analysis_results['supertrend'] = supertrend_results_list
return analysis_results
def calculate_supertrend_indicators(self):
"""
Calculate SuperTrend indicators with different parameter sets in parallel.
Returns:
- list, the SuperTrend results
"""
supertrend_params = [
{"period": 12, "multiplier": 3.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
{"period": 10, "multiplier": 1.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
{"period": 11, "multiplier": 2.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN}
]
data = self.data.copy()
# For just 3 calculations, direct calculation might be faster than process pool
results = []
for p in supertrend_params:
result = calculate_supertrend_external(data, p["period"], p["multiplier"])
results.append(result)
supertrend_results_list = []
for params, result in zip(supertrend_params, results):
supertrend_results_list.append({
"results": result,
"params": params
})
return supertrend_results_list

View File

@ -1,25 +0,0 @@
import pandas as pd
class Taxes:
def __init__(self, tax_rate=0.20):
"""
tax_rate: flat tax rate on positive profits (e.g., 0.20 for 20%)
"""
self.tax_rate = tax_rate
def add_taxes_to_results_csv(self, input_csv, output_csv=None, profit_col='final_usd'):
"""
Reads a backtest results CSV, adds tax columns, and writes to a new CSV.
- input_csv: path to the input CSV file
- output_csv: path to the output CSV file (if None, overwrite input)
- profit_col: column name for profit (default: 'final_usd')
"""
df = pd.read_csv(input_csv, delimiter=None)
# Compute tax only on positive profits
df['tax_paid'] = df[profit_col].apply(lambda x: self.tax_rate * x if x > 0 else 0)
df['net_profit_after_tax'] = df[profit_col] - df['tax_paid']
df['cumulative_tax_paid'] = df['tax_paid'].cumsum()
if not output_csv:
output_csv = input_csv
df.to_csv(output_csv, index=False)
return output_csv

View File

@ -1,848 +0,0 @@
import pandas as pd
import numpy as np
import logging
from scipy.signal import find_peaks
from matplotlib.patches import Rectangle
from scipy import stats
import concurrent.futures
from functools import partial
from functools import lru_cache
import matplotlib.pyplot as plt
# Color configuration
# Plot colors
DARK_BG_COLOR = '#181C27'
LEGEND_BG_COLOR = '#333333'
TITLE_COLOR = 'white'
AXIS_LABEL_COLOR = 'white'
# Candlestick colors
CANDLE_UP_COLOR = '#089981' # Green
CANDLE_DOWN_COLOR = '#F23645' # Red
# Marker colors
MIN_COLOR = 'red'
MAX_COLOR = 'green'
# Line style colors
MIN_LINE_STYLE = 'g--' # Green dashed
MAX_LINE_STYLE = 'r--' # Red dashed
SMA7_LINE_STYLE = 'y-' # Yellow solid
SMA15_LINE_STYLE = 'm-' # Magenta solid
# SuperTrend colors
ST_COLOR_UP = 'g-'
ST_COLOR_DOWN = 'r-'
# Cache the calculation results by function parameters
@lru_cache(maxsize=32)
def cached_supertrend_calculation(period, multiplier, data_tuple):
# Convert tuple back to numpy arrays
high = np.array(data_tuple[0])
low = np.array(data_tuple[1])
close = np.array(data_tuple[2])
# Calculate TR and ATR using vectorized operations
tr = np.zeros_like(close)
tr[0] = high[0] - low[0]
hc_range = np.abs(high[1:] - close[:-1])
lc_range = np.abs(low[1:] - close[:-1])
hl_range = high[1:] - low[1:]
tr[1:] = np.maximum.reduce([hl_range, hc_range, lc_range])
# Use numpy's exponential moving average
atr = np.zeros_like(tr)
atr[0] = tr[0]
multiplier_ema = 2.0 / (period + 1)
for i in range(1, len(tr)):
atr[i] = (tr[i] * multiplier_ema) + (atr[i-1] * (1 - multiplier_ema))
# Calculate bands
upper_band = np.zeros_like(close)
lower_band = np.zeros_like(close)
for i in range(len(close)):
hl_avg = (high[i] + low[i]) / 2
upper_band[i] = hl_avg + (multiplier * atr[i])
lower_band[i] = hl_avg - (multiplier * atr[i])
final_upper = np.zeros_like(close)
final_lower = np.zeros_like(close)
supertrend = np.zeros_like(close)
trend = np.zeros_like(close)
final_upper[0] = upper_band[0]
final_lower[0] = lower_band[0]
if close[0] <= upper_band[0]:
supertrend[0] = upper_band[0]
trend[0] = -1
else:
supertrend[0] = lower_band[0]
trend[0] = 1
for i in range(1, len(close)):
if (upper_band[i] < final_upper[i-1]) or (close[i-1] > final_upper[i-1]):
final_upper[i] = upper_band[i]
else:
final_upper[i] = final_upper[i-1]
if (lower_band[i] > final_lower[i-1]) or (close[i-1] < final_lower[i-1]):
final_lower[i] = lower_band[i]
else:
final_lower[i] = final_lower[i-1]
if supertrend[i-1] == final_upper[i-1] and close[i] <= final_upper[i]:
supertrend[i] = final_upper[i]
trend[i] = -1
elif supertrend[i-1] == final_upper[i-1] and close[i] > final_upper[i]:
supertrend[i] = final_lower[i]
trend[i] = 1
elif supertrend[i-1] == final_lower[i-1] and close[i] >= final_lower[i]:
supertrend[i] = final_lower[i]
trend[i] = 1
elif supertrend[i-1] == final_lower[i-1] and close[i] < final_lower[i]:
supertrend[i] = final_upper[i]
trend[i] = -1
return {
'supertrend': supertrend,
'trend': trend,
'upper_band': final_upper,
'lower_band': final_lower
}
def calculate_supertrend_external(data, period, multiplier):
# Convert DataFrame columns to hashable tuples
high_tuple = tuple(data['high'])
low_tuple = tuple(data['low'])
close_tuple = tuple(data['close'])
# Call the cached function
return cached_supertrend_calculation(period, multiplier, (high_tuple, low_tuple, close_tuple))
def calculate_okx_fee(amount, is_maker=True):
fee_rate = 0.0008 if is_maker else 0.0010
return amount * fee_rate
class TrendDetectorSimple:
def __init__(self, data, verbose=False, display=False):
"""
Initialize the TrendDetectorSimple class.
Parameters:
- data: pandas DataFrame containing price data
- verbose: boolean, whether to display detailed logging information
- display: boolean, whether to enable display/plotting features
"""
self.data = data
self.verbose = verbose
self.display = display
# Only define display-related variables if display is True
if self.display:
# Plot style configuration
self.plot_style = 'dark_background'
self.bg_color = DARK_BG_COLOR
self.plot_size = (12, 8)
# Candlestick configuration
self.candle_width = 0.6
self.candle_up_color = CANDLE_UP_COLOR
self.candle_down_color = CANDLE_DOWN_COLOR
self.candle_alpha = 0.8
self.wick_width = 1
# Marker configuration
self.min_marker = '^'
self.min_color = MIN_COLOR
self.min_size = 100
self.max_marker = 'v'
self.max_color = MAX_COLOR
self.max_size = 100
self.marker_zorder = 100
# Line configuration
self.line_width = 1
self.min_line_style = MIN_LINE_STYLE
self.max_line_style = MAX_LINE_STYLE
self.sma7_line_style = SMA7_LINE_STYLE
self.sma15_line_style = SMA15_LINE_STYLE
# Text configuration
self.title_size = 14
self.title_color = TITLE_COLOR
self.axis_label_size = 12
self.axis_label_color = AXIS_LABEL_COLOR
# Legend configuration
self.legend_loc = 'best'
self.legend_bg_color = LEGEND_BG_COLOR
# Configure logging
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING,
format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger('TrendDetectorSimple')
# Convert data to pandas DataFrame if it's not already
if not isinstance(self.data, pd.DataFrame):
if isinstance(self.data, list):
self.data = pd.DataFrame({'close': self.data})
else:
raise ValueError("Data must be a pandas DataFrame or a list")
def calculate_tr(self):
"""
Calculate True Range (TR) for the price data.
True Range is the greatest of:
1. Current high - current low
2. |Current high - previous close|
3. |Current low - previous close|
Returns:
- Numpy array of TR values
"""
df = self.data.copy()
high = df['high'].values
low = df['low'].values
close = df['close'].values
tr = np.zeros_like(close)
tr[0] = high[0] - low[0] # First TR is just the first day's range
for i in range(1, len(close)):
# Current high - current low
hl_range = high[i] - low[i]
# |Current high - previous close|
hc_range = abs(high[i] - close[i-1])
# |Current low - previous close|
lc_range = abs(low[i] - close[i-1])
# TR is the maximum of these three values
tr[i] = max(hl_range, hc_range, lc_range)
return tr
def calculate_atr(self, period=14):
"""
Calculate Average True Range (ATR) for the price data.
ATR is the exponential moving average of the True Range over a specified period.
Parameters:
- period: int, the period for the ATR calculation (default: 14)
Returns:
- Numpy array of ATR values
"""
tr = self.calculate_tr()
atr = np.zeros_like(tr)
# First ATR value is just the first TR
atr[0] = tr[0]
# Calculate exponential moving average (EMA) of TR
multiplier = 2.0 / (period + 1)
for i in range(1, len(tr)):
atr[i] = (tr[i] * multiplier) + (atr[i-1] * (1 - multiplier))
return atr
def detect_trends(self):
"""
Detect trends by identifying local minima and maxima in the price data
using scipy.signal.find_peaks.
Parameters:
- prominence: float, required prominence of peaks (relative to the price range)
- width: int, required width of peaks in data points
Returns:
- DataFrame with columns for timestamps, prices, and trend indicators
- Dictionary containing analysis results including linear regression, SMAs, and SuperTrend indicators
"""
df = self.data
# close_prices = df['close'].values
# max_peaks, _ = find_peaks(close_prices)
# min_peaks, _ = find_peaks(-close_prices)
# df['is_min'] = False
# df['is_max'] = False
# for peak in max_peaks:
# df.at[peak, 'is_max'] = True
# for peak in min_peaks:
# df.at[peak, 'is_min'] = True
# result = df[['timestamp', 'close', 'is_min', 'is_max']].copy()
# Perform linear regression on min_peaks and max_peaks
# min_prices = df['close'].iloc[min_peaks].values
# max_prices = df['close'].iloc[max_peaks].values
# Linear regression for min peaks if we have at least 2 points
# min_slope, min_intercept, min_r_value, _, _ = stats.linregress(min_peaks, min_prices)
# Linear regression for max peaks if we have at least 2 points
# max_slope, max_intercept, max_r_value, _, _ = stats.linregress(max_peaks, max_prices)
# Calculate Simple Moving Averages (SMA) for 7 and 15 periods
# sma_7 = pd.Series(close_prices).rolling(window=7, min_periods=1).mean().values
# sma_15 = pd.Series(close_prices).rolling(window=15, min_periods=1).mean().values
analysis_results = {}
# analysis_results['linear_regression'] = {
# 'min': {
# 'slope': min_slope,
# 'intercept': min_intercept,
# 'r_squared': min_r_value ** 2
# },
# 'max': {
# 'slope': max_slope,
# 'intercept': max_intercept,
# 'r_squared': max_r_value ** 2
# }
# }
# analysis_results['sma'] = {
# '7': sma_7,
# '15': sma_15
# }
# Calculate SuperTrend indicators
supertrend_results_list = self._calculate_supertrend_indicators()
analysis_results['supertrend'] = supertrend_results_list
return analysis_results
def _calculate_supertrend_indicators(self):
"""
Calculate SuperTrend indicators with different parameter sets in parallel.
Returns:
- list, the SuperTrend results
"""
supertrend_params = [
{"period": 12, "multiplier": 3.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
{"period": 10, "multiplier": 1.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN},
{"period": 11, "multiplier": 2.0, "color_up": ST_COLOR_UP, "color_down": ST_COLOR_DOWN}
]
data = self.data.copy()
# For just 3 calculations, direct calculation might be faster than process pool
results = []
for p in supertrend_params:
result = calculate_supertrend_external(data, p["period"], p["multiplier"])
results.append(result)
supertrend_results_list = []
for params, result in zip(supertrend_params, results):
supertrend_results_list.append({
"results": result,
"params": params
})
return supertrend_results_list
def plot_trends(self, trend_data, analysis_results, view="both"):
"""
Plot the price data with detected trends using a candlestick chart.
Also plots SuperTrend indicators with three different parameter sets.
Parameters:
- trend_data: DataFrame, the output from detect_trends()
- analysis_results: Dictionary containing analysis results from detect_trends()
- view: str, one of 'both', 'trend', 'supertrend'; determines which plot(s) to display
Returns:
- None (displays the plot)
"""
if not self.display:
return # Do nothing if display is False
plt.style.use(self.plot_style)
if view == "both":
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(self.plot_size[0]*2, self.plot_size[1]))
else:
fig, ax = plt.subplots(figsize=self.plot_size)
ax1 = ax2 = None
if view == "trend":
ax1 = ax
elif view == "supertrend":
ax2 = ax
fig.patch.set_facecolor(self.bg_color)
if ax1: ax1.set_facecolor(self.bg_color)
if ax2: ax2.set_facecolor(self.bg_color)
df = self.data.copy()
if ax1:
self._plot_trend_analysis(ax1, df, trend_data, analysis_results)
if ax2:
self._plot_supertrend_analysis(ax2, df, analysis_results['supertrend'])
plt.tight_layout()
plt.show()
def _plot_candlesticks(self, ax, df):
"""
Plot candlesticks on the given axis.
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
"""
from matplotlib.patches import Rectangle
for i in range(len(df)):
# Get OHLC values for this candle
open_val = df['open'].iloc[i]
close_val = df['close'].iloc[i]
high_val = df['high'].iloc[i]
low_val = df['low'].iloc[i]
# Determine candle color
color = self.candle_up_color if close_val >= open_val else self.candle_down_color
# Plot candle body
body_height = abs(close_val - open_val)
bottom = min(open_val, close_val)
rect = Rectangle((i - self.candle_width/2, bottom), self.candle_width, body_height,
color=color, alpha=self.candle_alpha)
ax.add_patch(rect)
# Plot candle wicks
ax.plot([i, i], [low_val, high_val], color=color, linewidth=self.wick_width)
def _plot_trend_analysis(self, ax, df, trend_data, analysis_results):
"""
Plot trend analysis on the given axis.
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
- trend_data: pandas.DataFrame, the trend data
- analysis_results: dict, the analysis results
"""
# Draw candlesticks
self._plot_candlesticks(ax, df)
# Plot minima and maxima points
self._plot_min_max_points(ax, df, trend_data)
# Plot trend lines and moving averages
if analysis_results:
self._plot_trend_lines(ax, df, analysis_results)
# Configure the subplot
self._configure_subplot(ax, 'Price Chart with Trend Analysis', len(df))
def _plot_min_max_points(self, ax, df, trend_data):
"""
Plot minimum and maximum points on the given axis.
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
- trend_data: pandas.DataFrame, the trend data
"""
min_indices = trend_data.index[trend_data['is_min'] == True].tolist()
if min_indices:
min_y = [df['close'].iloc[i] for i in min_indices]
ax.scatter(min_indices, min_y, color=self.min_color, s=self.min_size,
marker=self.min_marker, label='Local Minima', zorder=self.marker_zorder)
max_indices = trend_data.index[trend_data['is_max'] == True].tolist()
if max_indices:
max_y = [df['close'].iloc[i] for i in max_indices]
ax.scatter(max_indices, max_y, color=self.max_color, s=self.max_size,
marker=self.max_marker, label='Local Maxima', zorder=self.marker_zorder)
def _plot_trend_lines(self, ax, df, analysis_results):
"""
Plot trend lines on the given axis.
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
- analysis_results: dict, the analysis results
"""
x_vals = np.arange(len(df))
# Minima regression line (support)
min_slope = analysis_results['linear_regression']['min']['slope']
min_intercept = analysis_results['linear_regression']['min']['intercept']
min_line = min_slope * x_vals + min_intercept
ax.plot(x_vals, min_line, self.min_line_style, linewidth=self.line_width,
label='Minima Regression')
# Maxima regression line (resistance)
max_slope = analysis_results['linear_regression']['max']['slope']
max_intercept = analysis_results['linear_regression']['max']['intercept']
max_line = max_slope * x_vals + max_intercept
ax.plot(x_vals, max_line, self.max_line_style, linewidth=self.line_width,
label='Maxima Regression')
# SMA-7 line
sma_7 = analysis_results['sma']['7']
ax.plot(x_vals, sma_7, self.sma7_line_style, linewidth=self.line_width,
label='SMA-7')
# SMA-15 line
sma_15 = analysis_results['sma']['15']
valid_idx_15 = ~np.isnan(sma_15)
ax.plot(x_vals[valid_idx_15], sma_15[valid_idx_15], self.sma15_line_style,
linewidth=self.line_width, label='SMA-15')
def _configure_subplot(self, ax, title, data_length):
"""
Configure the subplot with title, labels, limits, and legend.
Parameters:
- ax: matplotlib.axes.Axes, the axis to configure
- title: str, the title of the subplot
- data_length: int, the length of the data
"""
# Set title and labels
ax.set_title(title, fontsize=self.title_size, color=self.title_color)
ax.set_xlabel('Date', fontsize=self.axis_label_size, color=self.axis_label_color)
ax.set_ylabel('Price', fontsize=self.axis_label_size, color=self.axis_label_color)
# Set appropriate x-axis limits
ax.set_xlim(-0.5, data_length - 0.5)
# Add a legend
ax.legend(loc=self.legend_loc, facecolor=self.legend_bg_color)
def _plot_supertrend_analysis(self, ax, df, supertrend_results_list=None):
"""
Plot SuperTrend analysis on the given axis.
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
- supertrend_results_list: list, the SuperTrend results (optional)
"""
self._plot_candlesticks(ax, df)
self._plot_supertrend_lines(ax, df, supertrend_results_list, style='Both')
self._configure_subplot(ax, 'Multiple SuperTrend Indicators', len(df))
def _plot_supertrend_lines(self, ax, df, supertrend_results_list, style="Horizontal"):
"""
Plot SuperTrend lines on the given axis.
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
- supertrend_results_list: list, the SuperTrend results
"""
x_vals = np.arange(len(df))
if style == 'Horizontal' or style == 'Both':
if len(supertrend_results_list) != 3:
raise ValueError("Expected exactly 3 SuperTrend results for meta calculation")
trends = [st["results"]["trend"] for st in supertrend_results_list]
band_height = 0.02 * (df["high"].max() - df["low"].min())
y_base = df["low"].min() - band_height * 1.5
prev_color = None
for i in range(1, len(x_vals)):
t_vals = [t[i] for t in trends]
up_count = t_vals.count(1)
down_count = t_vals.count(-1)
if down_count == 3:
color = "red"
elif down_count == 2 and up_count == 1:
color = "orange"
elif down_count == 1 and up_count == 2:
color = "yellow"
elif up_count == 3:
color = "green"
else:
continue # skip if unknown or inconsistent values
ax.add_patch(Rectangle(
(x_vals[i-1], y_base),
1,
band_height,
color=color,
linewidth=0,
alpha=0.6
))
# Draw a vertical line at the change of color
if prev_color and prev_color != color:
ax.axvline(x_vals[i-1], color="grey", alpha=0.3, linewidth=1)
prev_color = color
ax.set_ylim(bottom=y_base - band_height * 0.5)
if style == 'Curves' or style == 'Both':
for st in supertrend_results_list:
params = st["params"]
results = st["results"]
supertrend = results["supertrend"]
trend = results["trend"]
# Plot SuperTrend line with color based on trend
for i in range(1, len(x_vals)):
if trend[i] == 1: # Uptrend
ax.plot(x_vals[i-1:i+1], supertrend[i-1:i+1], params["color_up"], linewidth=self.line_width)
else: # Downtrend
ax.plot(x_vals[i-1:i+1], supertrend[i-1:i+1], params["color_down"], linewidth=self.line_width)
self._plot_metasupertrend_lines(ax, df, supertrend_results_list)
self._add_supertrend_legend(ax, supertrend_results_list)
def _plot_metasupertrend_lines(self, ax, df, supertrend_results_list):
"""
Plot a Meta SuperTrend line where all individual SuperTrends agree on trend.
Parameters:
- ax: matplotlib.axes.Axes, the axis to plot on
- df: pandas.DataFrame, the data to plot
- supertrend_results_list: list, each item contains SuperTrend 'results' and 'params'
"""
x_vals = np.arange(len(df))
if len(supertrend_results_list) != 3:
raise ValueError("Expected exactly 3 SuperTrend results for meta calculation")
trends = [st["results"]["trend"] for st in supertrend_results_list]
supertrends = [st["results"]["supertrend"] for st in supertrend_results_list]
params = supertrend_results_list[0]["params"] # Use first config for styling
trends_arr = np.stack(trends, axis=1)
meta_trend = np.where((trends_arr[:,0] == trends_arr[:,1]) & (trends_arr[:,1] == trends_arr[:,2]), trends_arr[:,0], 0)
for i in range(1, len(x_vals)):
t1, t2, t3 = trends[0][i], trends[1][i], trends[2][i]
if t1 == t2 == t3:
meta_trend = t1
# Average the 3 supertrend values
st_avg_prev = np.mean([s[i-1] for s in supertrends])
st_avg_curr = np.mean([s[i] for s in supertrends])
color = params["color_up"] if meta_trend == 1 else params["color_down"]
ax.plot(x_vals[i-1:i+1], [st_avg_prev, st_avg_curr], color, linewidth=self.line_width)
def _add_supertrend_legend(self, ax, supertrend_results_list):
"""
Add SuperTrend legend entries to the given axis.
Parameters:
- ax: matplotlib.axes.Axes, the axis to add legend entries to
- supertrend_results_list: list, the SuperTrend results
"""
for st in supertrend_results_list:
params = st["params"]
period = params["period"]
multiplier = params["multiplier"]
color_up = params["color_up"]
color_down = params["color_down"]
ax.plot([], [], color_up, linewidth=self.line_width,
label=f'ST (P:{period}, M:{multiplier}) Up')
ax.plot([], [], color_down, linewidth=self.line_width,
label=f'ST (P:{period}, M:{multiplier}) Down')
def backtest_meta_supertrend(self, min1_df, initial_usd=10000, stop_loss_pct=0.05, debug=False):
"""
Backtest a simple strategy using the meta supertrend (all three supertrends agree).
Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss.
Parameters:
- min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional)
- initial_usd: float, starting USD amount
- stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%)
- debug: bool, whether to print debug info
"""
df = self.data.copy().reset_index(drop=True)
df['timestamp'] = pd.to_datetime(df['timestamp'])
# Get meta supertrend (all three agree)
supertrend_results_list = self._calculate_supertrend_indicators()
trends = [st['results']['trend'] for st in supertrend_results_list]
trends_arr = np.stack(trends, axis=1)
meta_trend = np.where((trends_arr[:,0] == trends_arr[:,1]) & (trends_arr[:,1] == trends_arr[:,2]),
trends_arr[:,0], 0)
position = 0 # 0 = no position, 1 = long
entry_price = 0
usd = initial_usd
coin = 0
trade_log = []
max_balance = initial_usd
drawdowns = []
trades = []
entry_time = None
current_trade_min1_start_idx = None
min1_df['timestamp'] = pd.to_datetime(min1_df.index)
for i in range(1, len(df)):
if i % 100 == 0 and debug:
self.logger.debug(f"Progress: {i}/{len(df)} rows processed.")
price_open = df['open'].iloc[i]
price_high = df['high'].iloc[i]
price_low = df['low'].iloc[i]
price_close = df['close'].iloc[i]
date = df['timestamp'].iloc[i]
prev_mt = meta_trend[i-1]
curr_mt = meta_trend[i]
# Check stop loss if in position
if position == 1:
stop_price = entry_price * (1 - stop_loss_pct)
if current_trade_min1_start_idx is None:
# First check after entry, find the entry point in 1-min data
current_trade_min1_start_idx = min1_df.index[min1_df.index >= entry_time][0]
# Get the end index for current check
current_min1_end_idx = min1_df.index[min1_df.index <= date][-1]
# Check all 1-minute candles in between for stop loss
min1_slice = min1_df.loc[current_trade_min1_start_idx:current_min1_end_idx]
if (min1_slice['low'] <= stop_price).any():
# Stop loss triggered, find the exact candle
stop_candle = min1_slice[min1_slice['low'] <= stop_price].iloc[0]
# More realistic fill: if open < stop, fill at open, else at stop
if stop_candle['open'] < stop_price:
sell_price = stop_candle['open']
else:
sell_price = stop_price
if debug:
print(f"STOP LOSS triggered: entry={entry_price}, stop={stop_price}, sell_price={sell_price}, entry_time={entry_time}, stop_time={stop_candle.name}")
btc_to_sell = coin
usd_gross = btc_to_sell * sell_price
exit_fee = calculate_okx_fee(usd_gross, is_maker=False) # taker fee
usd = usd_gross - exit_fee
trade_log.append({
'type': 'STOP',
'entry': entry_price,
'exit': sell_price,
'entry_time': entry_time,
'exit_time': stop_candle.name,
'fee_usd': exit_fee
})
coin = 0
position = 0
entry_price = 0
current_trade_min1_start_idx = None
continue
# Update the start index for next check
current_trade_min1_start_idx = current_min1_end_idx
# Entry: only if not in position and signal changes to 1
if position == 0 and prev_mt != 1 and curr_mt == 1:
# Buy at open, fee is charged in USD
entry_fee = calculate_okx_fee(usd, is_maker=False)
usd_after_fee = usd - entry_fee
coin = usd_after_fee / price_open
entry_price = price_open
entry_time = date
usd = 0
position = 1
current_trade_min1_start_idx = None # Will be set on first stop loss check
trade_log.append({
'type': 'BUY',
'entry': entry_price,
'exit': None,
'entry_time': entry_time,
'exit_time': None,
'fee_usd': entry_fee
})
# Exit: only if in position and signal changes from 1 to -1
elif position == 1 and prev_mt == 1 and curr_mt == -1:
# Sell at open, fee is charged in USD
btc_to_sell = coin
usd_gross = btc_to_sell * price_open
exit_fee = calculate_okx_fee(usd_gross, is_maker=False)
usd = usd_gross - exit_fee
trade_log.append({
'type': 'SELL',
'entry': entry_price,
'exit': price_open,
'entry_time': entry_time,
'exit_time': date,
'fee_usd': exit_fee
})
coin = 0
position = 0
entry_price = 0
current_trade_min1_start_idx = None
# Track drawdown
balance = usd if position == 0 else coin * price_close
if balance > max_balance:
max_balance = balance
drawdown = (max_balance - balance) / max_balance
drawdowns.append(drawdown)
# If still in position at end, sell at last close
if position == 1:
btc_to_sell = coin
usd_gross = btc_to_sell * df['close'].iloc[-1]
exit_fee = calculate_okx_fee(usd_gross, is_maker=False)
usd = usd_gross - exit_fee
trade_log.append({
'type': 'EOD',
'entry': entry_price,
'exit': df['close'].iloc[-1],
'entry_time': entry_time,
'exit_time': df['timestamp'].iloc[-1],
'fee_usd': exit_fee
})
coin = 0
position = 0
entry_price = 0
# Calculate statistics
final_balance = usd
n_trades = len(trade_log)
wins = [1 for t in trade_log if t['exit'] is not None and t['exit'] > t['entry']]
win_rate = len(wins) / n_trades if n_trades > 0 else 0
max_drawdown = max(drawdowns) if drawdowns else 0
avg_trade = np.mean([t['exit']/t['entry']-1 for t in trade_log if t['exit'] is not None]) if trade_log else 0
trades = []
total_fees_usd = 0.0
for trade in trade_log:
if trade['exit'] is not None:
profit_pct = (trade['exit'] - trade['entry']) / trade['entry']
else:
profit_pct = 0.0
trades.append({
'entry_time': trade['entry_time'],
'exit_time': trade['exit_time'],
'entry': trade['entry'],
'exit': trade['exit'],
'profit_pct': profit_pct,
'type': trade.get('type', 'SELL'),
'fee_usd': trade.get('fee_usd')
})
fee_usd = trade.get('fee_usd')
total_fees_usd += fee_usd
results = {
"initial_usd": initial_usd,
"final_usd": final_balance,
"n_trades": n_trades,
"win_rate": win_rate,
"max_drawdown": max_drawdown,
"avg_trade": avg_trade,
"trade_log": trade_log,
"trades": trades,
"total_fees_usd": total_fees_usd,
}
if n_trades > 0:
results["first_trade"] = {
"entry_time": trade_log[0]['entry_time'],
"entry": trade_log[0]['entry']
}
results["last_trade"] = {
"exit_time": trade_log[-1]['exit_time'],
"exit": trade_log[-1]['exit']
}
return results

View File

@ -1,23 +0,0 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from taxes import Taxes
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python apply_taxes_to_file.py <input_csv> [profit_col]")
sys.exit(1)
input_csv = sys.argv[1]
profit_col = sys.argv[2] if len(sys.argv) > 2 else 'final_usd'
if not os.path.isfile(input_csv):
print(f"File not found: {input_csv}")
sys.exit(1)
base, ext = os.path.splitext(input_csv)
output_csv = f"{base}_taxed.csv"
taxes = Taxes() # Default 20% tax rate
taxes.add_taxes_to_results_csv(input_csv, output_csv, profit_col=profit_col)
print(f"Taxed file saved as: {output_csv}")

View File

@ -169,15 +169,19 @@ class Storage:
filtered_row = {k: v for k, v in row.items() if k in fieldnames} filtered_row = {k: v for k, v in row.items() if k in fieldnames}
writer.writerow(filtered_row) writer.writerow(filtered_row)
def write_results_combined(self, filename, fieldnames, rows): def write_backtest_results(self, filename, fieldnames, rows, metadata_lines=None):
"""Write a combined results to a CSV file """Write a combined results to a CSV file
Args: Args:
filename: filename to write to filename: filename to write to
fieldnames: list of fieldnames fieldnames: list of fieldnames
rows: list of rows rows: list of rows
metadata_lines: optional list of strings to write as header comments
""" """
fname = os.path.join(self.results_dir, filename) fname = os.path.join(self.results_dir, filename)
with open(fname, "w", newline="") as csvfile: with open(fname, "w", newline="") as csvfile:
if metadata_lines:
for line in metadata_lines:
csvfile.write(f"{line}\n")
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter='\t') writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter='\t')
writer.writeheader() writer.writeheader()
for row in rows: for row in rows:

118
main.py
View File

@ -4,15 +4,14 @@ import logging
import concurrent.futures import concurrent.futures
import os import os
import datetime import datetime
import queue import argparse
import json
import ast
from cycles.trend_detector_simple import TrendDetectorSimple
from cycles.taxes import Taxes
from cycles.utils.storage import Storage from cycles.utils.storage import Storage
from cycles.utils.gsheets import GSheetBatchPusher
from cycles.utils.system import SystemUtils from cycles.utils.system import SystemUtils
from cycles.backtest import Backtest
# Set up logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s", format="%(asctime)s [%(levelname)s] %(message)s",
@ -22,19 +21,17 @@ logging.basicConfig(
] ]
) )
# Global queue for batching Google Sheets updates
results_queue = queue.Queue()
def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd, debug=False): def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd, debug=False):
"""Process the entire timeframe with all stop loss values (no monthly split)""" """Process the entire timeframe with all stop loss values (no monthly split)"""
df = df.copy().reset_index(drop=True) df = df.copy().reset_index(drop=True)
trend_detector = TrendDetectorSimple(df, verbose=False)
results_rows = [] results_rows = []
trade_rows = [] trade_rows = []
for stop_loss_pct in stop_loss_pcts: for stop_loss_pct in stop_loss_pcts:
results = trend_detector.backtest_meta_supertrend( results = Backtest.run(
min1_df, min1_df,
df,
initial_usd=initial_usd, initial_usd=initial_usd,
stop_loss_pct=stop_loss_pct, stop_loss_pct=stop_loss_pct,
debug=debug debug=debug
@ -100,9 +97,10 @@ def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd,
print("Large loss trade:", trade) print("Large loss trade:", trade)
return results_rows, trade_rows return results_rows, trade_rows
def process_timeframe(timeframe_info, debug=False): def process(timeframe_info, debug=False):
"""Process a single (timeframe, stop_loss_pct) combination (no monthly split)""" """Process a single (timeframe, stop_loss_pct) combination (no monthly split)"""
rule, data_1min, stop_loss_pct, initial_usd = timeframe_info rule, data_1min, stop_loss_pct, initial_usd = timeframe_info
if rule == "1T": if rule == "1T":
df = data_1min.copy() df = data_1min.copy()
else: else:
@ -114,7 +112,6 @@ def process_timeframe(timeframe_info, debug=False):
'volume': 'sum' 'volume': 'sum'
}).dropna() }).dropna()
df = df.reset_index() df = df.reset_index()
# Only process one stop loss
results_rows, all_trade_rows = process_timeframe_data(data_1min, df, [stop_loss_pct], rule, initial_usd, debug=debug) results_rows, all_trade_rows = process_timeframe_data(data_1min, df, [stop_loss_pct], rule, initial_usd, debug=debug)
return results_rows, all_trade_rows return results_rows, all_trade_rows
@ -166,32 +163,69 @@ def get_nearest_price(df, target_date):
return nearest_time, price return nearest_time, price
if __name__ == "__main__": if __name__ == "__main__":
# Configuration debug = True
# start_date = '2022-01-01'
# stop_date = '2023-01-01'
start_date = '2024-05-15'
stop_date = '2025-05-15'
initial_usd = 10000 parser = argparse.ArgumentParser(description="Run backtest with config file.")
parser.add_argument("config", type=str, nargs="?", help="Path to config JSON file.")
args = parser.parse_args()
debug = False # Default values (from config.json)
default_config = {
"start_date": "2024-05-15",
"stop_date": datetime.datetime.today().strftime('%Y-%m-%d'),
"initial_usd": 10000,
"timeframes": ["1D"],
"stop_loss_pcts": [0.01, 0.02, 0.03],
}
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M") if args.config:
with open(args.config, 'r') as f:
config = json.load(f)
else:
print("No config file provided. Please enter the following values (press Enter to use default):")
start_date = input(f"Start date [{default_config['start_date']}]: ") or default_config['start_date']
stop_date = input(f"Stop date [{default_config['stop_date']}]: ") or default_config['stop_date']
initial_usd_str = input(f"Initial USD [{default_config['initial_usd']}]: ") or str(default_config['initial_usd'])
initial_usd = float(initial_usd_str)
timeframes_str = input(f"Timeframes (comma separated) [{', '.join(default_config['timeframes'])}]: ") or ','.join(default_config['timeframes'])
timeframes = [tf.strip() for tf in timeframes_str.split(',') if tf.strip()]
stop_loss_pcts_str = input(f"Stop loss pcts (comma separated) [{', '.join(str(x) for x in default_config['stop_loss_pcts'])}]: ") or ','.join(str(x) for x in default_config['stop_loss_pcts'])
stop_loss_pcts = [float(x.strip()) for x in stop_loss_pcts_str.split(',') if x.strip()]
config = {
'start_date': start_date,
'stop_date': stop_date,
'initial_usd': initial_usd,
'timeframes': timeframes,
'stop_loss_pcts': stop_loss_pcts,
}
# Use config values
start_date = config['start_date']
stop_date = config['stop_date']
initial_usd = config['initial_usd']
timeframes = config['timeframes']
stop_loss_pcts = config['stop_loss_pcts']
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")
storage = Storage(logging=logging) storage = Storage(logging=logging)
system_utils = SystemUtils(logging=logging) system_utils = SystemUtils(logging=logging)
timeframes = ["1D"]
stop_loss_pcts = [0.01, 0.02, 0.03]
# Load data once
data_1min = storage.load_data('btcusd_1-min_data.csv', start_date, stop_date) data_1min = storage.load_data('btcusd_1-min_data.csv', start_date, stop_date)
nearest_start_time, start_price = get_nearest_price(data_1min, start_date) nearest_start_time, start_price = get_nearest_price(data_1min, start_date)
nearest_stop_time, stop_price = get_nearest_price(data_1min, stop_date) nearest_stop_time, stop_price = get_nearest_price(data_1min, stop_date)
logging.info(f"Price at start_date ({start_date}) [nearest timestamp: {nearest_start_time}]: {start_price}") metadata_lines = [
logging.info(f"Price at stop_date ({stop_date}) [nearest timestamp: {nearest_stop_time}]: {stop_price}") f"Start date\t{start_date}\tPrice\t{start_price}",
f"Stop date\t{stop_date}\tPrice\t{stop_price}",
f"Initial USD\t{initial_usd}"
]
tasks = [ tasks = [
(name, data_1min, stop_loss_pct, initial_usd) (name, data_1min, stop_loss_pct, initial_usd)
@ -201,29 +235,35 @@ if __name__ == "__main__":
workers = system_utils.get_optimal_workers() workers = system_utils.get_optimal_workers()
# Process tasks with optimized concurrency if debug:
with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
futures = {executor.submit(process_timeframe, task, debug): task for task in tasks}
all_results_rows = [] all_results_rows = []
all_trade_rows = [] all_trade_rows = []
for future in concurrent.futures.as_completed(futures): for task in tasks:
results, trades = future.result() results, trades = process(task, debug)
if results or trades: if results or trades:
all_results_rows.extend(results) all_results_rows.extend(results)
all_trade_rows.extend(trades) all_trade_rows.extend(trades)
else:
with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
futures = {executor.submit(process, task, debug): task for task in tasks}
all_results_rows = []
all_trade_rows = []
# Write all results to a single CSV file for future in concurrent.futures.as_completed(futures):
combined_filename = os.path.join(f"{timestamp}_backtest_combined.csv") results, trades = future.result()
combined_fieldnames = [
if results or trades:
all_results_rows.extend(results)
all_trade_rows.extend(trades)
backtest_filename = os.path.join(f"{timestamp}_backtest.csv")
backtest_fieldnames = [
"timeframe", "stop_loss_pct", "n_trades", "n_stop_loss", "win_rate", "timeframe", "stop_loss_pct", "n_trades", "n_stop_loss", "win_rate",
"max_drawdown", "avg_trade", "profit_ratio", "final_usd", "total_fees_usd" "max_drawdown", "avg_trade", "profit_ratio", "final_usd", "total_fees_usd"
] ]
storage.write_results_combined(combined_filename, combined_fieldnames, all_results_rows) storage.write_backtest_results(backtest_filename, backtest_fieldnames, all_results_rows, metadata_lines)
# Now, group all_trade_rows by (timeframe, stop_loss_pct) trades_fieldnames = ["entry_time", "exit_time", "entry_price", "exit_price", "profit_pct", "type", "fee_usd"]
trades_fieldnames = [
"entry_time", "exit_time", "entry_price", "exit_price", "profit_pct", "type", "fee_usd"
]
storage.write_trades(all_trade_rows, trades_fieldnames) storage.write_trades(all_trade_rows, trades_fieldnames)