""" Correlation Filter for Multi-Pair Divergence Strategy. Calculates rolling correlation matrix and filters pairs to avoid highly correlated positions. """ import pandas as pd import numpy as np from engine.logging_config import get_logger from .config import MultiPairConfig from .pair_scanner import TradingPair logger = get_logger(__name__) class CorrelationFilter: """ Calculates and filters based on asset correlations. Uses rolling correlation of returns to identify assets moving together, avoiding redundant positions. """ def __init__(self, config: MultiPairConfig): self.config = config self._correlation_matrix: pd.DataFrame | None = None self._last_update_idx: int = -1 def calculate_correlation_matrix( self, price_data: dict[str, pd.Series], current_idx: int | None = None ) -> pd.DataFrame: """ Calculate rolling correlation matrix between all assets. Args: price_data: Dictionary mapping asset symbols to price series current_idx: Current bar index (for caching) Returns: Correlation matrix DataFrame """ # Use cached if recent if ( current_idx is not None and self._correlation_matrix is not None and current_idx - self._last_update_idx < 24 # Update every 24 bars ): return self._correlation_matrix # Calculate returns returns = {} for symbol, prices in price_data.items(): returns[symbol] = prices.pct_change() returns_df = pd.DataFrame(returns) # Rolling correlation window = self.config.correlation_window # Get latest correlation (last row of rolling correlation) if len(returns_df) >= window: rolling_corr = returns_df.rolling(window=window).corr() # Extract last timestamp correlation matrix last_idx = returns_df.index[-1] corr_matrix = rolling_corr.loc[last_idx] else: # Fallback to full-period correlation if not enough data corr_matrix = returns_df.corr() self._correlation_matrix = corr_matrix if current_idx is not None: self._last_update_idx = current_idx return corr_matrix def filter_pairs( self, pairs: list[TradingPair], current_position_asset: str | None, price_data: dict[str, pd.Series], current_idx: int | None = None ) -> list[TradingPair]: """ Filter pairs based on correlation with current position. If we have an open position in an asset, exclude pairs where either asset is highly correlated with the held asset. Args: pairs: List of candidate pairs current_position_asset: Currently held asset (or None) price_data: Dictionary of price series by symbol current_idx: Current bar index for caching Returns: Filtered list of pairs """ if current_position_asset is None: return pairs corr_matrix = self.calculate_correlation_matrix(price_data, current_idx) threshold = self.config.correlation_threshold filtered = [] for pair in pairs: # Check correlation of base and quote with held asset base_corr = self._get_correlation( corr_matrix, pair.base_asset, current_position_asset ) quote_corr = self._get_correlation( corr_matrix, pair.quote_asset, current_position_asset ) # Filter if either asset highly correlated with position if abs(base_corr) > threshold or abs(quote_corr) > threshold: logger.debug( "Filtered %s: base_corr=%.2f, quote_corr=%.2f (held: %s)", pair.name, base_corr, quote_corr, current_position_asset ) continue filtered.append(pair) if len(filtered) < len(pairs): logger.info( "Correlation filter: %d/%d pairs remaining (held: %s)", len(filtered), len(pairs), current_position_asset ) return filtered def _get_correlation( self, corr_matrix: pd.DataFrame, asset1: str, asset2: str ) -> float: """ Get correlation between two assets from matrix. Args: corr_matrix: Correlation matrix asset1: First asset symbol asset2: Second asset symbol Returns: Correlation coefficient (-1 to 1), or 0 if not found """ if asset1 == asset2: return 1.0 try: return corr_matrix.loc[asset1, asset2] except KeyError: return 0.0 def get_correlation_report( self, price_data: dict[str, pd.Series] ) -> pd.DataFrame: """ Generate a readable correlation report. Args: price_data: Dictionary of price series Returns: Correlation matrix as DataFrame """ return self.calculate_correlation_matrix(price_data)