150 lines
6.1 KiB
Python
150 lines
6.1 KiB
Python
import json
|
|
import threading
|
|
import time
|
|
import numpy as np
|
|
import pandas as pd
|
|
import xgboost as xgb
|
|
from datetime import datetime
|
|
from okx_client import OKXClient
|
|
import dash
|
|
from dash import dcc, html
|
|
from dash.dependencies import Output, Input
|
|
import plotly.graph_objs as go
|
|
import websocket
|
|
|
|
# --- Prediction utilities (from test_predictor.py) ---
|
|
def add_features(df):
|
|
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
|
def calc_rsi(close, window=14):
|
|
delta = close.diff()
|
|
up = delta.clip(lower=0)
|
|
down = -1 * delta.clip(upper=0)
|
|
ma_up = up.rolling(window=window, min_periods=window).mean()
|
|
ma_down = down.rolling(window=window, min_periods=window).mean()
|
|
rs = ma_up / ma_down
|
|
return 100 - (100 / (1 + rs))
|
|
df['rsi'] = calc_rsi(df['Close'])
|
|
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
|
|
df['sma_50'] = df['Close'].rolling(window=50).mean()
|
|
df['sma_200'] = df['Close'].rolling(window=200).mean()
|
|
high_low = df['High'] - df['Low']
|
|
high_close = np.abs(df['High'] - df['Close'].shift(1))
|
|
low_close = np.abs(df['Low'] - df['Close'].shift(1))
|
|
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
|
df['atr'] = tr.rolling(window=14).mean()
|
|
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
|
|
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
|
|
df['hour'] = df['Timestamp'].dt.hour
|
|
return df
|
|
|
|
def load_model_and_features(model_path):
|
|
model = xgb.Booster()
|
|
model.load_model(model_path)
|
|
feature_cols = [
|
|
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
|
|
]
|
|
return model, feature_cols
|
|
|
|
def predict_next(df, model, feature_cols):
|
|
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
|
|
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
|
|
pred_log_return = model.predict(dmatrix)[0]
|
|
last_price = df['Close'].iloc[-1]
|
|
pred_price = last_price * np.exp(pred_log_return)
|
|
return pred_price, pred_log_return, last_price
|
|
|
|
# --- Main live plotting ---
|
|
WINDOW = 50
|
|
MODEL_PATH = 'data/xgboost_model.json'
|
|
|
|
ohlcv_bars = [] # [timestamp, open, high, low, close, volume]
|
|
bar_lock = threading.Lock()
|
|
model, feature_cols = load_model_and_features(MODEL_PATH)
|
|
|
|
# --- Background thread to collect trades and aggregate OHLCV ---
|
|
def ws_collector():
|
|
client = OKXClient(authenticate=False)
|
|
client.subscribe_trades(instrument="BTC-USDT")
|
|
current_bar = None
|
|
while True:
|
|
try:
|
|
msg = client.ws.recv()
|
|
data = json.loads(msg)
|
|
except websocket._exceptions.WebSocketTimeoutException:
|
|
continue # Just try again
|
|
except Exception as e:
|
|
print(f"WebSocket error: {e}")
|
|
break # or try to reconnect
|
|
if 'arg' in data and data['arg'].get('channel', '') == 'trades':
|
|
for trade in data.get('data', []):
|
|
# trade: {'instId', 'tradeId', 'px', 'sz', 'side', 'ts'}
|
|
ts = int(trade['ts'])
|
|
price = float(trade['px'])
|
|
size = float(trade['sz'])
|
|
dt = datetime.utcfromtimestamp(ts / 1000)
|
|
bar_seconds = 30 # or 15, 30, etc.
|
|
bar_time = dt.replace(second=(dt.second // bar_seconds) * bar_seconds, microsecond=0)
|
|
with bar_lock:
|
|
if not ohlcv_bars or ohlcv_bars[-1][0] != bar_time:
|
|
# New bar
|
|
ohlcv_bars.append([bar_time, price, price, price, price, size])
|
|
if len(ohlcv_bars) > WINDOW:
|
|
ohlcv_bars.pop(0)
|
|
else:
|
|
# Update current bar
|
|
bar = ohlcv_bars[-1]
|
|
bar[2] = max(bar[2], price) # high
|
|
bar[3] = min(bar[3], price) # low
|
|
bar[4] = price # close
|
|
bar[5] += size # volume
|
|
|
|
# Start the background thread
|
|
threading.Thread(target=ws_collector, daemon=True).start()
|
|
|
|
# --- Dash App ---
|
|
app = dash.Dash(__name__)
|
|
app.layout = html.Div([
|
|
html.H2('BTC/USDT Price & Prediction (OKX, XGBoost, Trades Aggregated)', style={"textAlign": "center", "margin": 0, "padding": 0}),
|
|
dcc.Graph(id='live-graph', animate=False, style={"height": "90vh", "width": "100vw", "margin": 0, "padding": 0}),
|
|
dcc.Interval(
|
|
id='interval-component',
|
|
interval=5*1000, # 5 seconds
|
|
n_intervals=0
|
|
),
|
|
html.Div(id='prediction-output', style={"textAlign": "center", "fontSize": 20, "marginTop": 10})
|
|
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
|
|
|
|
@app.callback(
|
|
[Output('live-graph', 'figure'), Output('prediction-output', 'children')],
|
|
[Input('interval-component', 'n_intervals')]
|
|
)
|
|
def update_graph_live(n):
|
|
with bar_lock:
|
|
bars = list(ohlcv_bars)
|
|
if len(bars) < 2:
|
|
return go.Figure(), "Waiting for data..."
|
|
df = pd.DataFrame(bars, columns=["Timestamp", "Open", "High", "Low", "Close", "Volume"])
|
|
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
|
|
df = add_features(df)
|
|
if df[feature_cols].isnull().any().any():
|
|
pred_text = "Not enough data for prediction."
|
|
else:
|
|
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
|
|
pred_text = f"Last: {last_price:.2f} | Predicted next: {pred_price:.2f} | LogRet: {pred_log_return:.6f}"
|
|
fig = go.Figure()
|
|
fig.add_trace(go.Candlestick(
|
|
x=df["Timestamp"],
|
|
open=df["Open"],
|
|
high=df["High"],
|
|
low=df["Low"],
|
|
close=df["Close"],
|
|
name='Candlestick',
|
|
increasing_line_color='green',
|
|
decreasing_line_color='red',
|
|
showlegend=False
|
|
))
|
|
fig.update_layout(title='BTC-USDT 1m OHLCV (Aggregated from Trades)', xaxis_title='Time', yaxis_title='Price (USDT)', xaxis_rangeslider_visible=False)
|
|
return fig, pred_text
|
|
|
|
if __name__ == '__main__':
|
|
app.run(debug=True) |