Files
lowkey_backtest/research/regime_detection.py
Simon Moisy e6d69ed04d Add CryptoQuant client and regime detection analysis
- Introduced `CryptoQuantClient` for fetching data from the CryptoQuant API.
- Added `regime_detection.py` for advanced regime detection analysis using machine learning.
- Updated dependencies in `pyproject.toml` and `uv.lock` to include `scikit-learn`, `matplotlib`, `plotly`, `requests`, and `python-dotenv`.
- Enhanced `.gitignore` to exclude `regime_results.html` and CSV files.
- Created an interactive HTML plot for regime detection results and saved it as `regime_results.html`.
2026-01-13 16:13:57 +08:00

385 lines
14 KiB
Python

import sys
import os
from pathlib import Path
# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pandas as pd
import numpy as np
import ta
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from engine.data_manager import DataManager
from engine.market import MarketType
def prepare_data(symbol_a="BTC-USDT", symbol_b="ETH-USDT", timeframe="1h", limit=None, start_date=None, end_date=None):
"""
Load and align data for two assets to create a pair.
"""
dm = DataManager()
print(f"Loading data for {symbol_a} and {symbol_b}...")
# Helper to load or download
def get_df(symbol):
try:
# Try load first
df = dm.load_data("okx", symbol, timeframe, MarketType.SPOT)
except Exception:
df = dm.download_data("okx", symbol, timeframe, market_type=MarketType.SPOT)
# If we have start/end dates, ensure we have enough data or re-download
if start_date:
mask_start = pd.Timestamp(start_date, tz='UTC')
if df.index.min() > mask_start:
print(f"Local data starts {df.index.min()}, need {mask_start}. Downloading...")
df = dm.download_data("okx", symbol, timeframe, start_date=start_date, end_date=end_date, market_type=MarketType.SPOT)
return df
df_a = get_df(symbol_a)
df_b = get_df(symbol_b)
# Filter by date if provided (to match CQ data range)
if start_date:
df_a = df_a[df_a.index >= pd.Timestamp(start_date, tz='UTC')]
df_b = df_b[df_b.index >= pd.Timestamp(start_date, tz='UTC')]
if end_date:
df_a = df_a[df_a.index <= pd.Timestamp(end_date, tz='UTC')]
df_b = df_b[df_b.index <= pd.Timestamp(end_date, tz='UTC')]
# Align DataFrames
print("Aligning data...")
common_index = df_a.index.intersection(df_b.index)
df_a = df_a.loc[common_index].copy()
df_b = df_b.loc[common_index].copy()
if limit:
df_a = df_a.tail(limit)
df_b = df_b.tail(limit)
return df_a, df_b
def load_cryptoquant_data(file_path: str) -> pd.DataFrame | None:
"""
Load CryptoQuant data and prepare it for merging.
"""
if not os.path.exists(file_path):
print(f"Warning: CQ data file {file_path} not found.")
return None
print(f"Loading CryptoQuant data from {file_path}...")
df = pd.read_csv(file_path, index_col='timestamp', parse_dates=True)
# CQ data is usually daily (UTC 00:00).
# Ensure index is timezone aware to match market data
if df.index.tz is None:
df.index = df.index.tz_localize('UTC')
return df
def calculate_features(df_a, df_b, cq_df=None, window=24):
"""
Calculate spread, z-score, and advanced regime features including CQ data.
"""
# 1. Price Ratio (Spread)
spread = df_b['close'] / df_a['close']
# 2. Rolling Statistics for Z-Score
rolling_mean = spread.rolling(window=window).mean()
rolling_std = spread.rolling(window=window).std()
z_score = (spread - rolling_mean) / rolling_std
# 3. Spread Momentum / Technicals
spread_rsi = ta.momentum.RSIIndicator(spread, window=14).rsi()
spread_roc = spread.pct_change(periods=5) * 100
# 4. Volume Dynamics
vol_ratio = df_b['volume'] / df_a['volume']
vol_ratio_ma = vol_ratio.rolling(window=12).mean()
# 5. Volatility Regime
ret_a = df_a['close'].pct_change()
ret_b = df_b['close'].pct_change()
vol_a = ret_a.rolling(window=window).std()
vol_b = ret_b.rolling(window=window).std()
vol_spread_ratio = vol_b / vol_a
# Create feature DataFrame
features = pd.DataFrame(index=spread.index)
features['spread'] = spread
features['z_score'] = z_score
features['spread_rsi'] = spread_rsi
features['spread_roc'] = spread_roc
features['vol_ratio'] = vol_ratio
features['vol_ratio_rel'] = vol_ratio / vol_ratio_ma
features['vol_diff_ratio'] = vol_spread_ratio
# 6. Merge CryptoQuant Data
if cq_df is not None:
print("Merging CryptoQuant features...")
# Forward fill daily data to hourly timestamps
# reindex features to match cq_df range or join
# Resample CQ to hourly (ffill)
# But easier: join features with cq_df using asof or reindex
cq_aligned = cq_df.reindex(features.index, method='ffill')
# Add derived CQ features
# Funding Diff: If ETH funding > BTC funding => ETH overheated
if 'btc_funding' in cq_aligned.columns and 'eth_funding' in cq_aligned.columns:
cq_aligned['funding_diff'] = cq_aligned['eth_funding'] - cq_aligned['btc_funding']
# Inflow Ratio: If ETH inflow >> BTC inflow => ETH dump incoming?
if 'btc_inflow' in cq_aligned.columns and 'eth_inflow' in cq_aligned.columns:
# Add small epsilon to avoid div by zero
cq_aligned['inflow_ratio'] = cq_aligned['eth_inflow'] / (cq_aligned['btc_inflow'] + 1)
features = features.join(cq_aligned)
# --- Refined Target Definition (Anytime Profit) ---
horizon = 6
threshold = 0.005 # 0.5% profit target
z_threshold = 1.0
# For Short Spread (Z > 1): Did it drop below target?
# We look for the MINIMUM spread in the next 'horizon' periods
future_min = features['spread'].rolling(window=horizon).min().shift(-horizon)
target_short = features['spread'] * (1 - threshold)
success_short = (features['z_score'] > z_threshold) & (future_min < target_short)
# For Long Spread (Z < -1): Did it rise above target?
# We look for the MAXIMUM spread in the next 'horizon' periods
future_max = features['spread'].rolling(window=horizon).max().shift(-horizon)
target_long = features['spread'] * (1 + threshold)
success_long = (features['z_score'] < -z_threshold) & (future_max > target_long)
conditions = [success_short, success_long]
features['target'] = np.select(conditions, [1, 1], default=0)
return features.dropna()
def train_regime_model(features):
"""
Train a Random Forest to predict mean reversion success.
"""
# Define excluded columns (targets, raw prices, intermediates)
exclude_cols = ['spread', 'horizon_ret', 'target', 'rolling_mean', 'rolling_std']
# Auto-select all other numeric columns as features
feature_cols = [c for c in features.columns if c not in exclude_cols]
# Handle NaN/Inf if any slipped through
X = features[feature_cols].replace([np.inf, -np.inf], np.nan).fillna(0)
y = features['target']
# Split Data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=False)
print(f"\nTraining on {len(X_train)} samples, Testing on {len(X_test)} samples...")
print(f"Features used: {feature_cols}")
print(f"Class Balance (Target=1): {y.mean():.2%}")
# Model
model = RandomForestClassifier(
n_estimators=200,
max_depth=6,
min_samples_leaf=20,
class_weight='balanced_subsample',
random_state=42
)
model.fit(X_train, y_train)
# Evaluation
y_pred = model.predict(X_test)
y_prob = model.predict_proba(X_test)[:, 1]
print("\n--- Model Evaluation ---")
print(classification_report(y_test, y_pred))
# Feature Importance
importances = pd.Series(model.feature_importances_, index=feature_cols).sort_values(ascending=False)
print("\n--- Feature Importance ---")
print(importances)
return model, X_test, y_test, y_pred, y_prob
def plot_interactive_results(features, y_test, y_pred, y_prob):
"""
Create an interactive HTML plot using Plotly.
"""
print("\nGenerating interactive plot...")
test_idx = y_test.index
test_data = features.loc[test_idx].copy()
test_data['prob'] = y_prob
test_data['prediction'] = y_pred
test_data['actual'] = y_test
# Create Subplots
fig = make_subplots(
rows=3, cols=1,
shared_xaxes=True,
vertical_spacing=0.05,
row_heights=[0.5, 0.25, 0.25],
subplot_titles=('Spread & Signals', 'Exchange Inflows', 'Z-Score & Probability')
)
# Top: Spread
fig.add_trace(
go.Scatter(x=test_data.index, y=test_data['spread'], mode='lines', name='Spread', line=dict(color='gray')),
row=1, col=1
)
# Signals
# Separate Long and Short signals for clarity
# Logic: If Z-Score was High (>1), we were betting on a SHORT Spread (Reversion Down)
# If Z-Score was Low (< -1), we were betting on a LONG Spread (Reversion Up)
# Correct Short Signals (Green Triangle Down)
tp_short = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 1) & (test_data['z_score'] > 0)]
fig.add_trace(
go.Scatter(x=tp_short.index, y=tp_short['spread'], mode='markers', name='Win: Short Spread',
marker=dict(symbol='triangle-down', size=12, color='green')),
row=1, col=1
)
# Correct Long Signals (Green Triangle Up)
tp_long = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 1) & (test_data['z_score'] < 0)]
fig.add_trace(
go.Scatter(x=tp_long.index, y=tp_long['spread'], mode='markers', name='Win: Long Spread',
marker=dict(symbol='triangle-up', size=12, color='green')),
row=1, col=1
)
# False Short Signals (Red Triangle Down)
fp_short = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 0) & (test_data['z_score'] > 0)]
fig.add_trace(
go.Scatter(x=fp_short.index, y=fp_short['spread'], mode='markers', name='Loss: Short Spread',
marker=dict(symbol='triangle-down', size=10, color='red')),
row=1, col=1
)
# False Long Signals (Red Triangle Up)
fp_long = test_data[(test_data['prediction'] == 1) & (test_data['actual'] == 0) & (test_data['z_score'] < 0)]
fig.add_trace(
go.Scatter(x=fp_long.index, y=fp_long['spread'], mode='markers', name='Loss: Long Spread',
marker=dict(symbol='triangle-up', size=10, color='red')),
row=1, col=1
)
# Middle: Inflows (BTC vs ETH)
if 'btc_inflow' in test_data.columns:
fig.add_trace(
go.Bar(x=test_data.index, y=test_data['btc_inflow'], name='BTC Inflow', marker_color='orange', opacity=0.6),
row=2, col=1
)
if 'eth_inflow' in test_data.columns:
fig.add_trace(
go.Bar(x=test_data.index, y=test_data['eth_inflow'], name='ETH Inflow', marker_color='purple', opacity=0.6),
row=2, col=1
)
# Bottom: Z-Score
fig.add_trace(
go.Scatter(x=test_data.index, y=test_data['z_score'], mode='lines', name='Z-Score', line=dict(color='blue'), opacity=0.5),
row=3, col=1
)
fig.add_hline(y=2, line_dash="dash", line_color="red", row=3, col=1)
fig.add_hline(y=-2, line_dash="dash", line_color="green", row=3, col=1)
# Probability (Secondary Y for Row 3)
fig.add_trace(
go.Scatter(x=test_data.index, y=test_data['prob'], mode='lines', name='Prob', line=dict(color='cyan', width=1.5), yaxis='y4'),
row=3, col=1
)
fig.update_layout(
title='Regime Detection Analysis (with CryptoQuant)',
autosize=True,
height=None,
hovermode='x unified',
yaxis4=dict(title='Probability', overlaying='y3', side='right', range=[0, 1], showgrid=False),
template="plotly_dark",
margin=dict(l=10, r=10, t=40, b=10),
barmode='group'
)
# Update all x-axes to ensure spikes are visible everywhere
fig.update_xaxes(
showspikes=True,
spikemode='across',
spikesnap='cursor',
showline=False,
showgrid=True,
spikedash='dot',
spikecolor='white', # Make it bright to see
spikethickness=1,
)
fig.update_layout(
title='Regime Detection Analysis (with CryptoQuant)',
autosize=True,
height=None,
hovermode='x unified', # Keep unified hover for data reading
yaxis4=dict(title='Probability', overlaying='y3', side='right', range=[0, 1], showgrid=False),
template="plotly_dark",
margin=dict(l=10, r=10, t=40, b=10),
barmode='group'
)
output_path = "research/regime_results.html"
fig.write_html(
output_path,
config={'responsive': True, 'scrollZoom': True},
include_plotlyjs='cdn',
full_html=True,
default_height='100vh',
default_width='100%'
)
print(f"Interactive plot saved to {output_path}")
def main():
# 1. Load CQ Data first to determine valid date range
cq_path = "data/cq_training_data.csv"
cq_df = load_cryptoquant_data(cq_path)
start_date = None
end_date = None
if cq_df is not None and not cq_df.empty:
start_date = cq_df.index.min().strftime('%Y-%m-%d')
end_date = cq_df.index.max().strftime('%Y-%m-%d')
print(f"CryptoQuant Data Range: {start_date} to {end_date}")
# 2. Get Market Data (Aligned to CQ range)
df_btc, df_eth = prepare_data(
"BTC-USDT", "ETH-USDT",
timeframe="1h",
start_date=start_date,
end_date=end_date
)
# 3. Calculate Features
print("Calculating advanced regime features...")
data = calculate_features(df_btc, df_eth, cq_df=cq_df, window=24)
if data.empty:
print("Error: No overlapping data found between Price and CryptoQuant data.")
return
# 4. Train & Evaluate
model, X_test, y_test, y_pred, y_prob = train_regime_model(data)
# 5. Plot
plot_interactive_results(data, y_test, y_pred, y_prob)
if __name__ == "__main__":
main()