Enhance strategy output standardization and improve plotting logic
- Introduced a new method to standardize output column names across different strategies, ensuring consistency in data handling and plotting. - Updated plotting logic in test_bbrsi.py to utilize standardized column names, improving clarity and maintainability. - Enhanced error handling for missing data in plots and adjusted visual elements for better representation of trading signals. - Improved the overall structure of strategy implementations to support additional indicators and metadata for better analysis.
This commit is contained in:
parent
3a9dec543c
commit
00873d593f
@ -16,14 +16,66 @@ class Strategy:
|
|||||||
|
|
||||||
def run(self, data, strategy_name):
|
def run(self, data, strategy_name):
|
||||||
if strategy_name == "MarketRegimeStrategy":
|
if strategy_name == "MarketRegimeStrategy":
|
||||||
return self.MarketRegimeStrategy(data)
|
result = self.MarketRegimeStrategy(data)
|
||||||
|
return self.standardize_output(result, strategy_name)
|
||||||
elif strategy_name == "CryptoTradingStrategy":
|
elif strategy_name == "CryptoTradingStrategy":
|
||||||
return self.CryptoTradingStrategy(data)
|
result = self.CryptoTradingStrategy(data)
|
||||||
|
return self.standardize_output(result, strategy_name)
|
||||||
else:
|
else:
|
||||||
if self.logging is not None:
|
if self.logging is not None:
|
||||||
self.logging.warning(f"Strategy {strategy_name} not found. Using no_strategy instead.")
|
self.logging.warning(f"Strategy {strategy_name} not found. Using no_strategy instead.")
|
||||||
return self.no_strategy(data)
|
return self.no_strategy(data)
|
||||||
|
|
||||||
|
def standardize_output(self, data, strategy_name):
|
||||||
|
"""
|
||||||
|
Standardize column names across different strategies to ensure consistent plotting and analysis
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (DataFrame): Strategy output DataFrame
|
||||||
|
strategy_name (str): Name of the strategy that generated this data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame: Data with standardized column names
|
||||||
|
"""
|
||||||
|
if data.empty:
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Create a copy to avoid modifying the original
|
||||||
|
standardized = data.copy()
|
||||||
|
|
||||||
|
# Standardize column names based on strategy
|
||||||
|
if strategy_name == "MarketRegimeStrategy":
|
||||||
|
# MarketRegimeStrategy already has standard column names for most fields
|
||||||
|
# Just ensure all standard columns exist
|
||||||
|
pass
|
||||||
|
elif strategy_name == "CryptoTradingStrategy":
|
||||||
|
# Map strategy-specific column names to standard names
|
||||||
|
column_mapping = {
|
||||||
|
'UpperBand_15m': 'UpperBand',
|
||||||
|
'LowerBand_15m': 'LowerBand',
|
||||||
|
'SMA_15m': 'SMA',
|
||||||
|
'RSI_15m': 'RSI',
|
||||||
|
'VolumeMA_15m': 'VolumeMA',
|
||||||
|
# Keep StopLoss and TakeProfit as they are
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add standard columns from mapped columns
|
||||||
|
for old_col, new_col in column_mapping.items():
|
||||||
|
if old_col in standardized.columns and new_col not in standardized.columns:
|
||||||
|
standardized[new_col] = standardized[old_col]
|
||||||
|
|
||||||
|
# Add additional strategy-specific data as metadata columns
|
||||||
|
if 'UpperBand_1h' in standardized.columns:
|
||||||
|
standardized['UpperBand_1h_meta'] = standardized['UpperBand_1h']
|
||||||
|
if 'LowerBand_1h' in standardized.columns:
|
||||||
|
standardized['LowerBand_1h_meta'] = standardized['LowerBand_1h']
|
||||||
|
|
||||||
|
# Ensure all strategies have BBWidth if possible
|
||||||
|
if 'BBWidth' not in standardized.columns and 'UpperBand' in standardized.columns and 'LowerBand' in standardized.columns:
|
||||||
|
standardized['BBWidth'] = (standardized['UpperBand'] - standardized['LowerBand']) / standardized['SMA'] if 'SMA' in standardized.columns else np.nan
|
||||||
|
|
||||||
|
return standardized
|
||||||
|
|
||||||
def no_strategy(self, data):
|
def no_strategy(self, data):
|
||||||
"""No strategy: returns False for both buy and sell conditions"""
|
"""No strategy: returns False for both buy and sell conditions"""
|
||||||
buy_condition = pd.Series([False] * len(data), index=data.index)
|
buy_condition = pd.Series([False] * len(data), index=data.index)
|
||||||
@ -74,6 +126,7 @@ class Strategy:
|
|||||||
DataFrame: A unified DataFrame containing original data, BB, RSI, and signals.
|
DataFrame: A unified DataFrame containing original data, BB, RSI, and signals.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# data = aggregate_to_hourly(data, 4)
|
||||||
data = aggregate_to_daily(data)
|
data = aggregate_to_daily(data)
|
||||||
|
|
||||||
# Calculate Bollinger Bands
|
# Calculate Bollinger Bands
|
||||||
|
|||||||
@ -33,7 +33,7 @@ config_strategy = {
|
|||||||
"rsi_threshold": [40, 60],
|
"rsi_threshold": [40, 60],
|
||||||
"bb_std_dev_multiplier": 1.8,
|
"bb_std_dev_multiplier": 1.8,
|
||||||
},
|
},
|
||||||
"strategy_name": "MarketRegimeStrategy",
|
"strategy_name": "MarketRegimeStrategy", # CryptoTradingStrategy
|
||||||
"SqueezeStrategy": True
|
"SqueezeStrategy": True
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,18 +65,19 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Plot 1: Close Price and Strategy-Specific Bands/Levels
|
# Plot 1: Close Price and Strategy-Specific Bands/Levels
|
||||||
sns.lineplot(x=processed_data.index, y='close', data=processed_data, label='Close Price', ax=ax1)
|
sns.lineplot(x=processed_data.index, y='close', data=processed_data, label='Close Price', ax=ax1)
|
||||||
if strategy_name == "MarketRegimeStrategy":
|
|
||||||
|
# Use standardized column names for bands
|
||||||
if 'UpperBand' in processed_data.columns and 'LowerBand' in processed_data.columns:
|
if 'UpperBand' in processed_data.columns and 'LowerBand' in processed_data.columns:
|
||||||
sns.lineplot(x=processed_data.index, y='UpperBand', data=processed_data, label='Upper Band (BB)', ax=ax1)
|
# Instead of lines, shade the area between upper and lower bands
|
||||||
sns.lineplot(x=processed_data.index, y='LowerBand', data=processed_data, label='Lower Band (BB)', ax=ax1)
|
ax1.fill_between(processed_data.index,
|
||||||
|
processed_data['LowerBand'],
|
||||||
|
processed_data['UpperBand'],
|
||||||
|
alpha=0.1, color='blue', label='Bollinger Bands')
|
||||||
else:
|
else:
|
||||||
logging.warning("MarketRegimeStrategy: UpperBand or LowerBand not found for plotting.")
|
logging.warning(f"{strategy_name}: UpperBand or LowerBand not found for plotting.")
|
||||||
elif strategy_name == "CryptoTradingStrategy":
|
|
||||||
if 'UpperBand_15m' in processed_data.columns and 'LowerBand_15m' in processed_data.columns:
|
# Add strategy-specific extra indicators if available
|
||||||
sns.lineplot(x=processed_data.index, y='UpperBand_15m', data=processed_data, label='Upper Band (15m)', ax=ax1)
|
if strategy_name == "CryptoTradingStrategy":
|
||||||
sns.lineplot(x=processed_data.index, y='LowerBand_15m', data=processed_data, label='Lower Band (15m)', ax=ax1)
|
|
||||||
else:
|
|
||||||
logging.warning("CryptoTradingStrategy: UpperBand_15m or LowerBand_15m not found for plotting.")
|
|
||||||
if 'StopLoss' in processed_data.columns:
|
if 'StopLoss' in processed_data.columns:
|
||||||
sns.lineplot(x=processed_data.index, y='StopLoss', data=processed_data, label='Stop Loss', ax=ax1, linestyle='--', color='orange')
|
sns.lineplot(x=processed_data.index, y='StopLoss', data=processed_data, label='Stop Loss', ax=ax1, linestyle='--', color='orange')
|
||||||
if 'TakeProfit' in processed_data.columns:
|
if 'TakeProfit' in processed_data.columns:
|
||||||
@ -84,54 +85,68 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Plot Buy/Sell signals on Price chart
|
# Plot Buy/Sell signals on Price chart
|
||||||
if not buy_signals.empty:
|
if not buy_signals.empty:
|
||||||
ax1.scatter(buy_signals.index, buy_signals['close'], color='green', marker='o', s=10, label='Buy Signal', zorder=5)
|
ax1.scatter(buy_signals.index, buy_signals['close'], color='green', marker='o', s=20, label='Buy Signal', zorder=5)
|
||||||
if not sell_signals.empty:
|
if not sell_signals.empty:
|
||||||
ax1.scatter(sell_signals.index, sell_signals['close'], color='red', marker='o', s=10, label='Sell Signal', zorder=5)
|
ax1.scatter(sell_signals.index, sell_signals['close'], color='red', marker='o', s=20, label='Sell Signal', zorder=5)
|
||||||
ax1.set_title(f'Price and Signals ({strategy_name})')
|
ax1.set_title(f'Price and Signals ({strategy_name})')
|
||||||
ax1.set_ylabel('Price')
|
ax1.set_ylabel('Price')
|
||||||
ax1.legend()
|
ax1.legend()
|
||||||
ax1.grid(True)
|
ax1.grid(True)
|
||||||
|
|
||||||
# Plot 2: RSI and Strategy-Specific Thresholds
|
# Plot 2: RSI and Strategy-Specific Thresholds
|
||||||
rsi_col_name = 'RSI' if strategy_name == "MarketRegimeStrategy" else 'RSI_15m'
|
if 'RSI' in processed_data.columns:
|
||||||
if rsi_col_name in processed_data.columns:
|
sns.lineplot(x=processed_data.index, y='RSI', data=processed_data, label=f'RSI (' + str(config_strategy.get("rsi_period", 14)) + ')', ax=ax2, color='purple')
|
||||||
sns.lineplot(x=processed_data.index, y=rsi_col_name, data=processed_data, label=f'{rsi_col_name} (' + str(config_strategy.get("rsi_period", 14)) + ')', ax=ax2, color='purple')
|
|
||||||
if strategy_name == "MarketRegimeStrategy":
|
if strategy_name == "MarketRegimeStrategy":
|
||||||
# Assuming trending thresholds are what we want to show generally
|
# Get threshold values
|
||||||
ax2.axhline(config_strategy.get("trending", {}).get("rsi_threshold", [30,70])[1], color='red', linestyle='--', linewidth=0.8, label=f'Overbought (' + str(config_strategy.get("trending", {}).get("rsi_threshold", [30,70])[1]) + ')')
|
upper_threshold = config_strategy.get("trending", {}).get("rsi_threshold", [30,70])[1]
|
||||||
ax2.axhline(config_strategy.get("trending", {}).get("rsi_threshold", [30,70])[0], color='green', linestyle='--', linewidth=0.8, label=f'Oversold (' + str(config_strategy.get("trending", {}).get("rsi_threshold", [30,70])[0]) + ')')
|
lower_threshold = config_strategy.get("trending", {}).get("rsi_threshold", [30,70])[0]
|
||||||
|
|
||||||
|
# Shade overbought area (upper)
|
||||||
|
ax2.fill_between(processed_data.index, upper_threshold, 100,
|
||||||
|
alpha=0.1, color='red', label=f'Overbought (>{upper_threshold})')
|
||||||
|
|
||||||
|
# Shade oversold area (lower)
|
||||||
|
ax2.fill_between(processed_data.index, 0, lower_threshold,
|
||||||
|
alpha=0.1, color='green', label=f'Oversold (<{lower_threshold})')
|
||||||
|
|
||||||
elif strategy_name == "CryptoTradingStrategy":
|
elif strategy_name == "CryptoTradingStrategy":
|
||||||
ax2.axhline(65, color='red', linestyle='--', linewidth=0.8, label='Overbought (65)') # As per Crypto strategy logic
|
# Shade overbought area (upper)
|
||||||
ax2.axhline(35, color='green', linestyle='--', linewidth=0.8, label='Oversold (35)') # As per Crypto strategy logic
|
ax2.fill_between(processed_data.index, 65, 100,
|
||||||
|
alpha=0.1, color='red', label='Overbought (>65)')
|
||||||
|
|
||||||
|
# Shade oversold area (lower)
|
||||||
|
ax2.fill_between(processed_data.index, 0, 35,
|
||||||
|
alpha=0.1, color='green', label='Oversold (<35)')
|
||||||
|
|
||||||
# Plot Buy/Sell signals on RSI chart
|
# Plot Buy/Sell signals on RSI chart
|
||||||
if not buy_signals.empty and rsi_col_name in buy_signals.columns:
|
if not buy_signals.empty and 'RSI' in buy_signals.columns:
|
||||||
ax2.scatter(buy_signals.index, buy_signals[rsi_col_name], color='green', marker='o', s=20, label=f'Buy Signal ({rsi_col_name})', zorder=5)
|
ax2.scatter(buy_signals.index, buy_signals['RSI'], color='green', marker='o', s=20, label='Buy Signal (RSI)', zorder=5)
|
||||||
if not sell_signals.empty and rsi_col_name in sell_signals.columns:
|
if not sell_signals.empty and 'RSI' in sell_signals.columns:
|
||||||
ax2.scatter(sell_signals.index, sell_signals[rsi_col_name], color='red', marker='o', s=20, label=f'Sell Signal ({rsi_col_name})', zorder=5)
|
ax2.scatter(sell_signals.index, sell_signals['RSI'], color='red', marker='o', s=20, label='Sell Signal (RSI)', zorder=5)
|
||||||
ax2.set_title(f'Relative Strength Index ({rsi_col_name}) with Signals')
|
ax2.set_title('Relative Strength Index (RSI) with Signals')
|
||||||
ax2.set_ylabel(f'{rsi_col_name} Value')
|
ax2.set_ylabel('RSI Value')
|
||||||
ax2.set_ylim(0, 100)
|
ax2.set_ylim(0, 100)
|
||||||
ax2.legend()
|
ax2.legend()
|
||||||
ax2.grid(True)
|
ax2.grid(True)
|
||||||
else:
|
else:
|
||||||
logging.info(f"{rsi_col_name} data not available for plotting.")
|
logging.info("RSI data not available for plotting.")
|
||||||
|
|
||||||
# Plot 3: Strategy-Specific Indicators
|
# Plot 3: Strategy-Specific Indicators
|
||||||
ax3.clear() # Clear previous plot content if any
|
ax3.clear() # Clear previous plot content if any
|
||||||
if strategy_name == "MarketRegimeStrategy":
|
|
||||||
if 'BBWidth' in processed_data.columns:
|
if 'BBWidth' in processed_data.columns:
|
||||||
sns.lineplot(x=processed_data.index, y='BBWidth', data=processed_data, label='BB Width', ax=ax3)
|
sns.lineplot(x=processed_data.index, y='BBWidth', data=processed_data, label='BB Width', ax=ax3)
|
||||||
|
|
||||||
|
if strategy_name == "MarketRegimeStrategy":
|
||||||
if 'MarketRegime' in processed_data.columns:
|
if 'MarketRegime' in processed_data.columns:
|
||||||
sns.lineplot(x=processed_data.index, y='MarketRegime', data=processed_data, label='Market Regime (Sideways: 1, Trending: 0)', ax=ax3)
|
sns.lineplot(x=processed_data.index, y='MarketRegime', data=processed_data, label='Market Regime (Sideways: 1, Trending: 0)', ax=ax3)
|
||||||
ax3.set_title('Bollinger Bands Width & Market Regime')
|
ax3.set_title('Bollinger Bands Width & Market Regime')
|
||||||
ax3.set_ylabel('Value')
|
ax3.set_ylabel('Value')
|
||||||
elif strategy_name == "CryptoTradingStrategy":
|
elif strategy_name == "CryptoTradingStrategy":
|
||||||
if 'VolumeMA_15m' in processed_data.columns:
|
if 'VolumeMA' in processed_data.columns:
|
||||||
sns.lineplot(x=processed_data.index, y='VolumeMA_15m', data=processed_data, label='Volume MA (15m)', ax=ax3)
|
sns.lineplot(x=processed_data.index, y='VolumeMA', data=processed_data, label='Volume MA', ax=ax3)
|
||||||
if 'volume' in processed_data.columns: # Plot original volume for comparison
|
if 'volume' in processed_data.columns:
|
||||||
sns.lineplot(x=processed_data.index, y='volume', data=processed_data, label='Volume (15m)', ax=ax3, alpha=0.5)
|
sns.lineplot(x=processed_data.index, y='volume', data=processed_data, label='Volume', ax=ax3, alpha=0.5)
|
||||||
ax3.set_title('Volume Analysis (15m)')
|
ax3.set_title('Volume Analysis')
|
||||||
ax3.set_ylabel('Volume')
|
ax3.set_ylabel('Volume')
|
||||||
|
|
||||||
ax3.legend()
|
ax3.legend()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user