MarketDataCollector/live_plot.py
2025-05-30 12:40:49 +08:00

150 lines
6.1 KiB
Python

import json
import threading
import time
import numpy as np
import pandas as pd
import xgboost as xgb
from datetime import datetime
from okx_client import OKXClient
import dash
from dash import dcc, html
from dash.dependencies import Output, Input
import plotly.graph_objs as go
import websocket
# --- Prediction utilities (from test_predictor.py) ---
def add_features(df):
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
def calc_rsi(close, window=14):
delta = close.diff()
up = delta.clip(lower=0)
down = -1 * delta.clip(upper=0)
ma_up = up.rolling(window=window, min_periods=window).mean()
ma_down = down.rolling(window=window, min_periods=window).mean()
rs = ma_up / ma_down
return 100 - (100 / (1 + rs))
df['rsi'] = calc_rsi(df['Close'])
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
df['sma_50'] = df['Close'].rolling(window=50).mean()
df['sma_200'] = df['Close'].rolling(window=200).mean()
high_low = df['High'] - df['Low']
high_close = np.abs(df['High'] - df['Close'].shift(1))
low_close = np.abs(df['Low'] - df['Close'].shift(1))
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
df['atr'] = tr.rolling(window=14).mean()
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
df['hour'] = df['Timestamp'].dt.hour
return df
def load_model_and_features(model_path):
model = xgb.Booster()
model.load_model(model_path)
feature_cols = [
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
]
return model, feature_cols
def predict_next(df, model, feature_cols):
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
pred_log_return = model.predict(dmatrix)[0]
last_price = df['Close'].iloc[-1]
pred_price = last_price * np.exp(pred_log_return)
return pred_price, pred_log_return, last_price
# --- Main live plotting ---
WINDOW = 50
MODEL_PATH = 'data/xgboost_model.json'
ohlcv_bars = [] # [timestamp, open, high, low, close, volume]
bar_lock = threading.Lock()
model, feature_cols = load_model_and_features(MODEL_PATH)
# --- Background thread to collect trades and aggregate OHLCV ---
def ws_collector():
client = OKXClient(authenticate=False)
client.subscribe_trades(instrument="BTC-USDT")
current_bar = None
while True:
try:
msg = client.ws.recv()
data = json.loads(msg)
except websocket._exceptions.WebSocketTimeoutException:
continue # Just try again
except Exception as e:
print(f"WebSocket error: {e}")
break # or try to reconnect
if 'arg' in data and data['arg'].get('channel', '') == 'trades':
for trade in data.get('data', []):
# trade: {'instId', 'tradeId', 'px', 'sz', 'side', 'ts'}
ts = int(trade['ts'])
price = float(trade['px'])
size = float(trade['sz'])
dt = datetime.utcfromtimestamp(ts / 1000)
bar_seconds = 30 # or 15, 30, etc.
bar_time = dt.replace(second=(dt.second // bar_seconds) * bar_seconds, microsecond=0)
with bar_lock:
if not ohlcv_bars or ohlcv_bars[-1][0] != bar_time:
# New bar
ohlcv_bars.append([bar_time, price, price, price, price, size])
if len(ohlcv_bars) > WINDOW:
ohlcv_bars.pop(0)
else:
# Update current bar
bar = ohlcv_bars[-1]
bar[2] = max(bar[2], price) # high
bar[3] = min(bar[3], price) # low
bar[4] = price # close
bar[5] += size # volume
# Start the background thread
threading.Thread(target=ws_collector, daemon=True).start()
# --- Dash App ---
app = dash.Dash(__name__)
app.layout = html.Div([
html.H2('BTC/USDT Price & Prediction (OKX, XGBoost, Trades Aggregated)', style={"textAlign": "center", "margin": 0, "padding": 0}),
dcc.Graph(id='live-graph', animate=False, style={"height": "90vh", "width": "100vw", "margin": 0, "padding": 0}),
dcc.Interval(
id='interval-component',
interval=5*1000, # 5 seconds
n_intervals=0
),
html.Div(id='prediction-output', style={"textAlign": "center", "fontSize": 20, "marginTop": 10})
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
@app.callback(
[Output('live-graph', 'figure'), Output('prediction-output', 'children')],
[Input('interval-component', 'n_intervals')]
)
def update_graph_live(n):
with bar_lock:
bars = list(ohlcv_bars)
if len(bars) < 2:
return go.Figure(), "Waiting for data..."
df = pd.DataFrame(bars, columns=["Timestamp", "Open", "High", "Low", "Close", "Volume"])
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
df = add_features(df)
if df[feature_cols].isnull().any().any():
pred_text = "Not enough data for prediction."
else:
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
pred_text = f"Last: {last_price:.2f} | Predicted next: {pred_price:.2f} | LogRet: {pred_log_return:.6f}"
fig = go.Figure()
fig.add_trace(go.Candlestick(
x=df["Timestamp"],
open=df["Open"],
high=df["High"],
low=df["Low"],
close=df["Close"],
name='Candlestick',
increasing_line_color='green',
decreasing_line_color='red',
showlegend=False
))
fig.update_layout(title='BTC-USDT 1m OHLCV (Aggregated from Trades)', xaxis_title='Time', yaxis_title='Price (USDT)', xaxis_rangeslider_visible=False)
return fig, pred_text
if __name__ == '__main__':
app.run(debug=True)