best prediction so far
This commit is contained in:
parent
0bbb0e52af
commit
120c366576
18
main.py
18
main.py
@ -205,13 +205,13 @@ if __name__ == '__main__':
|
|||||||
test_timestamps = df['Timestamp'].values[split_idx:]
|
test_timestamps = df['Timestamp'].values[split_idx:]
|
||||||
|
|
||||||
model = CustomXGBoostGPU(X_train, X_test, y_train, y_test)
|
model = CustomXGBoostGPU(X_train, X_test, y_train, y_test)
|
||||||
booster = model.train(
|
booster = model.train(eval_metric='rmse')
|
||||||
colsample_bytree=1.0,
|
# colsample_bytree=1.0,
|
||||||
learning_rate=0.05,
|
# learning_rate=0.05,
|
||||||
max_depth=7,
|
# max_depth=7,
|
||||||
n_estimators=200,
|
# n_estimators=200,
|
||||||
subsample=0.8
|
# subsample=0.8
|
||||||
)
|
# )
|
||||||
model.save_model(f'../data/xgboost_model_all_features.json')
|
model.save_model(f'../data/xgboost_model_all_features.json')
|
||||||
|
|
||||||
test_preds = model.predict(X_test)
|
test_preds = model.predict(X_test)
|
||||||
@ -232,7 +232,7 @@ if __name__ == '__main__':
|
|||||||
predicted_prices.append(predicted_prices[-1] * np.exp(r_))
|
predicted_prices.append(predicted_prices[-1] * np.exp(r_))
|
||||||
predicted_prices = np.array(predicted_prices[1:])
|
predicted_prices = np.array(predicted_prices[1:])
|
||||||
|
|
||||||
mae = mean_absolute_error(actual_prices, predicted_prices)
|
# mae = mean_absolute_error(actual_prices, predicted_prices)
|
||||||
r2 = r2_score(actual_prices, predicted_prices)
|
r2 = r2_score(actual_prices, predicted_prices)
|
||||||
direction_actual = np.sign(np.diff(actual_prices))
|
direction_actual = np.sign(np.diff(actual_prices))
|
||||||
direction_pred = np.sign(np.diff(predicted_prices))
|
direction_pred = np.sign(np.diff(predicted_prices))
|
||||||
@ -247,7 +247,7 @@ if __name__ == '__main__':
|
|||||||
importance = feature_importance_dict.get(feature, 0.0)
|
importance = feature_importance_dict.get(feature, 0.0)
|
||||||
fi_str = format(importance, ".6f")
|
fi_str = format(importance, ".6f")
|
||||||
row = [feature]
|
row = [feature]
|
||||||
for val in [rmse, mae, r2, mape, directional_accuracy]:
|
for val in [rmse, mape, r2, directional_accuracy]:
|
||||||
if isinstance(val, float):
|
if isinstance(val, float):
|
||||||
row.append(format(val, '.10f'))
|
row.append(format(val, '.10f'))
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user