from typing import Dict, Iterable, List, Sequence, Set, Tuple import numpy as np from .config import PruningConfig EXCLUDE_BASE_FEATURES: List[str] = [ 'Timestamp', 'Close', 'log_return_5', 'volatility_5', 'volatility_15', 'volatility_30', 'bb_bbm', 'bb_bbh', 'bb_bbl', 'stoch_k', 'sma_50', 'sma_200', 'psar', 'donchian_hband', 'donchian_lband', 'donchian_mband', 'keltner_hband', 'keltner_lband', 'keltner_mband', 'ichimoku_a', 'ichimoku_b', 'ichimoku_base_line', 'ichimoku_conversion_line', 'Open_lag1', 'Open_lag2', 'Open_lag3', 'High_lag1', 'High_lag2', 'High_lag3', 'Low_lag1', 'Low_lag2', 'Low_lag3', 'Close_lag1', 'Close_lag2', 'Close_lag3', 'Open_roll_mean_15', 'Open_roll_std_15', 'Open_roll_min_15', 'Open_roll_max_15', 'Open_roll_mean_30', 'Open_roll_min_30', 'Open_roll_max_30', 'High_roll_mean_15', 'High_roll_std_15', 'High_roll_min_15', 'High_roll_max_15', 'Low_roll_mean_5', 'Low_roll_min_5', 'Low_roll_max_5', 'Low_roll_mean_30', 'Low_roll_std_30', 'Low_roll_min_30', 'Low_roll_max_30', 'Close_roll_mean_5', 'Close_roll_min_5', 'Close_roll_max_5', 'Close_roll_mean_15', 'Close_roll_std_15', 'Close_roll_min_15', 'Close_roll_max_15', 'Close_roll_mean_30', 'Close_roll_std_30', 'Close_roll_min_30', 'Close_roll_max_30', 'Volume_roll_max_5', 'Volume_roll_max_15', 'Volume_roll_max_30', 'supertrend_12_3.0', 'supertrend_10_1.0', 'supertrend_11_2.0', ] def build_feature_list(all_columns: Sequence[str]) -> List[str]: """Return the model feature list by excluding base columns and targets.""" return [col for col in all_columns if col not in EXCLUDE_BASE_FEATURES] def prune_features( feature_cols: Sequence[str], importance_avg: Dict[str, float] | None, cfg: PruningConfig, ) -> List[str]: """Decide which features to keep using averaged importances and rules.""" prune_set: Set[str] = set() if importance_avg is not None: sorted_feats = sorted(importance_avg.items(), key=lambda kv: kv[1], reverse=True) keep_names = set(name for name, _ in sorted_feats[: cfg.top_k]) for name in feature_cols: if name not in keep_names: prune_set.add(name) for name in cfg.known_low_features: if name in feature_cols: prune_set.add(name) # If Parkinson vol exists, drop alternatives at same window for w in [5, 15, 30]: park = f'park_vol_{w}' if park in feature_cols: for alt in [f'gk_vol_{w}', f'rs_vol_{w}', f'yz_vol_{w}']: if alt in feature_cols: prune_set.add(alt) kept = [c for c in feature_cols if c not in prune_set] return kept