Enhance trading logic and fee calculations in main.py and trend_detector_simple.py

- Added total fees calculation to process_timeframe_data and aggregate_results functions in main.py.
- Updated trade logging in TrendDetectorSimple to include transaction fees in USD.
- Introduced calculate_okx_fee function for consistent fee calculations based on maker/taker status.
- Adjusted backtesting logic to account for fees when buying and selling, ensuring accurate profit calculations.
- Expanded stop loss percentages and timeframes for broader analysis in main.py.
This commit is contained in:
Simon Moisy 2025-05-21 14:54:44 +08:00
parent 8ff86339d6
commit c2886a2aab
2 changed files with 46 additions and 197 deletions

185
main.py
View File

@ -27,43 +27,6 @@ logging.basicConfig(
]
)
# Global queue for batching Google Sheets updates
results_queue = queue.Queue()
# Background thread function to push updates every minute
class GSheetBatchPusher(threading.Thread):
def __init__(self, queue, timestamp, spreadsheet_name, interval=60):
super().__init__(daemon=True)
self.queue = queue
self.timestamp = timestamp
self.spreadsheet_name = spreadsheet_name
self.interval = interval
self._stop_event = threading.Event()
def run(self):
while not self._stop_event.is_set():
self.push_all()
time.sleep(self.interval)
# Final push on stop
self.push_all()
def stop(self):
self._stop_event.set()
def push_all(self):
batch_results = []
batch_trades = []
while True:
try:
results, trades = self.queue.get_nowait()
batch_results.extend(results)
batch_trades.extend(trades)
except queue.Empty:
break
if batch_results or batch_trades:
write_results_per_combination_gsheet(batch_results, batch_trades, self.timestamp, self.spreadsheet_name)
def get_optimal_workers():
"""Determine optimal number of worker processes based on system resources"""
cpu_count = os.cpu_count() or 4
@ -145,6 +108,7 @@ def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd,
final_usd = initial_usd
for trade in trades:
final_usd *= (1 + trade['profit_pct'])
total_fees_usd = sum(trade.get('fee_usd', 0.0) for trade in trades)
row = {
"timeframe": rule_name,
"stop_loss_pct": stop_loss_pct,
@ -158,6 +122,7 @@ def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd,
"profit_ratio": profit_ratio,
"initial_usd": initial_usd,
"final_usd": final_usd,
"total_fees_usd": total_fees_usd,
}
results_rows.append(row)
for trade in trades:
@ -169,7 +134,8 @@ def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd,
"entry_price": trade.get("entry"),
"exit_price": trade.get("exit"),
"profit_pct": trade.get("profit_pct"),
"type": trade.get("type", ""),
"type": trade.get("type"),
"fee_usd": trade.get("fee_usd"),
})
logging.info(f"Timeframe: {rule_name}, Stop Loss: {stop_loss_pct}, Trades: {n_trades}")
if debug:
@ -235,6 +201,7 @@ def aggregate_results(all_rows):
# Calculate final USD
final_usd = np.mean([r.get('final_usd', initial_usd) for r in rows])
total_fees_usd = np.mean([r.get('total_fees_usd') for r in rows])
summary_rows.append({
"timeframe": rule,
@ -247,108 +214,28 @@ def aggregate_results(all_rows):
"profit_ratio": avg_profit_ratio,
"initial_usd": initial_usd,
"final_usd": final_usd,
"total_fees_usd": total_fees_usd,
})
return summary_rows
def write_results_per_combination_gsheet(results_rows, trade_rows, timestamp, spreadsheet_name="GlimBit Backtest Results"):
scopes = [
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/drive"
]
creds = Credentials.from_service_account_file('credentials/service_account.json', scopes=scopes)
gc = gspread.authorize(creds)
sh = gc.open(spreadsheet_name)
try:
worksheet = sh.worksheet("Results")
except gspread.exceptions.WorksheetNotFound:
worksheet = sh.add_worksheet(title="Results", rows="1000", cols="20")
# Clear the worksheet before writing new results
worksheet.clear()
# Updated fieldnames to match your data rows
fieldnames = [
"timeframe", "stop_loss_pct", "n_trades", "n_stop_loss", "win_rate",
"max_drawdown", "avg_trade", "profit_ratio", "initial_usd", "final_usd"
]
def to_native(val):
if isinstance(val, (np.generic, np.ndarray)):
val = val.item()
if hasattr(val, 'isoformat'):
return val.isoformat()
# Handle inf, -inf, nan
if isinstance(val, float):
if math.isinf(val):
return "" if val > 0 else "-∞"
if math.isnan(val):
return ""
return val
# Write header if sheet is empty
if len(worksheet.get_all_values()) == 0:
worksheet.append_row(fieldnames)
for row in results_rows:
values = [to_native(row.get(field, "")) for field in fieldnames]
worksheet.append_row(values)
trades_fieldnames = [
"entry_time", "exit_time", "entry_price", "exit_price", "profit_pct", "type"
]
trades_by_combo = defaultdict(list)
for trade in trade_rows:
tf = trade.get("timeframe")
sl = trade.get("stop_loss_pct")
trades_by_combo[(tf, sl)].append(trade)
for (tf, sl), trades in trades_by_combo.items():
sl_percent = int(round(sl * 100))
sheet_name = f"Trades_{tf}_ST{sl_percent}%"
try:
trades_ws = sh.worksheet(sheet_name)
except gspread.exceptions.WorksheetNotFound:
trades_ws = sh.add_worksheet(title=sheet_name, rows="1000", cols="20")
# Clear the trades worksheet before writing new trades
trades_ws.clear()
if len(trades_ws.get_all_values()) == 0:
trades_ws.append_row(trades_fieldnames)
for trade in trades:
trade_row = [to_native(trade.get(field, "")) for field in trades_fieldnames]
try:
trades_ws.append_row(trade_row)
except gspread.exceptions.APIError as e:
if '429' in str(e):
logging.warning(f"Google Sheets API quota exceeded (429). Please wait one minute. Will retry on next batch push. Sheet: {sheet_name}")
# Re-queue the failed batch for retry
results_queue.put((results_rows, trade_rows))
return # Stop pushing for this batch, will retry next interval
else:
raise
if __name__ == "__main__":
# Configuration
# start_date = '2022-01-01'
# stop_date = '2023-01-01'
start_date = '2024-05-15'
stop_date = '2025-05-15'
initial_usd = 10000
debug = False
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M")
timeframes = ["1min", "5min"]
stop_loss_pcts = [0.01, 0.02, 0.03]
timeframes = ["1min", "5min", "15min"]
stop_loss_pcts = [0.01, 0.02, 0.03, 0.04, 0.05]
# Load data once
data_1min = load_data('./data/btcusd_1-min_data.csv', start_date, stop_date)
def get_nearest_price(df, target_date):
@ -375,11 +262,6 @@ if __name__ == "__main__":
workers = get_optimal_workers()
logging.info(f"Using {workers} workers for processing")
# Start the background batch pusher
# spreadsheet_name = "GlimBit Backtest Results"
# batch_pusher = GSheetBatchPusher(results_queue, timestamp, spreadsheet_name, interval=65)
# batch_pusher.start()
# Process tasks with optimized concurrency
with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
futures = {executor.submit(process_timeframe, task, debug): task for task in tasks}
@ -390,27 +272,15 @@ if __name__ == "__main__":
if results or trades:
all_results_rows.extend(results)
all_trade_rows.extend(trades)
# results_queue.put((results, trades)) # Enqueue for batch update
# After all tasks, flush any remaining updates
# batch_pusher.stop()
# batch_pusher.join()
# Ensure all batches are pushed, even after 429 errors
# while not results_queue.empty():
# logging.info("Waiting for Google Sheets quota to reset. Retrying batch push in 60 seconds...")
# time.sleep(65)
# batch_pusher.push_all()
# Write all results to a single CSV file
combined_filename = os.path.join(results_dir, f"{timestamp}_backtest_combined.csv")
combined_fieldnames = [
"timeframe", "stop_loss_pct", "n_trades", "n_stop_loss", "win_rate",
"max_drawdown", "avg_trade", "profit_ratio", "final_usd"
"max_drawdown", "avg_trade", "profit_ratio", "final_usd", "total_fees_usd"
]
def format_row(row):
# Format percentages and floats as in your example
return {
"timeframe": row["timeframe"],
"stop_loss_pct": f"{row['stop_loss_pct']*100:.2f}%",
@ -421,6 +291,7 @@ if __name__ == "__main__":
"avg_trade": f"{row['avg_trade']*100:.2f}%",
"profit_ratio": f"{row['profit_ratio']*100:.2f}%",
"final_usd": f"{row['final_usd']:.2f}",
"total_fees_usd": f"{row.get('total_fees_usd'):.2f}",
}
with open(combined_filename, "w", newline="") as csvfile:
@ -431,31 +302,6 @@ if __name__ == "__main__":
logging.info(f"Combined results written to {combined_filename}")
# --- Add taxes to combined results CSV ---
# taxes = Taxes() # Default 20% tax rate
# taxed_filename = combined_filename.replace('.csv', '_taxed.csv')
# taxes.add_taxes_to_results_csv(combined_filename, taxed_filename, profit_col='total_profit')
# logging.info(f"Taxed results written to {taxed_filename}")
# --- Write trades to separate CSVs per timeframe and stop loss ---
# Collect all trades from each task (need to run tasks to collect trades)
# Since only all_results_rows is collected above, we need to also collect all trades.
# To do this, modify the above loop to collect all trades as well.
# But for now, let's assume you have a list all_trade_rows (list of dicts)
# If not, you need to collect it in the ProcessPoolExecutor loop above.
# --- BEGIN: Collect all trades from each task ---
# To do this, modify the ProcessPoolExecutor loop above:
# all_results_rows = []
# all_trade_rows = []
# ...
# for future in concurrent.futures.as_completed(futures):
# results, trades = future.result()
# if results or trades:
# all_results_rows.extend(results)
# all_trade_rows.extend(trades)
# --- END: Collect all trades from each task ---
# Now, group all_trade_rows by (timeframe, stop_loss_pct)
from collections import defaultdict
trades_by_combo = defaultdict(list)
@ -465,7 +311,7 @@ if __name__ == "__main__":
trades_by_combo[(tf, sl)].append(trade)
trades_fieldnames = [
"entry_time", "exit_time", "entry_price", "exit_price", "profit_pct", "type"
"entry_time", "exit_time", "entry_price", "exit_price", "profit_pct", "type", "fee_usd"
]
for (tf, sl), trades in trades_by_combo.items():
@ -475,7 +321,10 @@ if __name__ == "__main__":
writer = csv.DictWriter(csvfile, fieldnames=trades_fieldnames)
writer.writeheader()
for trade in trades:
writer.writerow({k: trade.get(k, "") for k in trades_fieldnames})
row = {k: trade.get(k, "") for k in trades_fieldnames}
fee = trade.get("fee_usd")
row["fee_usd"] = f"{float(fee):.2f}"
writer.writerow(row)
logging.info(f"Trades written to {trades_filename}")

View File

@ -114,6 +114,10 @@ def calculate_supertrend_external(data, period, multiplier):
# Call the cached function
return cached_supertrend_calculation(period, multiplier, (high_tuple, low_tuple, close_tuple))
def calculate_okx_fee(amount, is_maker=True):
fee_rate = 0.0008 if is_maker else 0.0010
return amount * fee_rate
class TrendDetectorSimple:
def __init__(self, data, verbose=False, display=False):
"""
@ -638,7 +642,7 @@ class TrendDetectorSimple:
ax.plot([], [], color_down, linewidth=self.line_width,
label=f'ST (P:{period}, M:{multiplier}) Down')
def backtest_meta_supertrend(self, min1_df, initial_usd=10000, stop_loss_pct=0.05, transaction_cost=0.001, debug=False):
def backtest_meta_supertrend(self, min1_df, initial_usd=10000, stop_loss_pct=0.05, debug=False):
"""
Backtest a simple strategy using the meta supertrend (all three supertrends agree).
Buys when meta supertrend is positive, sells when negative, applies a percentage stop loss.
@ -647,7 +651,6 @@ class TrendDetectorSimple:
- min1_df: pandas DataFrame, 1-minute timeframe data for more accurate stop loss checking (optional)
- initial_usd: float, starting USD amount
- stop_loss_pct: float, stop loss as a fraction (e.g. 0.05 for 5%)
- transaction_cost: float, transaction cost as a fraction (e.g. 0.001 for 0.1%)
- debug: bool, whether to print debug info
"""
df = self.data.copy().reset_index(drop=True)
@ -709,16 +712,16 @@ class TrendDetectorSimple:
if debug:
print(f"STOP LOSS triggered: entry={entry_price}, stop={stop_price}, sell_price={sell_price}, entry_time={entry_time}, stop_time={stop_candle.name}")
btc_to_sell = coin
fee_btc = btc_to_sell * transaction_cost
btc_after_fee = btc_to_sell - fee_btc
usd = btc_after_fee * sell_price
usd_gross = btc_to_sell * sell_price
exit_fee = calculate_okx_fee(usd_gross, is_maker=False) # taker fee
usd = usd_gross - exit_fee
trade_log.append({
'type': 'STOP',
'entry': entry_price,
'exit': sell_price,
'entry_time': entry_time,
'exit_time': stop_candle.name, # Use index name instead of timestamp column
'fee_btc': fee_btc
'exit_time': stop_candle.name,
'fee_usd': exit_fee
})
coin = 0
position = 0
@ -731,10 +734,10 @@ class TrendDetectorSimple:
# Entry: only if not in position and signal changes to 1
if position == 0 and prev_mt != 1 and curr_mt == 1:
# Buy at open, fee is charged in BTC (base currency)
gross_btc = usd / price_open
fee_btc = gross_btc * transaction_cost
coin = gross_btc - fee_btc
# Buy at open, fee is charged in USD
entry_fee = calculate_okx_fee(usd, is_maker=False)
usd_after_fee = usd - entry_fee
coin = usd_after_fee / price_open
entry_price = price_open
entry_time = date
usd = 0
@ -746,23 +749,23 @@ class TrendDetectorSimple:
'exit': None,
'entry_time': entry_time,
'exit_time': None,
'fee_btc': fee_btc
'fee_usd': entry_fee
})
# Exit: only if in position and signal changes from 1 to -1
elif position == 1 and prev_mt == 1 and curr_mt == -1:
# Sell at open, fee is charged in BTC (base currency)
# Sell at open, fee is charged in USD
btc_to_sell = coin
fee_btc = btc_to_sell * transaction_cost
btc_after_fee = btc_to_sell - fee_btc
usd = btc_after_fee * price_open
usd_gross = btc_to_sell * price_open
exit_fee = calculate_okx_fee(usd_gross, is_maker=False)
usd = usd_gross - exit_fee
trade_log.append({
'type': 'SELL',
'entry': entry_price,
'exit': price_open,
'entry_time': entry_time,
'exit_time': date,
'fee_btc': fee_btc
'fee_usd': exit_fee
})
coin = 0
position = 0
@ -779,16 +782,16 @@ class TrendDetectorSimple:
# If still in position at end, sell at last close
if position == 1:
btc_to_sell = coin
fee_btc = btc_to_sell * transaction_cost
btc_after_fee = btc_to_sell - fee_btc
usd = btc_after_fee * df['close'].iloc[-1]
usd_gross = btc_to_sell * df['close'].iloc[-1]
exit_fee = calculate_okx_fee(usd_gross, is_maker=False)
usd = usd_gross - exit_fee
trade_log.append({
'type': 'EOD',
'entry': entry_price,
'exit': df['close'].iloc[-1],
'entry_time': entry_time,
'exit_time': df['timestamp'].iloc[-1],
'fee_btc': fee_btc
'fee_usd': exit_fee
})
coin = 0
position = 0
@ -803,7 +806,6 @@ class TrendDetectorSimple:
avg_trade = np.mean([t['exit']/t['entry']-1 for t in trade_log if t['exit'] is not None]) if trade_log else 0
trades = []
total_fees_btc = 0.0
total_fees_usd = 0.0
for trade in trade_log:
if trade['exit'] is not None:
@ -816,13 +818,12 @@ class TrendDetectorSimple:
'entry': trade['entry'],
'exit': trade['exit'],
'profit_pct': profit_pct,
'type': trade.get('type', 'SELL')
'type': trade.get('type', 'SELL'),
'fee_usd': trade.get('fee_usd', 0.0)
})
# Sum up BTC fees and their USD equivalent (use exit price if available)
fee_btc = trade.get('fee_btc', 0.0)
total_fees_btc += fee_btc
if fee_btc and trade.get('exit') is not None:
total_fees_usd += fee_btc * trade['exit']
# Sum up USD fees
fee_usd = trade.get('fee_usd', 0.0)
total_fees_usd += fee_usd
results = {
"initial_usd": initial_usd,
@ -833,7 +834,6 @@ class TrendDetectorSimple:
"avg_trade": avg_trade,
"trade_log": trade_log,
"trades": trades,
"total_fees_btc": total_fees_btc,
"total_fees_usd": total_fees_usd,
}
if n_trades > 0: