- Extend regime detection to top 10 cryptocurrencies (45 pairs) - Dynamic pair selection based on divergence score (|z_score| * probability) - Universal ML model trained on all pairs - Correlation-based filtering to avoid redundant positions - Funding rate integration from OKX for all 10 assets - ATR-based dynamic stop-loss and take-profit - Walk-forward training with 70/30 split Performance: +35.69% return (vs +28.66% baseline), 63.6% win rate
174 lines
5.4 KiB
Python
174 lines
5.4 KiB
Python
"""
|
|
Correlation Filter for Multi-Pair Divergence Strategy.
|
|
|
|
Calculates rolling correlation matrix and filters pairs
|
|
to avoid highly correlated positions.
|
|
"""
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
from engine.logging_config import get_logger
|
|
from .config import MultiPairConfig
|
|
from .pair_scanner import TradingPair
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class CorrelationFilter:
|
|
"""
|
|
Calculates and filters based on asset correlations.
|
|
|
|
Uses rolling correlation of returns to identify assets
|
|
moving together, avoiding redundant positions.
|
|
"""
|
|
|
|
def __init__(self, config: MultiPairConfig):
|
|
self.config = config
|
|
self._correlation_matrix: pd.DataFrame | None = None
|
|
self._last_update_idx: int = -1
|
|
|
|
def calculate_correlation_matrix(
|
|
self,
|
|
price_data: dict[str, pd.Series],
|
|
current_idx: int | None = None
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Calculate rolling correlation matrix between all assets.
|
|
|
|
Args:
|
|
price_data: Dictionary mapping asset symbols to price series
|
|
current_idx: Current bar index (for caching)
|
|
|
|
Returns:
|
|
Correlation matrix DataFrame
|
|
"""
|
|
# Use cached if recent
|
|
if (
|
|
current_idx is not None
|
|
and self._correlation_matrix is not None
|
|
and current_idx - self._last_update_idx < 24 # Update every 24 bars
|
|
):
|
|
return self._correlation_matrix
|
|
|
|
# Calculate returns
|
|
returns = {}
|
|
for symbol, prices in price_data.items():
|
|
returns[symbol] = prices.pct_change()
|
|
|
|
returns_df = pd.DataFrame(returns)
|
|
|
|
# Rolling correlation
|
|
window = self.config.correlation_window
|
|
|
|
# Get latest correlation (last row of rolling correlation)
|
|
if len(returns_df) >= window:
|
|
rolling_corr = returns_df.rolling(window=window).corr()
|
|
# Extract last timestamp correlation matrix
|
|
last_idx = returns_df.index[-1]
|
|
corr_matrix = rolling_corr.loc[last_idx]
|
|
else:
|
|
# Fallback to full-period correlation if not enough data
|
|
corr_matrix = returns_df.corr()
|
|
|
|
self._correlation_matrix = corr_matrix
|
|
if current_idx is not None:
|
|
self._last_update_idx = current_idx
|
|
|
|
return corr_matrix
|
|
|
|
def filter_pairs(
|
|
self,
|
|
pairs: list[TradingPair],
|
|
current_position_asset: str | None,
|
|
price_data: dict[str, pd.Series],
|
|
current_idx: int | None = None
|
|
) -> list[TradingPair]:
|
|
"""
|
|
Filter pairs based on correlation with current position.
|
|
|
|
If we have an open position in an asset, exclude pairs where
|
|
either asset is highly correlated with the held asset.
|
|
|
|
Args:
|
|
pairs: List of candidate pairs
|
|
current_position_asset: Currently held asset (or None)
|
|
price_data: Dictionary of price series by symbol
|
|
current_idx: Current bar index for caching
|
|
|
|
Returns:
|
|
Filtered list of pairs
|
|
"""
|
|
if current_position_asset is None:
|
|
return pairs
|
|
|
|
corr_matrix = self.calculate_correlation_matrix(price_data, current_idx)
|
|
threshold = self.config.correlation_threshold
|
|
|
|
filtered = []
|
|
for pair in pairs:
|
|
# Check correlation of base and quote with held asset
|
|
base_corr = self._get_correlation(
|
|
corr_matrix, pair.base_asset, current_position_asset
|
|
)
|
|
quote_corr = self._get_correlation(
|
|
corr_matrix, pair.quote_asset, current_position_asset
|
|
)
|
|
|
|
# Filter if either asset highly correlated with position
|
|
if abs(base_corr) > threshold or abs(quote_corr) > threshold:
|
|
logger.debug(
|
|
"Filtered %s: base_corr=%.2f, quote_corr=%.2f (held: %s)",
|
|
pair.name, base_corr, quote_corr, current_position_asset
|
|
)
|
|
continue
|
|
|
|
filtered.append(pair)
|
|
|
|
if len(filtered) < len(pairs):
|
|
logger.info(
|
|
"Correlation filter: %d/%d pairs remaining (held: %s)",
|
|
len(filtered), len(pairs), current_position_asset
|
|
)
|
|
|
|
return filtered
|
|
|
|
def _get_correlation(
|
|
self,
|
|
corr_matrix: pd.DataFrame,
|
|
asset1: str,
|
|
asset2: str
|
|
) -> float:
|
|
"""
|
|
Get correlation between two assets from matrix.
|
|
|
|
Args:
|
|
corr_matrix: Correlation matrix
|
|
asset1: First asset symbol
|
|
asset2: Second asset symbol
|
|
|
|
Returns:
|
|
Correlation coefficient (-1 to 1), or 0 if not found
|
|
"""
|
|
if asset1 == asset2:
|
|
return 1.0
|
|
|
|
try:
|
|
return corr_matrix.loc[asset1, asset2]
|
|
except KeyError:
|
|
return 0.0
|
|
|
|
def get_correlation_report(
|
|
self,
|
|
price_data: dict[str, pd.Series]
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Generate a readable correlation report.
|
|
|
|
Args:
|
|
price_data: Dictionary of price series
|
|
|
|
Returns:
|
|
Correlation matrix as DataFrame
|
|
"""
|
|
return self.calculate_correlation_matrix(price_data)
|