29 lines
736 B
Python
29 lines
736 B
Python
|
|
from typing import Dict, List, Tuple
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from custom_xgboost import CustomXGBoostGPU
|
||
|
|
|
||
|
|
|
||
|
|
def train_model(
|
||
|
|
X_train: np.ndarray,
|
||
|
|
X_test: np.ndarray,
|
||
|
|
y_train: np.ndarray,
|
||
|
|
y_test: np.ndarray,
|
||
|
|
eval_metric: str = 'rmse',
|
||
|
|
):
|
||
|
|
"""Train the XGBoost model and return the fitted wrapper."""
|
||
|
|
model = CustomXGBoostGPU(X_train, X_test, y_train, y_test)
|
||
|
|
model.train(eval_metric=eval_metric)
|
||
|
|
return model
|
||
|
|
|
||
|
|
|
||
|
|
def predict(model: CustomXGBoostGPU, X: np.ndarray) -> np.ndarray:
|
||
|
|
"""Predict using the trained model."""
|
||
|
|
return model.predict(X)
|
||
|
|
|
||
|
|
|
||
|
|
def get_feature_importance(model: CustomXGBoostGPU, feature_names: List[str]) -> Dict[str, float]:
|
||
|
|
return model.get_feature_importance(feature_names)
|
||
|
|
|
||
|
|
|