123 lines
4.7 KiB
Python
123 lines
4.7 KiB
Python
import time
|
|
import json
|
|
import numpy as np
|
|
import pandas as pd
|
|
import requests
|
|
import xgboost as xgb
|
|
from datetime import datetime, timedelta
|
|
|
|
# --- CONFIG ---
|
|
OKX_REST_URL = 'https://www.okx.com/api/v5/market/candles'
|
|
SYMBOL = 'BTC-USDT'
|
|
BAR = '1m'
|
|
HIST_MINUTES = 250 # Number of minutes of history to fetch for features
|
|
MODEL_PATH = 'data/xgboost_model.json'
|
|
|
|
# --- Fetch recent candles from OKX REST API ---
|
|
def fetch_recent_candles(symbol, bar, limit=HIST_MINUTES):
|
|
params = {
|
|
'instId': symbol,
|
|
'bar': bar,
|
|
'limit': str(limit)
|
|
}
|
|
resp = requests.get(OKX_REST_URL, params=params)
|
|
data = resp.json()
|
|
if data['code'] != '0':
|
|
raise Exception(f"OKX API error: {data['msg']}")
|
|
# OKX returns most recent first, reverse to chronological
|
|
candles = data['data'][::-1]
|
|
df = pd.DataFrame(candles)
|
|
# OKX columns: [ts, o, h, l, c, vol, volCcy, confirm, ...] (see API docs)
|
|
# We'll use: ts, o, h, l, c, vol
|
|
col_map = {
|
|
0: 'Timestamp',
|
|
1: 'Open',
|
|
2: 'High',
|
|
3: 'Low',
|
|
4: 'Close',
|
|
5: 'Volume',
|
|
}
|
|
df = df.rename(columns={str(k): v for k, v in col_map.items()})
|
|
# If columns are not named, use integer index
|
|
for k, v in col_map.items():
|
|
if v not in df.columns:
|
|
df[v] = df.iloc[:, k]
|
|
df = df[['Timestamp', 'Open', 'High', 'Low', 'Close', 'Volume']]
|
|
df['Timestamp'] = pd.to_datetime(df['Timestamp'].astype(np.int64), unit='ms')
|
|
for col in ['Open', 'High', 'Low', 'Close', 'Volume']:
|
|
df[col] = pd.to_numeric(df[col], errors='coerce')
|
|
return df
|
|
|
|
# --- Feature Engineering (minimal, real-time) ---
|
|
def add_features(df):
|
|
# Log return (target, not used for prediction)
|
|
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
|
# RSI (14)
|
|
def calc_rsi(close, window=14):
|
|
delta = close.diff()
|
|
up = delta.clip(lower=0)
|
|
down = -1 * delta.clip(upper=0)
|
|
ma_up = up.rolling(window=window, min_periods=window).mean()
|
|
ma_down = down.rolling(window=window, min_periods=window).mean()
|
|
rs = ma_up / ma_down
|
|
return 100 - (100 / (1 + rs))
|
|
df['rsi'] = calc_rsi(df['Close'])
|
|
# EMA 14
|
|
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
|
|
# SMA 50, 200
|
|
df['sma_50'] = df['Close'].rolling(window=50).mean()
|
|
df['sma_200'] = df['Close'].rolling(window=200).mean()
|
|
# ATR 14
|
|
high_low = df['High'] - df['Low']
|
|
high_close = np.abs(df['High'] - df['Close'].shift(1))
|
|
low_close = np.abs(df['Low'] - df['Close'].shift(1))
|
|
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
|
df['atr'] = tr.rolling(window=14).mean()
|
|
# ROC 10
|
|
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
|
|
# DPO 20
|
|
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
|
|
# Hour
|
|
df['hour'] = df['Timestamp'].dt.hour
|
|
# Add more features as needed (match with main.py)
|
|
return df
|
|
|
|
# --- Load model and feature columns ---
|
|
def load_model_and_features(model_path):
|
|
model = xgb.Booster()
|
|
model.load_model(model_path)
|
|
# Try to infer feature names from main.py (hardcoded for now)
|
|
feature_cols = [
|
|
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
|
|
]
|
|
return model, feature_cols
|
|
|
|
# --- Predict next log return and price ---
|
|
def predict_next(df, model, feature_cols):
|
|
# Use the last row for prediction
|
|
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
|
|
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
|
|
pred_log_return = model.predict(dmatrix)[0]
|
|
last_price = df['Close'].iloc[-1]
|
|
pred_price = last_price * np.exp(pred_log_return)
|
|
return pred_price, pred_log_return, last_price
|
|
|
|
if __name__ == '__main__':
|
|
print('Fetching recent candles from OKX...')
|
|
df = fetch_recent_candles(SYMBOL, BAR)
|
|
df = add_features(df)
|
|
model, feature_cols = load_model_and_features(MODEL_PATH)
|
|
print('Waiting for new candle...')
|
|
last_timestamp = df['Timestamp'].iloc[-1]
|
|
while True:
|
|
time.sleep(5)
|
|
new_df = fetch_recent_candles(SYMBOL, BAR, limit=HIST_MINUTES)
|
|
if new_df['Timestamp'].iloc[-1] > last_timestamp:
|
|
df = new_df
|
|
df = add_features(df)
|
|
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
|
|
print(f"[{df['Timestamp'].iloc[-1]}] Last price: {last_price:.2f} | Predicted next price: {pred_price:.2f} | Predicted log return: {pred_log_return:.6f}")
|
|
last_timestamp = df['Timestamp'].iloc[-1]
|
|
else:
|
|
print('No new candle yet...')
|