import json import threading import time import numpy as np import pandas as pd import xgboost as xgb from datetime import datetime from okx_client import OKXClient import dash from dash import dcc, html from dash.dependencies import Output, Input import plotly.graph_objs as go import websocket # --- Prediction utilities (from test_predictor.py) --- def add_features(df): df['log_return'] = np.log(df['Close'] / df['Close'].shift(1)) def calc_rsi(close, window=14): delta = close.diff() up = delta.clip(lower=0) down = -1 * delta.clip(upper=0) ma_up = up.rolling(window=window, min_periods=window).mean() ma_down = down.rolling(window=window, min_periods=window).mean() rs = ma_up / ma_down return 100 - (100 / (1 + rs)) df['rsi'] = calc_rsi(df['Close']) df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean() df['sma_50'] = df['Close'].rolling(window=50).mean() df['sma_200'] = df['Close'].rolling(window=200).mean() high_low = df['High'] - df['Low'] high_close = np.abs(df['High'] - df['Close'].shift(1)) low_close = np.abs(df['Low'] - df['Close'].shift(1)) tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1) df['atr'] = tr.rolling(window=14).mean() df['roc_10'] = df['Close'].pct_change(periods=10) * 100 df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10) df['hour'] = df['Timestamp'].dt.hour return df def load_model_and_features(model_path): model = xgb.Booster() model.load_model(model_path) feature_cols = [ 'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour' ] return model, feature_cols def predict_next(df, model, feature_cols): X = df[feature_cols].iloc[[-1]].values.astype(np.float32) dmatrix = xgb.DMatrix(X, feature_names=feature_cols) pred_log_return = model.predict(dmatrix)[0] last_price = df['Close'].iloc[-1] pred_price = last_price * np.exp(pred_log_return) return pred_price, pred_log_return, last_price # --- Main live plotting --- WINDOW = 50 MODEL_PATH = 'data/xgboost_model.json' ohlcv_bars = [] # [timestamp, open, high, low, close, volume] bar_lock = threading.Lock() model, feature_cols = load_model_and_features(MODEL_PATH) # --- Background thread to collect trades and aggregate OHLCV --- def ws_collector(): client = OKXClient(authenticate=False) client.subscribe_trades(instrument="BTC-USDT") current_bar = None while True: try: msg = client.ws.recv() data = json.loads(msg) except websocket._exceptions.WebSocketTimeoutException: continue # Just try again except Exception as e: print(f"WebSocket error: {e}") break # or try to reconnect if 'arg' in data and data['arg'].get('channel', '') == 'trades': for trade in data.get('data', []): # trade: {'instId', 'tradeId', 'px', 'sz', 'side', 'ts'} ts = int(trade['ts']) price = float(trade['px']) size = float(trade['sz']) dt = datetime.utcfromtimestamp(ts / 1000) bar_seconds = 30 # or 15, 30, etc. bar_time = dt.replace(second=(dt.second // bar_seconds) * bar_seconds, microsecond=0) with bar_lock: if not ohlcv_bars or ohlcv_bars[-1][0] != bar_time: # New bar ohlcv_bars.append([bar_time, price, price, price, price, size]) if len(ohlcv_bars) > WINDOW: ohlcv_bars.pop(0) else: # Update current bar bar = ohlcv_bars[-1] bar[2] = max(bar[2], price) # high bar[3] = min(bar[3], price) # low bar[4] = price # close bar[5] += size # volume # Start the background thread threading.Thread(target=ws_collector, daemon=True).start() # --- Dash App --- app = dash.Dash(__name__) app.layout = html.Div([ html.H2('BTC/USDT Price & Prediction (OKX, XGBoost, Trades Aggregated)', style={"textAlign": "center", "margin": 0, "padding": 0}), dcc.Graph(id='live-graph', animate=False, style={"height": "90vh", "width": "100vw", "margin": 0, "padding": 0}), dcc.Interval( id='interval-component', interval=5*1000, # 5 seconds n_intervals=0 ), html.Div(id='prediction-output', style={"textAlign": "center", "fontSize": 20, "marginTop": 10}) ], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"}) @app.callback( [Output('live-graph', 'figure'), Output('prediction-output', 'children')], [Input('interval-component', 'n_intervals')] ) def update_graph_live(n): with bar_lock: bars = list(ohlcv_bars) if len(bars) < 2: return go.Figure(), "Waiting for data..." df = pd.DataFrame(bars, columns=["Timestamp", "Open", "High", "Low", "Close", "Volume"]) df["Timestamp"] = pd.to_datetime(df["Timestamp"]) df = add_features(df) if df[feature_cols].isnull().any().any(): pred_text = "Not enough data for prediction." else: pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols) pred_text = f"Last: {last_price:.2f} | Predicted next: {pred_price:.2f} | LogRet: {pred_log_return:.6f}" fig = go.Figure() fig.add_trace(go.Candlestick( x=df["Timestamp"], open=df["Open"], high=df["High"], low=df["Low"], close=df["Close"], name='Candlestick', increasing_line_color='green', decreasing_line_color='red', showlegend=False )) fig.update_layout(title='BTC-USDT 1m OHLCV (Aggregated from Trades)', xaxis_title='Time', yaxis_title='Price (USDT)', xaxis_rangeslider_visible=False) return fig, pred_text if __name__ == '__main__': app.run(debug=True)