Implement Regime Reversion Strategy and remove regime_detection.py
- Introduced `RegimeReversionStrategy` for ML-based regime detection and mean reversion trading. - Added feature engineering and model training logic within the new strategy. - Removed the deprecated `regime_detection.py` file to streamline the codebase. - Updated the strategy factory to include the new regime strategy configuration.
This commit is contained in:
@@ -1,384 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import ta
|
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from sklearn.metrics import classification_report, confusion_matrix
|
|
||||||
import plotly.graph_objects as go
|
|
||||||
from plotly.subplots import make_subplots
|
|
||||||
|
|
||||||
from engine.data_manager import DataManager
|
|
||||||
from engine.market import MarketType
|
|
||||||
|
|
||||||
def prepare_data(symbol_a="BTC-USDT", symbol_b="ETH-USDT", timeframe="1h", limit=None, start_date=None, end_date=None):
|
|
||||||
"""
|
|
||||||
Load and align data for two assets to create a pair.
|
|
||||||
"""
|
|
||||||
dm = DataManager()
|
|
||||||
|
|
||||||
print(f"Loading data for {symbol_a} and {symbol_b}...")
|
|
||||||
|
|
||||||
# Helper to load or download
|
|
||||||
def get_df(symbol):
|
|
||||||
try:
|
|
||||||
# Try load first
|
|
||||||
df = dm.load_data("okx", symbol, timeframe, MarketType.SPOT)
|
|
||||||
except Exception:
|
|
||||||
df = dm.download_data("okx", symbol, timeframe, market_type=MarketType.SPOT)
|
|
||||||
|
|
||||||
# If we have start/end dates, ensure we have enough data or re-download
|
|
||||||
if start_date:
|
|
||||||
mask_start = pd.Timestamp(start_date, tz='UTC')
|
|
||||||
if df.index.min() > mask_start:
|
|
||||||
print(f"Local data starts {df.index.min()}, need {mask_start}. Downloading...")
|
|
||||||
df = dm.download_data("okx", symbol, timeframe, start_date=start_date, end_date=end_date, market_type=MarketType.SPOT)
|
|
||||||
return df
|
|
||||||
|
|
||||||
df_a = get_df(symbol_a)
|
|
||||||
df_b = get_df(symbol_b)
|
|
||||||
|
|
||||||
# Filter by date if provided (to match CQ data range)
|
|
||||||
if start_date:
|
|
||||||
df_a = df_a[df_a.index >= pd.Timestamp(start_date, tz='UTC')]
|
|
||||||
df_b = df_b[df_b.index >= pd.Timestamp(start_date, tz='UTC')]
|
|
||||||
|
|
||||||
if end_date:
|
|
||||||
df_a = df_a[df_a.index <= pd.Timestamp(end_date, tz='UTC')]
|
|
||||||
df_b = df_b[df_b.index <= pd.Timestamp(end_date, tz='UTC')]
|
|
||||||
|
|
||||||
# Align DataFrames
|
|
||||||
print("Aligning data...")
|
|
||||||
common_index = df_a.index.intersection(df_b.index)
|
|
||||||
df_a = df_a.loc[common_index].copy()
|
|
||||||
df_b = df_b.loc[common_index].copy()
|
|
||||||
|
|
||||||
if limit:
|
|
||||||
df_a = df_a.tail(limit)
|
|
||||||
df_b = df_b.tail(limit)
|
|
||||||
|
|
||||||
return df_a, df_b
|
|
||||||
|
|
||||||
def load_cryptoquant_data(file_path: str) -> pd.DataFrame | None:
|
|
||||||
"""
|
|
||||||
Load CryptoQuant data and prepare it for merging.
|
|
||||||
"""
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
print(f"Warning: CQ data file {file_path} not found.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"Loading CryptoQuant data from {file_path}...")
|
|
||||||
df = pd.read_csv(file_path, index_col='timestamp', parse_dates=True)
|
|
||||||
|
|
||||||
# CQ data is usually daily (UTC 00:00).
|
|
||||||
# Ensure index is timezone aware to match market data
|
|
||||||
if df.index.tz is None:
|
|
||||||
df.index = df.index.tz_localize('UTC')
|
|
||||||
|
|
||||||
return df
|
|
||||||
|
|
||||||
def calculate_features(df_a, df_b, cq_df=None, window=24):
|
|
||||||
"""
|
|
||||||
Calculate spread, z-score, and advanced regime features including CQ data.
|
|
||||||
"""
|
|
||||||
# 1. Price Ratio (Spread)
|
|
||||||
spread = df_b['close'] / df_a['close']
|
|
||||||
|
|
||||||
# 2. Rolling Statistics for Z-Score
|
|
||||||
rolling_mean = spread.rolling(window=window).mean()
|
|
||||||
rolling_std = spread.rolling(window=window).std()
|
|
||||||
z_score = (spread - rolling_mean) / rolling_std
|
|
||||||
|
|
||||||
# 3. Spread Momentum / Technicals
|
|
||||||
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
|
||||||
spread_roc = spread.pct_change(periods=5) * 100
|
|
||||||
|
|
||||||
# 4. Volume Dynamics
|
|
||||||
vol_ratio = df_b['volume'] / df_a['volume']
|
|
||||||
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
|
||||||
|
|
||||||
# 5. Volatility Regime
|
|
||||||
ret_a = df_a['close'].pct_change()
|
|
||||||
ret_b = df_b['close'].pct_change()
|
|
||||||
vol_a = ret_a.rolling(window=window).std()
|
|
||||||
vol_b = ret_b.rolling(window=window).std()
|
|
||||||
vol_spread_ratio = vol_b / vol_a
|
|
||||||
|
|
||||||
# Create feature DataFrame
|
|
||||||
features = pd.DataFrame(index=spread.index)
|
|
||||||
features['spread'] = spread
|
|
||||||
features['z_score'] = z_score
|
|
||||||
features['spread_rsi'] = spread_rsi
|
|
||||||
features['spread_roc'] = spread_roc
|
|
||||||
features['vol_ratio'] = vol_ratio
|
|
||||||
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
|
|
||||||
features['vol_diff_ratio'] = vol_spread_ratio
|
|
||||||
|
|
||||||
# 6. Merge CryptoQuant Data
|
|
||||||
if cq_df is not None:
|
|
||||||
print("Merging CryptoQuant features...")
|
|
||||||
# Forward fill daily data to hourly timestamps
|
|
||||||
# reindex features to match cq_df range or join
|
|
||||||
|
|
||||||
# Resample CQ to hourly (ffill)
|
|
||||||
# But easier: join features with cq_df using asof or reindex
|
|
||||||
cq_aligned = cq_df.reindex(features.index, method='ffill')
|
|
||||||
|
|
||||||
# Add derived CQ features
|
|
||||||
# Funding Diff: If ETH funding > BTC funding => ETH overheated
|
|
||||||
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
|
|
||||||
cq_aligned['funding_diff'] = cq_aligned['eth_funding'] - cq_aligned['btc_funding']
|
|
||||||
|
|
||||||
# Inflow Ratio: If ETH inflow >> BTC inflow => ETH dump incoming?
|
|
||||||
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
|
|
||||||
# Add small epsilon to avoid div by zero
|
|
||||||
cq_aligned['inflow_ratio'] = cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
|
|
||||||
|
|
||||||
features = features.join(cq_aligned)
|
|
||||||
|
|
||||||
# --- Refined Target Definition (Anytime Profit) ---
|
|
||||||
horizon = 6
|
|
||||||
threshold = 0.005 # 0.5% profit target
|
|
||||||
z_threshold = 1.0
|
|
||||||
|
|
||||||
# For Short Spread (Z > 1): Did it drop below target?
|
|
||||||
# We look for the MINIMUM spread in the next 'horizon' periods
|
|
||||||
future_min = features['spread'].rolling(window=horizon).min().shift(-horizon)
|
|
||||||
target_short = features['spread'] * (1 - threshold)
|
|
||||||
success_short = (features['z_score'] > z_threshold) & (future_min < target_short)
|
|
||||||
|
|
||||||
# For Long Spread (Z < -1): Did it rise above target?
|
|
||||||
# We look for the MAXIMUM spread in the next 'horizon' periods
|
|
||||||
future_max = features['spread'].rolling(window=horizon).max().shift(-horizon)
|
|
||||||
target_long = features['spread'] * (1 + threshold)
|
|
||||||
success_long = (features['z_score'] < -z_threshold) & (future_max > target_long)
|
|
||||||
|
|
||||||
conditions = [success_short, success_long]
|
|
||||||
|
|
||||||
features['target'] = np.select(conditions, [1, 1], default=0)
|
|
||||||
|
|
||||||
return features.dropna()
|
|
||||||
|
|
||||||
def train_regime_model(features):
|
|
||||||
"""
|
|
||||||
Train a Random Forest to predict mean reversion success.
|
|
||||||
"""
|
|
||||||
# Define excluded columns (targets, raw prices, intermediates)
|
|
||||||
exclude_cols = ['spread', 'horizon_ret', 'target', 'rolling_mean', 'rolling_std']
|
|
||||||
|
|
||||||
# Auto-select all other numeric columns as features
|
|
||||||
feature_cols = [c for c in features.columns if c not in exclude_cols]
|
|
||||||
|
|
||||||
# Handle NaN/Inf if any slipped through
|
|
||||||
X = features[feature_cols].replace([np.inf, -np.inf], np.nan).fillna(0)
|
|
||||||
y = features['target']
|
|
||||||
|
|
||||||
# Split Data
|
|
||||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=False)
|
|
||||||
|
|
||||||
print(f"\nTraining on {len(X_train)} samples, Testing on {len(X_test)} samples...")
|
|
||||||
print(f"Features used: {feature_cols}")
|
|
||||||
print(f"Class Balance (Target=1): {y.mean():.2%}")
|
|
||||||
|
|
||||||
# Model
|
|
||||||
model = RandomForestClassifier(
|
|
||||||
n_estimators=200,
|
|
||||||
max_depth=6,
|
|
||||||
min_samples_leaf=20,
|
|
||||||
class_weight='balanced_subsample',
|
|
||||||
random_state=42
|
|
||||||
)
|
|
||||||
model.fit(X_train, y_train)
|
|
||||||
|
|
||||||
# Evaluation
|
|
||||||
y_pred = model.predict(X_test)
|
|
||||||
y_prob = model.predict_proba(X_test)[:, 1]
|
|
||||||
|
|
||||||
print("\n--- Model Evaluation ---")
|
|
||||||
print(classification_report(y_test, y_pred))
|
|
||||||
|
|
||||||
# Feature Importance
|
|
||||||
importances = pd.Series(model.feature_importances_, index=feature_cols).sort_values(ascending=False)
|
|
||||||
print("\n--- Feature Importance ---")
|
|
||||||
print(importances)
|
|
||||||
|
|
||||||
return model, X_test, y_test, y_pred, y_prob
|
|
||||||
|
|
||||||
def plot_interactive_results(features, y_test, y_pred, y_prob):
|
|
||||||
"""
|
|
||||||
Create an interactive HTML plot using Plotly.
|
|
||||||
"""
|
|
||||||
print("\nGenerating interactive plot...")
|
|
||||||
|
|
||||||
test_idx = y_test.index
|
|
||||||
test_data = features.loc[test_idx].copy()
|
|
||||||
test_data['prob'] = y_prob
|
|
||||||
test_data['prediction'] = y_pred
|
|
||||||
test_data['actual'] = y_test
|
|
||||||
|
|
||||||
# Create Subplots
|
|
||||||
fig = make_subplots(
|
|
||||||
rows=3, cols=1,
|
|
||||||
shared_xaxes=True,
|
|
||||||
vertical_spacing=0.05,
|
|
||||||
row_heights=[0.5, 0.25, 0.25],
|
|
||||||
subplot_titles=('Spread & Signals', 'Exchange Inflows', 'Z-Score & Probability')
|
|
||||||
)
|
|
||||||
|
|
||||||
# Top: Spread
|
|
||||||
fig.add_trace(
|
|
||||||
go.Scatter(x=test_data.index, y=test_data['spread'], mode='lines', name='Spread', line=dict(color='gray')),
|
|
||||||
row=1, col=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Signals
|
|
||||||
# Separate Long and Short signals for clarity
|
|
||||||
# Logic: If Z-Score was High (>1), we were betting on a SHORT Spread (Reversion Down)
|
|
||||||
# If Z-Score was Low (< -1), we were betting on a LONG Spread (Reversion Up)
|
|
||||||
|
|
||||||
# Correct Short Signals (Green Triangle Down)
|
|
||||||
tp_short = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 1) & (test_data['z_score'] > 0)]
|
|
||||||
fig.add_trace(
|
|
||||||
go.Scatter(x=tp_short.index, y=tp_short['spread'], mode='markers', name='Win: Short Spread',
|
|
||||||
marker=dict(symbol='triangle-down', size=12, color='green')),
|
|
||||||
row=1, col=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Correct Long Signals (Green Triangle Up)
|
|
||||||
tp_long = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 1) & (test_data['z_score'] < 0)]
|
|
||||||
fig.add_trace(
|
|
||||||
go.Scatter(x=tp_long.index, y=tp_long['spread'], mode='markers', name='Win: Long Spread',
|
|
||||||
marker=dict(symbol='triangle-up', size=12, color='green')),
|
|
||||||
row=1, col=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# False Short Signals (Red Triangle Down)
|
|
||||||
fp_short = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 0) & (test_data['z_score'] > 0)]
|
|
||||||
fig.add_trace(
|
|
||||||
go.Scatter(x=fp_short.index, y=fp_short['spread'], mode='markers', name='Loss: Short Spread',
|
|
||||||
marker=dict(symbol='triangle-down', size=10, color='red')),
|
|
||||||
row=1, col=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# False Long Signals (Red Triangle Up)
|
|
||||||
fp_long = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 0) & (test_data['z_score'] < 0)]
|
|
||||||
fig.add_trace(
|
|
||||||
go.Scatter(x=fp_long.index, y=fp_long['spread'], mode='markers', name='Loss: Long Spread',
|
|
||||||
marker=dict(symbol='triangle-up', size=10, color='red')),
|
|
||||||
row=1, col=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Middle: Inflows (BTC vs ETH)
|
|
||||||
if 'btc_inflow' in test_data.columns:
|
|
||||||
fig.add_trace(
|
|
||||||
go.Bar(x=test_data.index, y=test_data['btc_inflow'], name='BTC Inflow', marker_color='orange', opacity=0.6),
|
|
||||||
row=2, col=1
|
|
||||||
)
|
|
||||||
if 'eth_inflow' in test_data.columns:
|
|
||||||
fig.add_trace(
|
|
||||||
go.Bar(x=test_data.index, y=test_data['eth_inflow'], name='ETH Inflow', marker_color='purple', opacity=0.6),
|
|
||||||
row=2, col=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Bottom: Z-Score
|
|
||||||
fig.add_trace(
|
|
||||||
go.Scatter(x=test_data.index, y=test_data['z_score'], mode='lines', name='Z-Score', line=dict(color='blue'), opacity=0.5),
|
|
||||||
row=3, col=1
|
|
||||||
)
|
|
||||||
fig.add_hline(y=2, line_dash="dash", line_color="red", row=3, col=1)
|
|
||||||
fig.add_hline(y=-2, line_dash="dash", line_color="green", row=3, col=1)
|
|
||||||
|
|
||||||
# Probability (Secondary Y for Row 3)
|
|
||||||
fig.add_trace(
|
|
||||||
go.Scatter(x=test_data.index, y=test_data['prob'], mode='lines', name='Prob', line=dict(color='cyan', width=1.5), yaxis='y4'),
|
|
||||||
row=3, col=1
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_layout(
|
|
||||||
title='Regime Detection Analysis (with CryptoQuant)',
|
|
||||||
autosize=True,
|
|
||||||
height=None,
|
|
||||||
hovermode='x unified',
|
|
||||||
yaxis4=dict(title='Probability', overlaying='y3', side='right', range=[0, 1], showgrid=False),
|
|
||||||
template="plotly_dark",
|
|
||||||
margin=dict(l=10, r=10, t=40, b=10),
|
|
||||||
barmode='group'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update all x-axes to ensure spikes are visible everywhere
|
|
||||||
fig.update_xaxes(
|
|
||||||
showspikes=True,
|
|
||||||
spikemode='across',
|
|
||||||
spikesnap='cursor',
|
|
||||||
showline=False,
|
|
||||||
showgrid=True,
|
|
||||||
spikedash='dot',
|
|
||||||
spikecolor='white', # Make it bright to see
|
|
||||||
spikethickness=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.update_layout(
|
|
||||||
title='Regime Detection Analysis (with CryptoQuant)',
|
|
||||||
autosize=True,
|
|
||||||
height=None,
|
|
||||||
hovermode='x unified', # Keep unified hover for data reading
|
|
||||||
yaxis4=dict(title='Probability', overlaying='y3', side='right', range=[0, 1], showgrid=False),
|
|
||||||
template="plotly_dark",
|
|
||||||
margin=dict(l=10, r=10, t=40, b=10),
|
|
||||||
barmode='group'
|
|
||||||
)
|
|
||||||
|
|
||||||
output_path = "research/regime_results.html"
|
|
||||||
fig.write_html(
|
|
||||||
output_path,
|
|
||||||
config={'responsive': True, 'scrollZoom': True},
|
|
||||||
include_plotlyjs='cdn',
|
|
||||||
full_html=True,
|
|
||||||
default_height='100vh',
|
|
||||||
default_width='100%'
|
|
||||||
)
|
|
||||||
print(f"Interactive plot saved to {output_path}")
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# 1. Load CQ Data first to determine valid date range
|
|
||||||
cq_path = "data/cq_training_data.csv"
|
|
||||||
cq_df = load_cryptoquant_data(cq_path)
|
|
||||||
|
|
||||||
start_date = None
|
|
||||||
end_date = None
|
|
||||||
|
|
||||||
if cq_df is not None and not cq_df.empty:
|
|
||||||
start_date = cq_df.index.min().strftime('%Y-%m-%d')
|
|
||||||
end_date = cq_df.index.max().strftime('%Y-%m-%d')
|
|
||||||
print(f"CryptoQuant Data Range: {start_date} to {end_date}")
|
|
||||||
|
|
||||||
# 2. Get Market Data (Aligned to CQ range)
|
|
||||||
df_btc, df_eth = prepare_data(
|
|
||||||
"BTC-USDT", "ETH-USDT",
|
|
||||||
timeframe="1h",
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Calculate Features
|
|
||||||
print("Calculating advanced regime features...")
|
|
||||||
data = calculate_features(df_btc, df_eth, cq_df=cq_df, window=24)
|
|
||||||
|
|
||||||
if data.empty:
|
|
||||||
print("Error: No overlapping data found between Price and CryptoQuant data.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 4. Train & Evaluate
|
|
||||||
model, X_test, y_test, y_pred, y_prob = train_regime_model(data)
|
|
||||||
|
|
||||||
# 5. Plot
|
|
||||||
plot_interactive_results(data, y_test, y_pred, y_prob)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -36,6 +36,7 @@ def _build_registry() -> dict[str, StrategyConfig]:
|
|||||||
# Import here to avoid circular imports
|
# Import here to avoid circular imports
|
||||||
from strategies.examples import MaCrossStrategy, RsiStrategy
|
from strategies.examples import MaCrossStrategy, RsiStrategy
|
||||||
from strategies.supertrend import MetaSupertrendStrategy
|
from strategies.supertrend import MetaSupertrendStrategy
|
||||||
|
from strategies.regime_strategy import RegimeReversionStrategy
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"rsi": StrategyConfig(
|
"rsi": StrategyConfig(
|
||||||
@@ -76,6 +77,19 @@ def _build_registry() -> dict[str, StrategyConfig]:
|
|||||||
'period3': 12, 'multiplier3': 1.0
|
'period3': 12, 'multiplier3': 1.0
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
|
"regime": StrategyConfig(
|
||||||
|
strategy_class=RegimeReversionStrategy,
|
||||||
|
default_params={
|
||||||
|
'horizon': 96,
|
||||||
|
'z_window': 24,
|
||||||
|
'stop_loss': 0.06,
|
||||||
|
'take_profit': 0.05
|
||||||
|
},
|
||||||
|
grid_params={
|
||||||
|
'horizon': [72, 96, 120],
|
||||||
|
'stop_loss': [0.04, 0.06, 0.08]
|
||||||
|
}
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
280
strategies/regime_strategy.py
Normal file
280
strategies/regime_strategy.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import ta
|
||||||
|
import vectorbt as vbt
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
|
||||||
|
from strategies.base import BaseStrategy
|
||||||
|
from engine.market import MarketType
|
||||||
|
from engine.data_manager import DataManager
|
||||||
|
from engine.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
class RegimeReversionStrategy(BaseStrategy):
|
||||||
|
"""
|
||||||
|
ML-Based Regime Detection & Mean Reversion Strategy.
|
||||||
|
|
||||||
|
Logic:
|
||||||
|
1. Tracks the BTC/ETH Spread and its Z-Score (24h window).
|
||||||
|
2. Uses a Random Forest model to predict if an extreme Z-Score will revert profitably.
|
||||||
|
3. Features: Spread Technicals (RSI, ROC) + On-Chain Flows (Inflow, Funding).
|
||||||
|
4. Entry: When Model Probability > 0.5.
|
||||||
|
5. Exit: Z-Score reversion to 0 or SL/TP.
|
||||||
|
|
||||||
|
Walk-Forward Training:
|
||||||
|
- Trains on first `train_ratio` of data (default 70%)
|
||||||
|
- Generates signals only for remaining test period (30%)
|
||||||
|
- Eliminates look-ahead bias for realistic backtest results
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_path: str = "data/regime_model.pkl",
|
||||||
|
horizon: int = 96, # 4 Days based on research
|
||||||
|
z_window: int = 24,
|
||||||
|
stop_loss: float = 0.06, # 6% to survive 2% avg MAE
|
||||||
|
take_profit: float = 0.05, # Swing target
|
||||||
|
train_ratio: float = 0.7 # Walk-forward: train on first 70%
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model_path = model_path
|
||||||
|
self.horizon = horizon
|
||||||
|
self.z_window = z_window
|
||||||
|
self.stop_loss = stop_loss
|
||||||
|
self.take_profit = take_profit
|
||||||
|
self.train_ratio = train_ratio
|
||||||
|
|
||||||
|
# Default Strategy Config
|
||||||
|
self.default_market_type = MarketType.PERPETUAL
|
||||||
|
self.default_leverage = 1
|
||||||
|
|
||||||
|
self.dm = DataManager()
|
||||||
|
self.model = None
|
||||||
|
self.feature_cols = None
|
||||||
|
self.train_end_idx = None # Will store the training cutoff point
|
||||||
|
|
||||||
|
def run(self, close, **kwargs):
|
||||||
|
"""
|
||||||
|
Execute the strategy logic.
|
||||||
|
We assume this strategy is run on ETH-USDT (the active asset).
|
||||||
|
We will fetch BTC-USDT internally to calculate the spread.
|
||||||
|
"""
|
||||||
|
# 1. Identify Context
|
||||||
|
# We need BTC data aligned with the incoming ETH 'close' series
|
||||||
|
start_date = close.index.min()
|
||||||
|
end_date = close.index.max()
|
||||||
|
|
||||||
|
logger.info("Fetching BTC context data...")
|
||||||
|
try:
|
||||||
|
# Load BTC data (Context) - Must match the timeframe of the backtest
|
||||||
|
# Research was done on 1h candles, so strategy should be run on 1h
|
||||||
|
df_btc = self.dm.load_data("okx", "BTC-USDT", "1h", MarketType.SPOT)
|
||||||
|
|
||||||
|
# Align BTC to ETH (close)
|
||||||
|
df_btc = df_btc.reindex(close.index, method='ffill')
|
||||||
|
btc_close = df_btc['close']
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load BTC context: {e}")
|
||||||
|
empty = self.create_empty_signals(close)
|
||||||
|
return empty, empty, empty, empty
|
||||||
|
|
||||||
|
# 2. Construct DataFrames for Feature Engineering
|
||||||
|
# We need volume/high/low for features, but 'run' signature primarily gives 'close'.
|
||||||
|
# kwargs might have high/low/volume if passed by Backtester.run_strategy
|
||||||
|
eth_vol = kwargs.get('volume')
|
||||||
|
|
||||||
|
if eth_vol is None:
|
||||||
|
logger.warning("Volume data missing. Feature calculation might fail.")
|
||||||
|
# Fallback or error handling
|
||||||
|
eth_vol = pd.Series(0, index=close.index)
|
||||||
|
|
||||||
|
# Construct dummy dfs for prepare_features
|
||||||
|
# We only really need Close and Volume for the current feature set
|
||||||
|
df_a = pd.DataFrame({'close': btc_close, 'volume': df_btc['volume']})
|
||||||
|
df_b = pd.DataFrame({'close': close, 'volume': eth_vol})
|
||||||
|
|
||||||
|
# 3. Load On-Chain Data (CryptoQuant)
|
||||||
|
# We use the saved CSV for training/inference
|
||||||
|
# In a live setting, this would query the API for recent data
|
||||||
|
cq_df = None
|
||||||
|
try:
|
||||||
|
cq_path = "data/cq_training_data.csv"
|
||||||
|
cq_df = pd.read_csv(cq_path, index_col='timestamp', parse_dates=True)
|
||||||
|
if cq_df.index.tz is None:
|
||||||
|
cq_df.index = cq_df.index.tz_localize('UTC')
|
||||||
|
except Exception:
|
||||||
|
logger.warning("CryptoQuant data not found. Running without on-chain features.")
|
||||||
|
|
||||||
|
# 4. Calculate Features
|
||||||
|
features = self.prepare_features(df_a, df_b, cq_df)
|
||||||
|
|
||||||
|
# 5. Walk-Forward Split
|
||||||
|
# Train on first `train_ratio` of data, test on remainder
|
||||||
|
n_samples = len(features)
|
||||||
|
train_size = int(n_samples * self.train_ratio)
|
||||||
|
|
||||||
|
train_features = features.iloc[:train_size]
|
||||||
|
test_features = features.iloc[train_size:]
|
||||||
|
|
||||||
|
train_end_date = train_features.index[-1]
|
||||||
|
test_start_date = test_features.index[0]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Walk-Forward Split: Train={len(train_features)} bars "
|
||||||
|
f"(until {train_end_date.strftime('%Y-%m-%d')}), "
|
||||||
|
f"Test={len(test_features)} bars "
|
||||||
|
f"(from {test_start_date.strftime('%Y-%m-%d')})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Train Model on Training Period ONLY
|
||||||
|
if self.model is None:
|
||||||
|
logger.info("Training Regime Model on training period only...")
|
||||||
|
self.model, self.feature_cols = self.train_model(train_features)
|
||||||
|
|
||||||
|
# 7. Predict on TEST Period ONLY
|
||||||
|
# Use valid columns only
|
||||||
|
X_test = test_features[self.feature_cols].fillna(0)
|
||||||
|
X_test = X_test.replace([np.inf, -np.inf], 0)
|
||||||
|
|
||||||
|
# Predict Probabilities for test period
|
||||||
|
probs = self.model.predict_proba(X_test)[:, 1]
|
||||||
|
|
||||||
|
# 8. Generate Entry Signals (TEST period only)
|
||||||
|
# If Z > 1 (Spread High, ETH Expensive) -> Short ETH
|
||||||
|
# If Z < -1 (Spread Low, ETH Cheap) -> Long ETH
|
||||||
|
|
||||||
|
short_signal_test = (probs > 0.5) & (test_features['z_score'].values > 1.0)
|
||||||
|
long_signal_test = (probs > 0.5) & (test_features['z_score'].values < -1.0)
|
||||||
|
|
||||||
|
# Create full-length signal series (False for training period)
|
||||||
|
long_entries = pd.Series(False, index=close.index)
|
||||||
|
short_entries = pd.Series(False, index=close.index)
|
||||||
|
|
||||||
|
# Map test signals to their correct indices
|
||||||
|
test_idx = test_features.index
|
||||||
|
for i, idx in enumerate(test_idx):
|
||||||
|
if idx in close.index:
|
||||||
|
long_entries.loc[idx] = bool(long_signal_test[i])
|
||||||
|
short_entries.loc[idx] = bool(short_signal_test[i])
|
||||||
|
|
||||||
|
# 9. Generate Exits
|
||||||
|
# Exit when Z-Score crosses back through 0 (mean reversion complete)
|
||||||
|
z_reindexed = features['z_score'].reindex(close.index, fill_value=0)
|
||||||
|
|
||||||
|
# Exit Long when Z > 0, Exit Short when Z < 0
|
||||||
|
long_exits = z_reindexed > 0
|
||||||
|
short_exits = z_reindexed < 0
|
||||||
|
|
||||||
|
# Log signal counts for verification
|
||||||
|
n_long = long_entries.sum()
|
||||||
|
n_short = short_entries.sum()
|
||||||
|
logger.info(f"Generated {n_long} long signals, {n_short} short signals (test period only)")
|
||||||
|
|
||||||
|
return long_entries, long_exits, short_entries, short_exits
|
||||||
|
|
||||||
|
def prepare_features(self, df_btc, df_eth, cq_df=None):
|
||||||
|
"""Replicate research feature engineering"""
|
||||||
|
# Align
|
||||||
|
common = df_btc.index.intersection(df_eth.index)
|
||||||
|
df_a = df_btc.loc[common].copy()
|
||||||
|
df_b = df_eth.loc[common].copy()
|
||||||
|
|
||||||
|
# Spread
|
||||||
|
spread = df_b['close'] / df_a['close']
|
||||||
|
|
||||||
|
# Z-Score
|
||||||
|
rolling_mean = spread.rolling(window=self.z_window).mean()
|
||||||
|
rolling_std = spread.rolling(window=self.z_window).std()
|
||||||
|
z_score = (spread - rolling_mean) / rolling_std
|
||||||
|
|
||||||
|
# Technicals
|
||||||
|
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
|
||||||
|
spread_roc = spread.pct_change(periods=5) * 100
|
||||||
|
spread_change_1h = spread.pct_change(periods=1)
|
||||||
|
|
||||||
|
# Volume
|
||||||
|
vol_ratio = df_b['volume'] / df_a['volume']
|
||||||
|
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
|
||||||
|
|
||||||
|
# Volatility
|
||||||
|
ret_a = df_a['close'].pct_change()
|
||||||
|
ret_b = df_b['close'].pct_change()
|
||||||
|
vol_a = ret_a.rolling(window=self.z_window).std()
|
||||||
|
vol_b = ret_b.rolling(window=self.z_window).std()
|
||||||
|
vol_spread_ratio = vol_b / vol_a
|
||||||
|
|
||||||
|
features = pd.DataFrame(index=spread.index)
|
||||||
|
features['spread'] = spread
|
||||||
|
features['z_score'] = z_score
|
||||||
|
features['spread_rsi'] = spread_rsi
|
||||||
|
features['spread_roc'] = spread_roc
|
||||||
|
features['spread_change_1h'] = spread_change_1h
|
||||||
|
features['vol_ratio'] = vol_ratio
|
||||||
|
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
|
||||||
|
features['vol_diff_ratio'] = vol_spread_ratio
|
||||||
|
|
||||||
|
# CQ Merge
|
||||||
|
if cq_df is not None:
|
||||||
|
cq_aligned = cq_df.reindex(features.index, method='ffill')
|
||||||
|
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
|
||||||
|
cq_aligned['funding_diff'] = cq_aligned['eth_funding'] - cq_aligned['btc_funding']
|
||||||
|
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
|
||||||
|
cq_aligned['inflow_ratio'] = cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
|
||||||
|
features = features.join(cq_aligned)
|
||||||
|
|
||||||
|
return features.dropna()
|
||||||
|
|
||||||
|
def train_model(self, train_features):
|
||||||
|
"""
|
||||||
|
Train Random Forest on training data only.
|
||||||
|
|
||||||
|
This method receives ONLY the training subset of features,
|
||||||
|
ensuring no look-ahead bias. The model learns from historical
|
||||||
|
patterns and is then applied to unseen test data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_features: DataFrame containing features for training period only
|
||||||
|
"""
|
||||||
|
threshold = 0.005
|
||||||
|
horizon = self.horizon
|
||||||
|
|
||||||
|
# Define targets using ONLY training data
|
||||||
|
# For Short Spread (Z > 1): Did spread drop below target within horizon?
|
||||||
|
future_min = train_features['spread'].rolling(window=horizon).min().shift(-horizon)
|
||||||
|
target_short = train_features['spread'] * (1 - threshold)
|
||||||
|
success_short = (train_features['z_score'] > 1.0) & (future_min < target_short)
|
||||||
|
|
||||||
|
# For Long Spread (Z < -1): Did spread rise above target within horizon?
|
||||||
|
future_max = train_features['spread'].rolling(window=horizon).max().shift(-horizon)
|
||||||
|
target_long = train_features['spread'] * (1 + threshold)
|
||||||
|
success_long = (train_features['z_score'] < -1.0) & (future_max > target_long)
|
||||||
|
|
||||||
|
targets = np.select([success_short, success_long], [1, 1], default=0)
|
||||||
|
|
||||||
|
# Build model
|
||||||
|
model = RandomForestClassifier(
|
||||||
|
n_estimators=300, max_depth=5, min_samples_leaf=30,
|
||||||
|
class_weight={0: 1, 1: 3}, random_state=42
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exclude non-feature columns
|
||||||
|
exclude = ['spread']
|
||||||
|
cols = [c for c in train_features.columns if c not in exclude]
|
||||||
|
|
||||||
|
# Clean features
|
||||||
|
X_train = train_features[cols].fillna(0)
|
||||||
|
X_train = X_train.replace([np.inf, -np.inf], 0)
|
||||||
|
|
||||||
|
# Remove rows with NaN targets (from rolling window at end of training period)
|
||||||
|
valid_mask = ~np.isnan(targets) & ~np.isinf(targets)
|
||||||
|
# Also check for rows where future data doesn't exist (shift created NaNs)
|
||||||
|
valid_mask = valid_mask & (future_min.notna().values) & (future_max.notna().values)
|
||||||
|
|
||||||
|
X_train_clean = X_train[valid_mask]
|
||||||
|
targets_clean = targets[valid_mask]
|
||||||
|
|
||||||
|
logger.info(f"Training on {len(X_train_clean)} valid samples (removed {len(X_train) - len(X_train_clean)} with incomplete future data)")
|
||||||
|
|
||||||
|
model.fit(X_train_clean, targets_clean)
|
||||||
|
return model, cols
|
||||||
Reference in New Issue
Block a user