60 lines
2.6 KiB
Python
Raw Normal View History

from typing import Dict, Iterable, List, Sequence, Set, Tuple
import numpy as np
from .config import PruningConfig
EXCLUDE_BASE_FEATURES: List[str] = [
'Timestamp', 'Close',
'log_return_5', 'volatility_5', 'volatility_15', 'volatility_30',
'bb_bbm', 'bb_bbh', 'bb_bbl', 'stoch_k', 'sma_50', 'sma_200', 'psar',
'donchian_hband', 'donchian_lband', 'donchian_mband', 'keltner_hband', 'keltner_lband',
'keltner_mband', 'ichimoku_a', 'ichimoku_b', 'ichimoku_base_line', 'ichimoku_conversion_line',
'Open_lag1', 'Open_lag2', 'Open_lag3', 'High_lag1', 'High_lag2', 'High_lag3', 'Low_lag1', 'Low_lag2',
'Low_lag3', 'Close_lag1', 'Close_lag2', 'Close_lag3', 'Open_roll_mean_15', 'Open_roll_std_15', 'Open_roll_min_15',
'Open_roll_max_15', 'Open_roll_mean_30', 'Open_roll_min_30', 'Open_roll_max_30', 'High_roll_mean_15', 'High_roll_std_15',
'High_roll_min_15', 'High_roll_max_15', 'Low_roll_mean_5', 'Low_roll_min_5', 'Low_roll_max_5', 'Low_roll_mean_30',
'Low_roll_std_30', 'Low_roll_min_30', 'Low_roll_max_30', 'Close_roll_mean_5', 'Close_roll_min_5', 'Close_roll_max_5',
'Close_roll_mean_15', 'Close_roll_std_15', 'Close_roll_min_15', 'Close_roll_max_15', 'Close_roll_mean_30',
'Close_roll_std_30', 'Close_roll_min_30', 'Close_roll_max_30', 'Volume_roll_max_5', 'Volume_roll_max_15',
'Volume_roll_max_30', 'supertrend_12_3.0', 'supertrend_10_1.0', 'supertrend_11_2.0',
]
def build_feature_list(all_columns: Sequence[str]) -> List[str]:
"""Return the model feature list by excluding base columns and targets."""
return [col for col in all_columns if col not in EXCLUDE_BASE_FEATURES]
def prune_features(
feature_cols: Sequence[str],
importance_avg: Dict[str, float] | None,
cfg: PruningConfig,
) -> List[str]:
"""Decide which features to keep using averaged importances and rules."""
prune_set: Set[str] = set()
if importance_avg is not None:
sorted_feats = sorted(importance_avg.items(), key=lambda kv: kv[1], reverse=True)
keep_names = set(name for name, _ in sorted_feats[: cfg.top_k])
for name in feature_cols:
if name not in keep_names:
prune_set.add(name)
for name in cfg.known_low_features:
if name in feature_cols:
prune_set.add(name)
# If Parkinson vol exists, drop alternatives at same window
for w in [5, 15, 30]:
park = f'park_vol_{w}'
if park in feature_cols:
for alt in [f'gk_vol_{w}', f'rs_vol_{w}', f'yz_vol_{w}']:
if alt in feature_cols:
prune_set.add(alt)
kept = [c for c in feature_cols if c not in prune_set]
return kept