testing live plot
This commit is contained in:
parent
8f96e14b8b
commit
f534825e53
150
live_plot.py
Normal file
150
live_plot.py
Normal file
@ -0,0 +1,150 @@
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import xgboost as xgb
|
||||
from datetime import datetime
|
||||
from okx_client import OKXClient
|
||||
import dash
|
||||
from dash import dcc, html
|
||||
from dash.dependencies import Output, Input
|
||||
import plotly.graph_objs as go
|
||||
import websocket
|
||||
|
||||
# --- Prediction utilities (from test_predictor.py) ---
|
||||
def add_features(df):
|
||||
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||
def calc_rsi(close, window=14):
|
||||
delta = close.diff()
|
||||
up = delta.clip(lower=0)
|
||||
down = -1 * delta.clip(upper=0)
|
||||
ma_up = up.rolling(window=window, min_periods=window).mean()
|
||||
ma_down = down.rolling(window=window, min_periods=window).mean()
|
||||
rs = ma_up / ma_down
|
||||
return 100 - (100 / (1 + rs))
|
||||
df['rsi'] = calc_rsi(df['Close'])
|
||||
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
|
||||
df['sma_50'] = df['Close'].rolling(window=50).mean()
|
||||
df['sma_200'] = df['Close'].rolling(window=200).mean()
|
||||
high_low = df['High'] - df['Low']
|
||||
high_close = np.abs(df['High'] - df['Close'].shift(1))
|
||||
low_close = np.abs(df['Low'] - df['Close'].shift(1))
|
||||
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
||||
df['atr'] = tr.rolling(window=14).mean()
|
||||
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
|
||||
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
|
||||
df['hour'] = df['Timestamp'].dt.hour
|
||||
return df
|
||||
|
||||
def load_model_and_features(model_path):
|
||||
model = xgb.Booster()
|
||||
model.load_model(model_path)
|
||||
feature_cols = [
|
||||
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
|
||||
]
|
||||
return model, feature_cols
|
||||
|
||||
def predict_next(df, model, feature_cols):
|
||||
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
|
||||
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
|
||||
pred_log_return = model.predict(dmatrix)[0]
|
||||
last_price = df['Close'].iloc[-1]
|
||||
pred_price = last_price * np.exp(pred_log_return)
|
||||
return pred_price, pred_log_return, last_price
|
||||
|
||||
# --- Main live plotting ---
|
||||
WINDOW = 50
|
||||
MODEL_PATH = 'data/xgboost_model.json'
|
||||
|
||||
ohlcv_bars = [] # [timestamp, open, high, low, close, volume]
|
||||
bar_lock = threading.Lock()
|
||||
model, feature_cols = load_model_and_features(MODEL_PATH)
|
||||
|
||||
# --- Background thread to collect trades and aggregate OHLCV ---
|
||||
def ws_collector():
|
||||
client = OKXClient(authenticate=False)
|
||||
client.subscribe_trades(instrument="BTC-USDT")
|
||||
current_bar = None
|
||||
while True:
|
||||
try:
|
||||
msg = client.ws.recv()
|
||||
data = json.loads(msg)
|
||||
except websocket._exceptions.WebSocketTimeoutException:
|
||||
continue # Just try again
|
||||
except Exception as e:
|
||||
print(f"WebSocket error: {e}")
|
||||
break # or try to reconnect
|
||||
if 'arg' in data and data['arg'].get('channel', '') == 'trades':
|
||||
for trade in data.get('data', []):
|
||||
# trade: {'instId', 'tradeId', 'px', 'sz', 'side', 'ts'}
|
||||
ts = int(trade['ts'])
|
||||
price = float(trade['px'])
|
||||
size = float(trade['sz'])
|
||||
dt = datetime.utcfromtimestamp(ts / 1000)
|
||||
bar_seconds = 30 # or 15, 30, etc.
|
||||
bar_time = dt.replace(second=(dt.second // bar_seconds) * bar_seconds, microsecond=0)
|
||||
with bar_lock:
|
||||
if not ohlcv_bars or ohlcv_bars[-1][0] != bar_time:
|
||||
# New bar
|
||||
ohlcv_bars.append([bar_time, price, price, price, price, size])
|
||||
if len(ohlcv_bars) > WINDOW:
|
||||
ohlcv_bars.pop(0)
|
||||
else:
|
||||
# Update current bar
|
||||
bar = ohlcv_bars[-1]
|
||||
bar[2] = max(bar[2], price) # high
|
||||
bar[3] = min(bar[3], price) # low
|
||||
bar[4] = price # close
|
||||
bar[5] += size # volume
|
||||
|
||||
# Start the background thread
|
||||
threading.Thread(target=ws_collector, daemon=True).start()
|
||||
|
||||
# --- Dash App ---
|
||||
app = dash.Dash(__name__)
|
||||
app.layout = html.Div([
|
||||
html.H2('BTC/USDT Price & Prediction (OKX, XGBoost, Trades Aggregated)', style={"textAlign": "center", "margin": 0, "padding": 0}),
|
||||
dcc.Graph(id='live-graph', animate=False, style={"height": "90vh", "width": "100vw", "margin": 0, "padding": 0}),
|
||||
dcc.Interval(
|
||||
id='interval-component',
|
||||
interval=5*1000, # 5 seconds
|
||||
n_intervals=0
|
||||
),
|
||||
html.Div(id='prediction-output', style={"textAlign": "center", "fontSize": 20, "marginTop": 10})
|
||||
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
|
||||
|
||||
@app.callback(
|
||||
[Output('live-graph', 'figure'), Output('prediction-output', 'children')],
|
||||
[Input('interval-component', 'n_intervals')]
|
||||
)
|
||||
def update_graph_live(n):
|
||||
with bar_lock:
|
||||
bars = list(ohlcv_bars)
|
||||
if len(bars) < 2:
|
||||
return go.Figure(), "Waiting for data..."
|
||||
df = pd.DataFrame(bars, columns=["Timestamp", "Open", "High", "Low", "Close", "Volume"])
|
||||
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
|
||||
df = add_features(df)
|
||||
if df[feature_cols].isnull().any().any():
|
||||
pred_text = "Not enough data for prediction."
|
||||
else:
|
||||
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
|
||||
pred_text = f"Last: {last_price:.2f} | Predicted next: {pred_price:.2f} | LogRet: {pred_log_return:.6f}"
|
||||
fig = go.Figure()
|
||||
fig.add_trace(go.Candlestick(
|
||||
x=df["Timestamp"],
|
||||
open=df["Open"],
|
||||
high=df["High"],
|
||||
low=df["Low"],
|
||||
close=df["Close"],
|
||||
name='Candlestick',
|
||||
increasing_line_color='green',
|
||||
decreasing_line_color='red',
|
||||
showlegend=False
|
||||
))
|
||||
fig.update_layout(title='BTC-USDT 1m OHLCV (Aggregated from Trades)', xaxis_title='Time', yaxis_title='Price (USDT)', xaxis_rangeslider_visible=False)
|
||||
return fig, pred_text
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True)
|
||||
122
test_predictor.py
Normal file
122
test_predictor.py
Normal file
@ -0,0 +1,122 @@
|
||||
import time
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import requests
|
||||
import xgboost as xgb
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# --- CONFIG ---
|
||||
OKX_REST_URL = 'https://www.okx.com/api/v5/market/candles'
|
||||
SYMBOL = 'BTC-USDT'
|
||||
BAR = '1m'
|
||||
HIST_MINUTES = 250 # Number of minutes of history to fetch for features
|
||||
MODEL_PATH = 'data/xgboost_model.json'
|
||||
|
||||
# --- Fetch recent candles from OKX REST API ---
|
||||
def fetch_recent_candles(symbol, bar, limit=HIST_MINUTES):
|
||||
params = {
|
||||
'instId': symbol,
|
||||
'bar': bar,
|
||||
'limit': str(limit)
|
||||
}
|
||||
resp = requests.get(OKX_REST_URL, params=params)
|
||||
data = resp.json()
|
||||
if data['code'] != '0':
|
||||
raise Exception(f"OKX API error: {data['msg']}")
|
||||
# OKX returns most recent first, reverse to chronological
|
||||
candles = data['data'][::-1]
|
||||
df = pd.DataFrame(candles)
|
||||
# OKX columns: [ts, o, h, l, c, vol, volCcy, confirm, ...] (see API docs)
|
||||
# We'll use: ts, o, h, l, c, vol
|
||||
col_map = {
|
||||
0: 'Timestamp',
|
||||
1: 'Open',
|
||||
2: 'High',
|
||||
3: 'Low',
|
||||
4: 'Close',
|
||||
5: 'Volume',
|
||||
}
|
||||
df = df.rename(columns={str(k): v for k, v in col_map.items()})
|
||||
# If columns are not named, use integer index
|
||||
for k, v in col_map.items():
|
||||
if v not in df.columns:
|
||||
df[v] = df.iloc[:, k]
|
||||
df = df[['Timestamp', 'Open', 'High', 'Low', 'Close', 'Volume']]
|
||||
df['Timestamp'] = pd.to_datetime(df['Timestamp'].astype(np.int64), unit='ms')
|
||||
for col in ['Open', 'High', 'Low', 'Close', 'Volume']:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
return df
|
||||
|
||||
# --- Feature Engineering (minimal, real-time) ---
|
||||
def add_features(df):
|
||||
# Log return (target, not used for prediction)
|
||||
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||
# RSI (14)
|
||||
def calc_rsi(close, window=14):
|
||||
delta = close.diff()
|
||||
up = delta.clip(lower=0)
|
||||
down = -1 * delta.clip(upper=0)
|
||||
ma_up = up.rolling(window=window, min_periods=window).mean()
|
||||
ma_down = down.rolling(window=window, min_periods=window).mean()
|
||||
rs = ma_up / ma_down
|
||||
return 100 - (100 / (1 + rs))
|
||||
df['rsi'] = calc_rsi(df['Close'])
|
||||
# EMA 14
|
||||
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
|
||||
# SMA 50, 200
|
||||
df['sma_50'] = df['Close'].rolling(window=50).mean()
|
||||
df['sma_200'] = df['Close'].rolling(window=200).mean()
|
||||
# ATR 14
|
||||
high_low = df['High'] - df['Low']
|
||||
high_close = np.abs(df['High'] - df['Close'].shift(1))
|
||||
low_close = np.abs(df['Low'] - df['Close'].shift(1))
|
||||
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
||||
df['atr'] = tr.rolling(window=14).mean()
|
||||
# ROC 10
|
||||
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
|
||||
# DPO 20
|
||||
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
|
||||
# Hour
|
||||
df['hour'] = df['Timestamp'].dt.hour
|
||||
# Add more features as needed (match with main.py)
|
||||
return df
|
||||
|
||||
# --- Load model and feature columns ---
|
||||
def load_model_and_features(model_path):
|
||||
model = xgb.Booster()
|
||||
model.load_model(model_path)
|
||||
# Try to infer feature names from main.py (hardcoded for now)
|
||||
feature_cols = [
|
||||
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
|
||||
]
|
||||
return model, feature_cols
|
||||
|
||||
# --- Predict next log return and price ---
|
||||
def predict_next(df, model, feature_cols):
|
||||
# Use the last row for prediction
|
||||
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
|
||||
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
|
||||
pred_log_return = model.predict(dmatrix)[0]
|
||||
last_price = df['Close'].iloc[-1]
|
||||
pred_price = last_price * np.exp(pred_log_return)
|
||||
return pred_price, pred_log_return, last_price
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Fetching recent candles from OKX...')
|
||||
df = fetch_recent_candles(SYMBOL, BAR)
|
||||
df = add_features(df)
|
||||
model, feature_cols = load_model_and_features(MODEL_PATH)
|
||||
print('Waiting for new candle...')
|
||||
last_timestamp = df['Timestamp'].iloc[-1]
|
||||
while True:
|
||||
time.sleep(5)
|
||||
new_df = fetch_recent_candles(SYMBOL, BAR, limit=HIST_MINUTES)
|
||||
if new_df['Timestamp'].iloc[-1] > last_timestamp:
|
||||
df = new_df
|
||||
df = add_features(df)
|
||||
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
|
||||
print(f"[{df['Timestamp'].iloc[-1]}] Last price: {last_price:.2f} | Predicted next price: {pred_price:.2f} | Predicted log return: {pred_log_return:.6f}")
|
||||
last_timestamp = df['Timestamp'].iloc[-1]
|
||||
else:
|
||||
print('No new candle yet...')
|
||||
Loading…
x
Reference in New Issue
Block a user