From 8ff86339d69895fdd1a8a058e55d549787c48516 Mon Sep 17 00:00:00 2001 From: Simon Moisy Date: Tue, 20 May 2025 16:14:40 +0800 Subject: [PATCH] Add taxes functionality and refactor trading logic - Introduced Taxes class in taxes.py to calculate and apply taxes on profits in backtest results. - Updated main.py to include tax calculations in the results processing flow. - Refactored trade logging in TrendDetectorSimple to account for transaction fees and ensure accurate profit calculations. - Added a utility script (apply_taxes_to_file.py) for applying taxes to existing CSV files. - Adjusted date range and timeframe settings in main.py for broader analysis. --- main.py | 22 +++++++----- taxes.py | 25 +++++++++++++ trend_detector_simple.py | 69 ++++++++++++++++++++++++------------ utils/apply_taxes_to_file.py | 23 ++++++++++++ 4 files changed, 108 insertions(+), 31 deletions(-) create mode 100644 taxes.py create mode 100644 utils/apply_taxes_to_file.py diff --git a/main.py b/main.py index 96ebece..51343d8 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,7 @@ import queue import time import math import json +from taxes import Taxes # Set up logging logging.basicConfig( @@ -124,7 +125,8 @@ def process_timeframe_data(min1_df, df, stop_loss_pcts, rule_name, initial_usd, ) n_trades = results["n_trades"] trades = results.get('trades', []) - n_winning_trades = sum(1 for trade in trades if trade['profit_pct'] > 0) + wins = [1 for t in trades if t['exit'] is not None and t['exit'] > t['entry']] + n_winning_trades = len(wins) total_profit = sum(trade['profit_pct'] for trade in trades) total_loss = sum(-trade['profit_pct'] for trade in trades if trade['profit_pct'] < 0) win_rate = n_winning_trades / n_trades if n_trades > 0 else 0 @@ -332,8 +334,10 @@ def write_results_per_combination_gsheet(results_rows, trade_rows, timestamp, sp if __name__ == "__main__": # Configuration - start_date = '2022-01-01' - stop_date = '2023-01-01' + # start_date = '2022-01-01' + # stop_date = '2023-01-01' + start_date = '2024-05-15' + stop_date = '2025-05-15' initial_usd = 10000 debug = False @@ -341,14 +345,12 @@ if __name__ == "__main__": os.makedirs(results_dir, exist_ok=True) timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M") - timeframes = ["1min", "5min", "15min", "30min", "1h", "4h", "6h", "12h", "1D"] + timeframes = ["1min", "5min"] stop_loss_pcts = [0.01, 0.02, 0.03] # Load data once data_1min = load_data('./data/btcusd_1-min_data.csv', start_date, stop_date) - logging.info(f"1min rows: {len(data_1min)}") - # Log the price at the nearest timestamp to start_date and stop_date def get_nearest_price(df, target_date): if len(df) == 0: return None, None @@ -364,14 +366,12 @@ if __name__ == "__main__": logging.info(f"Price at start_date ({start_date}) [nearest timestamp: {nearest_start_time}]: {start_price}") logging.info(f"Price at stop_date ({stop_date}) [nearest timestamp: {nearest_stop_time}]: {stop_price}") - # Prepare tasks tasks = [ (name, data_1min, stop_loss_pct, initial_usd) for name in timeframes for stop_loss_pct in stop_loss_pcts ] - # Determine optimal worker count workers = get_optimal_workers() logging.info(f"Using {workers} workers for processing") @@ -431,6 +431,12 @@ if __name__ == "__main__": logging.info(f"Combined results written to {combined_filename}") + # --- Add taxes to combined results CSV --- + # taxes = Taxes() # Default 20% tax rate + # taxed_filename = combined_filename.replace('.csv', '_taxed.csv') + # taxes.add_taxes_to_results_csv(combined_filename, taxed_filename, profit_col='total_profit') + # logging.info(f"Taxed results written to {taxed_filename}") + # --- Write trades to separate CSVs per timeframe and stop loss --- # Collect all trades from each task (need to run tasks to collect trades) # Since only all_results_rows is collected above, we need to also collect all trades. diff --git a/taxes.py b/taxes.py new file mode 100644 index 0000000..30669c7 --- /dev/null +++ b/taxes.py @@ -0,0 +1,25 @@ +import pandas as pd + +class Taxes: + def __init__(self, tax_rate=0.20): + """ + tax_rate: flat tax rate on positive profits (e.g., 0.20 for 20%) + """ + self.tax_rate = tax_rate + + def add_taxes_to_results_csv(self, input_csv, output_csv=None, profit_col='final_usd'): + """ + Reads a backtest results CSV, adds tax columns, and writes to a new CSV. + - input_csv: path to the input CSV file + - output_csv: path to the output CSV file (if None, overwrite input) + - profit_col: column name for profit (default: 'final_usd') + """ + df = pd.read_csv(input_csv, delimiter=None) + # Compute tax only on positive profits + df['tax_paid'] = df[profit_col].apply(lambda x: self.tax_rate * x if x > 0 else 0) + df['net_profit_after_tax'] = df[profit_col] - df['tax_paid'] + df['cumulative_tax_paid'] = df['tax_paid'].cumsum() + if not output_csv: + output_csv = input_csv + df.to_csv(output_csv, index=False) + return output_csv diff --git a/trend_detector_simple.py b/trend_detector_simple.py index 0610ec5..95a1bb5 100644 --- a/trend_detector_simple.py +++ b/trend_detector_simple.py @@ -660,17 +660,6 @@ class TrendDetectorSimple: meta_trend = np.where((trends_arr[:,0] == trends_arr[:,1]) & (trends_arr[:,1] == trends_arr[:,2]), trends_arr[:,0], 0) - if debug: - # Count flips (ignoring 0s) - flips = 0 - last = meta_trend[0] - for val in meta_trend[1:]: - if val != 0 and val != last: - flips += 1 - last = val - print(f"Meta trend flips (ignoring 0): {flips}") - print(f"Meta trend value counts: {np.unique(meta_trend, return_counts=True)}") - position = 0 # 0 = no position, 1 = long entry_price = 0 usd = initial_usd @@ -719,13 +708,17 @@ class TrendDetectorSimple: sell_price = stop_price if debug: print(f"STOP LOSS triggered: entry={entry_price}, stop={stop_price}, sell_price={sell_price}, entry_time={entry_time}, stop_time={stop_candle.name}") - usd = coin * sell_price * (1 - transaction_cost) # Apply transaction cost + btc_to_sell = coin + fee_btc = btc_to_sell * transaction_cost + btc_after_fee = btc_to_sell - fee_btc + usd = btc_after_fee * sell_price trade_log.append({ 'type': 'STOP', 'entry': entry_price, 'exit': sell_price, 'entry_time': entry_time, - 'exit_time': stop_candle.name # Use index name instead of timestamp column + 'exit_time': stop_candle.name, # Use index name instead of timestamp column + 'fee_btc': fee_btc }) coin = 0 position = 0 @@ -738,24 +731,38 @@ class TrendDetectorSimple: # Entry: only if not in position and signal changes to 1 if position == 0 and prev_mt != 1 and curr_mt == 1: - # Buy at open, apply transaction cost - coin = (usd * (1 - transaction_cost)) / price_open + # Buy at open, fee is charged in BTC (base currency) + gross_btc = usd / price_open + fee_btc = gross_btc * transaction_cost + coin = gross_btc - fee_btc entry_price = price_open entry_time = date usd = 0 position = 1 current_trade_min1_start_idx = None # Will be set on first stop loss check + trade_log.append({ + 'type': 'BUY', + 'entry': entry_price, + 'exit': None, + 'entry_time': entry_time, + 'exit_time': None, + 'fee_btc': fee_btc + }) # Exit: only if in position and signal changes from 1 to -1 elif position == 1 and prev_mt == 1 and curr_mt == -1: - # Sell at open, apply transaction cost - usd = coin * price_open * (1 - transaction_cost) + # Sell at open, fee is charged in BTC (base currency) + btc_to_sell = coin + fee_btc = btc_to_sell * transaction_cost + btc_after_fee = btc_to_sell - fee_btc + usd = btc_after_fee * price_open trade_log.append({ 'type': 'SELL', 'entry': entry_price, 'exit': price_open, 'entry_time': entry_time, - 'exit_time': date + 'exit_time': date, + 'fee_btc': fee_btc }) coin = 0 position = 0 @@ -771,13 +778,17 @@ class TrendDetectorSimple: # If still in position at end, sell at last close if position == 1: - usd = coin * df['close'].iloc[-1] * (1 - transaction_cost) # Apply transaction cost + btc_to_sell = coin + fee_btc = btc_to_sell * transaction_cost + btc_after_fee = btc_to_sell - fee_btc + usd = btc_after_fee * df['close'].iloc[-1] trade_log.append({ 'type': 'EOD', 'entry': entry_price, 'exit': df['close'].iloc[-1], 'entry_time': entry_time, - 'exit_time': df['timestamp'].iloc[-1] + 'exit_time': df['timestamp'].iloc[-1], + 'fee_btc': fee_btc }) coin = 0 position = 0 @@ -786,14 +797,19 @@ class TrendDetectorSimple: # Calculate statistics final_balance = usd n_trades = len(trade_log) - wins = [1 for t in trade_log if t['exit'] > t['entry']] + wins = [1 for t in trade_log if t['exit'] is not None and t['exit'] > t['entry']] win_rate = len(wins) / n_trades if n_trades > 0 else 0 max_drawdown = max(drawdowns) if drawdowns else 0 - avg_trade = np.mean([t['exit']/t['entry']-1 for t in trade_log]) if trade_log else 0 + avg_trade = np.mean([t['exit']/t['entry']-1 for t in trade_log if t['exit'] is not None]) if trade_log else 0 trades = [] + total_fees_btc = 0.0 + total_fees_usd = 0.0 for trade in trade_log: - profit_pct = (trade['exit'] - trade['entry']) / trade['entry'] + if trade['exit'] is not None: + profit_pct = (trade['exit'] - trade['entry']) / trade['entry'] + else: + profit_pct = 0.0 trades.append({ 'entry_time': trade['entry_time'], 'exit_time': trade['exit_time'], @@ -802,6 +818,11 @@ class TrendDetectorSimple: 'profit_pct': profit_pct, 'type': trade.get('type', 'SELL') }) + # Sum up BTC fees and their USD equivalent (use exit price if available) + fee_btc = trade.get('fee_btc', 0.0) + total_fees_btc += fee_btc + if fee_btc and trade.get('exit') is not None: + total_fees_usd += fee_btc * trade['exit'] results = { "initial_usd": initial_usd, @@ -812,6 +833,8 @@ class TrendDetectorSimple: "avg_trade": avg_trade, "trade_log": trade_log, "trades": trades, + "total_fees_btc": total_fees_btc, + "total_fees_usd": total_fees_usd, } if n_trades > 0: results["first_trade"] = { diff --git a/utils/apply_taxes_to_file.py b/utils/apply_taxes_to_file.py new file mode 100644 index 0000000..a5073db --- /dev/null +++ b/utils/apply_taxes_to_file.py @@ -0,0 +1,23 @@ +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from taxes import Taxes + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python apply_taxes_to_file.py [profit_col]") + sys.exit(1) + + input_csv = sys.argv[1] + profit_col = sys.argv[2] if len(sys.argv) > 2 else 'final_usd' + + if not os.path.isfile(input_csv): + print(f"File not found: {input_csv}") + sys.exit(1) + + base, ext = os.path.splitext(input_csv) + output_csv = f"{base}_taxed.csv" + + taxes = Taxes() # Default 20% tax rate + taxes.add_taxes_to_results_csv(input_csv, output_csv, profit_col=profit_col) + print(f"Taxed file saved as: {output_csv}")