29 lines
736 B
Python
Raw Permalink Normal View History

from typing import Dict, List, Tuple
import numpy as np
from custom_xgboost import CustomXGBoostGPU
def train_model(
X_train: np.ndarray,
X_test: np.ndarray,
y_train: np.ndarray,
y_test: np.ndarray,
eval_metric: str = 'rmse',
):
"""Train the XGBoost model and return the fitted wrapper."""
model = CustomXGBoostGPU(X_train, X_test, y_train, y_test)
model.train(eval_metric=eval_metric)
return model
def predict(model: CustomXGBoostGPU, X: np.ndarray) -> np.ndarray:
"""Predict using the trained model."""
return model.predict(X)
def get_feature_importance(model: CustomXGBoostGPU, feature_names: List[str]) -> Dict[str, float]:
return model.get_feature_importance(feature_names)