MarketDataCollector/test_predictor.py
2025-05-30 12:40:49 +08:00

123 lines
4.7 KiB
Python

import time
import json
import numpy as np
import pandas as pd
import requests
import xgboost as xgb
from datetime import datetime, timedelta
# --- CONFIG ---
OKX_REST_URL = 'https://www.okx.com/api/v5/market/candles'
SYMBOL = 'BTC-USDT'
BAR = '1m'
HIST_MINUTES = 250 # Number of minutes of history to fetch for features
MODEL_PATH = 'data/xgboost_model.json'
# --- Fetch recent candles from OKX REST API ---
def fetch_recent_candles(symbol, bar, limit=HIST_MINUTES):
params = {
'instId': symbol,
'bar': bar,
'limit': str(limit)
}
resp = requests.get(OKX_REST_URL, params=params)
data = resp.json()
if data['code'] != '0':
raise Exception(f"OKX API error: {data['msg']}")
# OKX returns most recent first, reverse to chronological
candles = data['data'][::-1]
df = pd.DataFrame(candles)
# OKX columns: [ts, o, h, l, c, vol, volCcy, confirm, ...] (see API docs)
# We'll use: ts, o, h, l, c, vol
col_map = {
0: 'Timestamp',
1: 'Open',
2: 'High',
3: 'Low',
4: 'Close',
5: 'Volume',
}
df = df.rename(columns={str(k): v for k, v in col_map.items()})
# If columns are not named, use integer index
for k, v in col_map.items():
if v not in df.columns:
df[v] = df.iloc[:, k]
df = df[['Timestamp', 'Open', 'High', 'Low', 'Close', 'Volume']]
df['Timestamp'] = pd.to_datetime(df['Timestamp'].astype(np.int64), unit='ms')
for col in ['Open', 'High', 'Low', 'Close', 'Volume']:
df[col] = pd.to_numeric(df[col], errors='coerce')
return df
# --- Feature Engineering (minimal, real-time) ---
def add_features(df):
# Log return (target, not used for prediction)
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
# RSI (14)
def calc_rsi(close, window=14):
delta = close.diff()
up = delta.clip(lower=0)
down = -1 * delta.clip(upper=0)
ma_up = up.rolling(window=window, min_periods=window).mean()
ma_down = down.rolling(window=window, min_periods=window).mean()
rs = ma_up / ma_down
return 100 - (100 / (1 + rs))
df['rsi'] = calc_rsi(df['Close'])
# EMA 14
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
# SMA 50, 200
df['sma_50'] = df['Close'].rolling(window=50).mean()
df['sma_200'] = df['Close'].rolling(window=200).mean()
# ATR 14
high_low = df['High'] - df['Low']
high_close = np.abs(df['High'] - df['Close'].shift(1))
low_close = np.abs(df['Low'] - df['Close'].shift(1))
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
df['atr'] = tr.rolling(window=14).mean()
# ROC 10
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
# DPO 20
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
# Hour
df['hour'] = df['Timestamp'].dt.hour
# Add more features as needed (match with main.py)
return df
# --- Load model and feature columns ---
def load_model_and_features(model_path):
model = xgb.Booster()
model.load_model(model_path)
# Try to infer feature names from main.py (hardcoded for now)
feature_cols = [
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
]
return model, feature_cols
# --- Predict next log return and price ---
def predict_next(df, model, feature_cols):
# Use the last row for prediction
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
pred_log_return = model.predict(dmatrix)[0]
last_price = df['Close'].iloc[-1]
pred_price = last_price * np.exp(pred_log_return)
return pred_price, pred_log_return, last_price
if __name__ == '__main__':
print('Fetching recent candles from OKX...')
df = fetch_recent_candles(SYMBOL, BAR)
df = add_features(df)
model, feature_cols = load_model_and_features(MODEL_PATH)
print('Waiting for new candle...')
last_timestamp = df['Timestamp'].iloc[-1]
while True:
time.sleep(5)
new_df = fetch_recent_candles(SYMBOL, BAR, limit=HIST_MINUTES)
if new_df['Timestamp'].iloc[-1] > last_timestamp:
df = new_df
df = add_features(df)
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
print(f"[{df['Timestamp'].iloc[-1]}] Last price: {last_price:.2f} | Predicted next price: {pred_price:.2f} | Predicted log return: {pred_log_return:.6f}")
last_timestamp = df['Timestamp'].iloc[-1]
else:
print('No new candle yet...')