60 lines
2.6 KiB
Python
60 lines
2.6 KiB
Python
|
|
from typing import Dict, Iterable, List, Sequence, Set, Tuple
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from .config import PruningConfig
|
||
|
|
|
||
|
|
|
||
|
|
EXCLUDE_BASE_FEATURES: List[str] = [
|
||
|
|
'Timestamp', 'Close',
|
||
|
|
'log_return_5', 'volatility_5', 'volatility_15', 'volatility_30',
|
||
|
|
'bb_bbm', 'bb_bbh', 'bb_bbl', 'stoch_k', 'sma_50', 'sma_200', 'psar',
|
||
|
|
'donchian_hband', 'donchian_lband', 'donchian_mband', 'keltner_hband', 'keltner_lband',
|
||
|
|
'keltner_mband', 'ichimoku_a', 'ichimoku_b', 'ichimoku_base_line', 'ichimoku_conversion_line',
|
||
|
|
'Open_lag1', 'Open_lag2', 'Open_lag3', 'High_lag1', 'High_lag2', 'High_lag3', 'Low_lag1', 'Low_lag2',
|
||
|
|
'Low_lag3', 'Close_lag1', 'Close_lag2', 'Close_lag3', 'Open_roll_mean_15', 'Open_roll_std_15', 'Open_roll_min_15',
|
||
|
|
'Open_roll_max_15', 'Open_roll_mean_30', 'Open_roll_min_30', 'Open_roll_max_30', 'High_roll_mean_15', 'High_roll_std_15',
|
||
|
|
'High_roll_min_15', 'High_roll_max_15', 'Low_roll_mean_5', 'Low_roll_min_5', 'Low_roll_max_5', 'Low_roll_mean_30',
|
||
|
|
'Low_roll_std_30', 'Low_roll_min_30', 'Low_roll_max_30', 'Close_roll_mean_5', 'Close_roll_min_5', 'Close_roll_max_5',
|
||
|
|
'Close_roll_mean_15', 'Close_roll_std_15', 'Close_roll_min_15', 'Close_roll_max_15', 'Close_roll_mean_30',
|
||
|
|
'Close_roll_std_30', 'Close_roll_min_30', 'Close_roll_max_30', 'Volume_roll_max_5', 'Volume_roll_max_15',
|
||
|
|
'Volume_roll_max_30', 'supertrend_12_3.0', 'supertrend_10_1.0', 'supertrend_11_2.0',
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def build_feature_list(all_columns: Sequence[str]) -> List[str]:
|
||
|
|
"""Return the model feature list by excluding base columns and targets."""
|
||
|
|
return [col for col in all_columns if col not in EXCLUDE_BASE_FEATURES]
|
||
|
|
|
||
|
|
|
||
|
|
def prune_features(
|
||
|
|
feature_cols: Sequence[str],
|
||
|
|
importance_avg: Dict[str, float] | None,
|
||
|
|
cfg: PruningConfig,
|
||
|
|
) -> List[str]:
|
||
|
|
"""Decide which features to keep using averaged importances and rules."""
|
||
|
|
prune_set: Set[str] = set()
|
||
|
|
|
||
|
|
if importance_avg is not None:
|
||
|
|
sorted_feats = sorted(importance_avg.items(), key=lambda kv: kv[1], reverse=True)
|
||
|
|
keep_names = set(name for name, _ in sorted_feats[: cfg.top_k])
|
||
|
|
for name in feature_cols:
|
||
|
|
if name not in keep_names:
|
||
|
|
prune_set.add(name)
|
||
|
|
|
||
|
|
for name in cfg.known_low_features:
|
||
|
|
if name in feature_cols:
|
||
|
|
prune_set.add(name)
|
||
|
|
|
||
|
|
# If Parkinson vol exists, drop alternatives at same window
|
||
|
|
for w in [5, 15, 30]:
|
||
|
|
park = f'park_vol_{w}'
|
||
|
|
if park in feature_cols:
|
||
|
|
for alt in [f'gk_vol_{w}', f'rs_vol_{w}', f'yz_vol_{w}']:
|
||
|
|
if alt in feature_cols:
|
||
|
|
prune_set.add(alt)
|
||
|
|
|
||
|
|
kept = [c for c in feature_cols if c not in prune_set]
|
||
|
|
return kept
|
||
|
|
|
||
|
|
|