testing live plot

This commit is contained in:
Simon Moisy 2025-05-30 12:40:49 +08:00
parent 8f96e14b8b
commit f534825e53
6 changed files with 842 additions and 570 deletions

150
live_plot.py Normal file
View File

@ -0,0 +1,150 @@
import json
import threading
import time
import numpy as np
import pandas as pd
import xgboost as xgb
from datetime import datetime
from okx_client import OKXClient
import dash
from dash import dcc, html
from dash.dependencies import Output, Input
import plotly.graph_objs as go
import websocket
# --- Prediction utilities (from test_predictor.py) ---
def add_features(df):
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
def calc_rsi(close, window=14):
delta = close.diff()
up = delta.clip(lower=0)
down = -1 * delta.clip(upper=0)
ma_up = up.rolling(window=window, min_periods=window).mean()
ma_down = down.rolling(window=window, min_periods=window).mean()
rs = ma_up / ma_down
return 100 - (100 / (1 + rs))
df['rsi'] = calc_rsi(df['Close'])
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
df['sma_50'] = df['Close'].rolling(window=50).mean()
df['sma_200'] = df['Close'].rolling(window=200).mean()
high_low = df['High'] - df['Low']
high_close = np.abs(df['High'] - df['Close'].shift(1))
low_close = np.abs(df['Low'] - df['Close'].shift(1))
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
df['atr'] = tr.rolling(window=14).mean()
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
df['hour'] = df['Timestamp'].dt.hour
return df
def load_model_and_features(model_path):
model = xgb.Booster()
model.load_model(model_path)
feature_cols = [
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
]
return model, feature_cols
def predict_next(df, model, feature_cols):
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
pred_log_return = model.predict(dmatrix)[0]
last_price = df['Close'].iloc[-1]
pred_price = last_price * np.exp(pred_log_return)
return pred_price, pred_log_return, last_price
# --- Main live plotting ---
WINDOW = 50
MODEL_PATH = 'data/xgboost_model.json'
ohlcv_bars = [] # [timestamp, open, high, low, close, volume]
bar_lock = threading.Lock()
model, feature_cols = load_model_and_features(MODEL_PATH)
# --- Background thread to collect trades and aggregate OHLCV ---
def ws_collector():
client = OKXClient(authenticate=False)
client.subscribe_trades(instrument="BTC-USDT")
current_bar = None
while True:
try:
msg = client.ws.recv()
data = json.loads(msg)
except websocket._exceptions.WebSocketTimeoutException:
continue # Just try again
except Exception as e:
print(f"WebSocket error: {e}")
break # or try to reconnect
if 'arg' in data and data['arg'].get('channel', '') == 'trades':
for trade in data.get('data', []):
# trade: {'instId', 'tradeId', 'px', 'sz', 'side', 'ts'}
ts = int(trade['ts'])
price = float(trade['px'])
size = float(trade['sz'])
dt = datetime.utcfromtimestamp(ts / 1000)
bar_seconds = 30 # or 15, 30, etc.
bar_time = dt.replace(second=(dt.second // bar_seconds) * bar_seconds, microsecond=0)
with bar_lock:
if not ohlcv_bars or ohlcv_bars[-1][0] != bar_time:
# New bar
ohlcv_bars.append([bar_time, price, price, price, price, size])
if len(ohlcv_bars) > WINDOW:
ohlcv_bars.pop(0)
else:
# Update current bar
bar = ohlcv_bars[-1]
bar[2] = max(bar[2], price) # high
bar[3] = min(bar[3], price) # low
bar[4] = price # close
bar[5] += size # volume
# Start the background thread
threading.Thread(target=ws_collector, daemon=True).start()
# --- Dash App ---
app = dash.Dash(__name__)
app.layout = html.Div([
html.H2('BTC/USDT Price & Prediction (OKX, XGBoost, Trades Aggregated)', style={"textAlign": "center", "margin": 0, "padding": 0}),
dcc.Graph(id='live-graph', animate=False, style={"height": "90vh", "width": "100vw", "margin": 0, "padding": 0}),
dcc.Interval(
id='interval-component',
interval=5*1000, # 5 seconds
n_intervals=0
),
html.Div(id='prediction-output', style={"textAlign": "center", "fontSize": 20, "marginTop": 10})
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
@app.callback(
[Output('live-graph', 'figure'), Output('prediction-output', 'children')],
[Input('interval-component', 'n_intervals')]
)
def update_graph_live(n):
with bar_lock:
bars = list(ohlcv_bars)
if len(bars) < 2:
return go.Figure(), "Waiting for data..."
df = pd.DataFrame(bars, columns=["Timestamp", "Open", "High", "Low", "Close", "Volume"])
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
df = add_features(df)
if df[feature_cols].isnull().any().any():
pred_text = "Not enough data for prediction."
else:
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
pred_text = f"Last: {last_price:.2f} | Predicted next: {pred_price:.2f} | LogRet: {pred_log_return:.6f}"
fig = go.Figure()
fig.add_trace(go.Candlestick(
x=df["Timestamp"],
open=df["Open"],
high=df["High"],
low=df["Low"],
close=df["Close"],
name='Candlestick',
increasing_line_color='green',
decreasing_line_color='red',
showlegend=False
))
fig.update_layout(title='BTC-USDT 1m OHLCV (Aggregated from Trades)', xaxis_title='Time', yaxis_title='Price (USDT)', xaxis_rangeslider_visible=False)
return fig, pred_text
if __name__ == '__main__':
app.run(debug=True)

304
main.py
View File

@ -1,152 +1,152 @@
from okx_client import OKXClient from okx_client import OKXClient
from market_db import MarketDB from market_db import MarketDB
import json import json
import logging import logging
import threading import threading
from collections import deque from collections import deque
import time import time
import signal import signal
latest_book = {'bids': [], 'asks': [], 'timestamp': None} latest_book = {'bids': [], 'asks': [], 'timestamp': None}
book_history = deque() book_history = deque()
trade_history = deque() trade_history = deque()
TRADE_HISTORY_SECONDS = 60 TRADE_HISTORY_SECONDS = 60
BOOK_HISTORY_SECONDS = 5 BOOK_HISTORY_SECONDS = 5
shutdown_flag = threading.Event() shutdown_flag = threading.Event()
def connect(instrument, max_retries=5): def connect(instrument, max_retries=5):
logging.info(f"Connecting to OKX for instrument: {instrument}") logging.info(f"Connecting to OKX for instrument: {instrument}")
retries = 0 retries = 0
backoff = 1 backoff = 1
while not shutdown_flag.is_set(): while not shutdown_flag.is_set():
try: try:
client = OKXClient(authenticate=False) client = OKXClient(authenticate=False)
client.subscribe_trades(instrument) client.subscribe_trades(instrument)
client.subscribe_book(instrument, depth=5, channel="books") client.subscribe_book(instrument, depth=5, channel="books")
logging.info(f"Subscribed to trades and book for {instrument}") logging.info(f"Subscribed to trades and book for {instrument}")
return client return client
except Exception as e: except Exception as e:
retries += 1 retries += 1
logging.error(f"Failed to connect to OKX: {e}. Retry {retries}/{max_retries} in {backoff}s.") logging.error(f"Failed to connect to OKX: {e}. Retry {retries}/{max_retries} in {backoff}s.")
if retries >= max_retries: if retries >= max_retries:
logging.critical("Max retries reached. Exiting connect loop.") logging.critical("Max retries reached. Exiting connect loop.")
raise raise
time.sleep(backoff) time.sleep(backoff)
backoff = min(backoff * 2, 60) # exponential backoff, max 60s backoff = min(backoff * 2, 60) # exponential backoff, max 60s
return None return None
def cleanup(client, db): def cleanup(client, db):
if client and hasattr(client, 'ws') and client.ws: if client and hasattr(client, 'ws') and client.ws:
try: try:
client.ws.close() client.ws.close()
except Exception as e: except Exception as e:
logging.warning(f"Error closing websocket: {e}") logging.warning(f"Error closing websocket: {e}")
if db: if db:
db.close() db.close()
def signal_handler(signum, frame): def signal_handler(signum, frame):
logging.info(f"Received signal {signum}, shutting down...") logging.info(f"Received signal {signum}, shutting down...")
shutdown_flag.set() shutdown_flag.set()
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
def main(): def main():
instruments = [ instruments = [
"ETH-USDT", "ETH-USDT",
"BTC-USDT", "BTC-USDT",
"SOL-USDT", "SOL-USDT",
"DOGE-USDT", "DOGE-USDT",
"TON-USDT", "TON-USDT",
"ETH-USDC", "ETH-USDC",
"SOPH-USDT", "SOPH-USDT",
"PEPE-USDT", "PEPE-USDT",
"BTC-USDC", "BTC-USDC",
"UNI-USDT" "UNI-USDT"
] ]
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
dbs = {} dbs = {}
clients = {} clients = {}
try: try:
for instrument in instruments: for instrument in instruments:
dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db") dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db")
logging.info(f"Database initialized for {instrument}") logging.info(f"Database initialized for {instrument}")
clients[instrument] = connect(instrument) clients[instrument] = connect(instrument)
while not shutdown_flag.is_set(): while not shutdown_flag.is_set():
for instrument in instruments: for instrument in instruments:
client = clients[instrument] client = clients[instrument]
db = dbs[instrument] db = dbs[instrument]
try: try:
data = client.ws.recv() data = client.ws.recv()
except Exception as e: except Exception as e:
logging.warning(f"WebSocket disconnected or error for {instrument}: {e}. Reconnecting...") logging.warning(f"WebSocket disconnected or error for {instrument}: {e}. Reconnecting...")
cleanup(client, None) cleanup(client, None)
try: try:
clients[instrument] = connect(instrument) clients[instrument] = connect(instrument)
except Exception as e: except Exception as e:
logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.") logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.")
continue continue
continue continue
if shutdown_flag.is_set(): if shutdown_flag.is_set():
break break
if data == '': if data == '':
continue continue
try: try:
msg = json.loads(data) msg = json.loads(data)
except Exception as e: except Exception as e:
logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}") logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}")
continue continue
if 'arg' in msg and msg['arg'].get('channel') == 'trades': if 'arg' in msg and msg['arg'].get('channel') == 'trades':
for trade in msg.get('data', []): for trade in msg.get('data', []):
db.insert_trade({ db.insert_trade({
'instrument': instrument, 'instrument': instrument,
'trade_id': trade.get('tradeId'), 'trade_id': trade.get('tradeId'),
'price': float(trade.get('px')), 'price': float(trade.get('px')),
'size': float(trade.get('sz')), 'size': float(trade.get('sz')),
'side': trade.get('side'), 'side': trade.get('side'),
'timestamp': trade.get('ts') 'timestamp': trade.get('ts')
}) })
ts = float(trade.get('ts', time.time() * 1000)) ts = float(trade.get('ts', time.time() * 1000))
trade_history.append({ trade_history.append({
'price': trade.get('px'), 'price': trade.get('px'),
'size': trade.get('sz'), 'size': trade.get('sz'),
'side': trade.get('side'), 'side': trade.get('side'),
'timestamp': ts 'timestamp': ts
}) })
elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'): elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'):
for book in msg.get('data', []): for book in msg.get('data', []):
db.insert_book({ db.insert_book({
'instrument': instrument, 'instrument': instrument,
'bids': book.get('bids'), 'bids': book.get('bids'),
'asks': book.get('asks'), 'asks': book.get('asks'),
'timestamp': book.get('ts') 'timestamp': book.get('ts')
}) })
latest_book['bids'] = book.get('bids', []) latest_book['bids'] = book.get('bids', [])
latest_book['asks'] = book.get('asks', []) latest_book['asks'] = book.get('asks', [])
latest_book['timestamp'] = book.get('ts') latest_book['timestamp'] = book.get('ts')
ts = float(book.get('ts', time.time() * 1000)) ts = float(book.get('ts', time.time() * 1000))
book_history.append({ book_history.append({
'bids': book.get('bids', []), 'bids': book.get('bids', []),
'asks': book.get('asks', []), 'asks': book.get('asks', []),
'timestamp': ts 'timestamp': ts
}) })
else: else:
logging.info(f"Unknown message for {instrument}: {msg}") logging.info(f"Unknown message for {instrument}: {msg}")
except Exception as e: except Exception as e:
logging.critical(f"Fatal error in main: {e}") logging.critical(f"Fatal error in main: {e}")
finally: finally:
for client in clients.values(): for client in clients.values():
cleanup(client, None) cleanup(client, None)
for db in dbs.values(): for db in dbs.values():
cleanup(None, db) cleanup(None, db)
logging.info('Shutdown complete.') logging.info('Shutdown complete.')
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,77 +1,77 @@
import sqlite3 import sqlite3
import os import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import logging import logging
class MarketDB: class MarketDB:
def __init__(self, market: str, db_dir: str = ""): def __init__(self, market: str, db_dir: str = ""):
db_name = f"{market}.db" db_name = f"{market}.db"
db_path = db_name if not db_dir else f"{db_dir.rstrip('/')}/{db_name}" db_path = db_name if not db_dir else f"{db_dir.rstrip('/')}/{db_name}"
if db_dir: if db_dir:
os.makedirs(db_dir, exist_ok=True) os.makedirs(db_dir, exist_ok=True)
self.conn = sqlite3.connect(db_path) self.conn = sqlite3.connect(db_path)
logging.info(f"Connected to database at {db_path}") logging.info(f"Connected to database at {db_path}")
self._create_tables() self._create_tables()
def _create_tables(self): def _create_tables(self):
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS trades ( CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
instrument TEXT, instrument TEXT,
trade_id TEXT, trade_id TEXT,
price REAL, price REAL,
size REAL, size REAL,
side TEXT, side TEXT,
timestamp TEXT timestamp TEXT
) )
''') ''')
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS book ( CREATE TABLE IF NOT EXISTS book (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
instrument TEXT, instrument TEXT,
bids TEXT, bids TEXT,
asks TEXT, asks TEXT,
timestamp TEXT timestamp TEXT
) )
''') ''')
self.conn.commit() self.conn.commit()
logging.info("Database tables ensured.") logging.info("Database tables ensured.")
def insert_trade(self, trade: Dict[str, Any]): def insert_trade(self, trade: Dict[str, Any]):
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute(''' cursor.execute('''
INSERT INTO trades (instrument, trade_id, price, size, side, timestamp) INSERT INTO trades (instrument, trade_id, price, size, side, timestamp)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
''', ( ''', (
trade.get('instrument'), trade.get('instrument'),
trade.get('trade_id'), trade.get('trade_id'),
trade.get('price'), trade.get('price'),
trade.get('size'), trade.get('size'),
trade.get('side'), trade.get('side'),
trade.get('timestamp') trade.get('timestamp')
)) ))
self.conn.commit() self.conn.commit()
logging.debug(f"Inserted trade: {trade}") logging.debug(f"Inserted trade: {trade}")
def insert_book(self, book: Dict[str, Any]): def insert_book(self, book: Dict[str, Any]):
cursor = self.conn.cursor() cursor = self.conn.cursor()
bids = book.get('bids', []) bids = book.get('bids', [])
asks = book.get('asks', []) asks = book.get('asks', [])
best_bid = next((b for b in bids if float(b[1]) > 0), ['-', '-']) best_bid = next((b for b in bids if float(b[1]) > 0), ['-', '-'])
best_ask = next((a for a in asks if float(a[1]) > 0), ['-', '-']) best_ask = next((a for a in asks if float(a[1]) > 0), ['-', '-'])
cursor.execute(''' cursor.execute('''
INSERT INTO book (instrument, bids, asks, timestamp) INSERT INTO book (instrument, bids, asks, timestamp)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
''', ( ''', (
book.get('instrument'), book.get('instrument'),
str(bids), str(bids),
str(asks), str(asks),
book.get('timestamp') book.get('timestamp')
)) ))
self.conn.commit() self.conn.commit()
logging.debug(f"Inserted book: {book.get('instrument', 'N/A')} ts:{book.get('timestamp', 'N/A')} bid:{best_bid} ask:{best_ask}") logging.debug(f"Inserted book: {book.get('instrument', 'N/A')} ts:{book.get('timestamp', 'N/A')} bid:{best_bid} ask:{best_ask}")
def close(self): def close(self):
self.conn.close() self.conn.close()
logging.info("Database connection closed.") logging.info("Database connection closed.")

View File

@ -1,233 +1,233 @@
import os import os
import time import time
import hmac import hmac
import hashlib import hashlib
import base64 import base64
import json import json
import pandas as pd import pandas as pd
import threading import threading
import requests import requests
import websocket import websocket
import logging import logging
class OKXClient: class OKXClient:
PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public" PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public"
PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private" PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private"
REST_URL = "https://www.okx.com" REST_URL = "https://www.okx.com"
def __init__(self, authenticate: bool = True): def __init__(self, authenticate: bool = True):
self.authenticated = False self.authenticated = False
self.api_key = None self.api_key = None
self.api_secret = None self.api_secret = None
self.api_passphrase = None self.api_passphrase = None
self.ws = None self.ws = None
self.ws_private = None self.ws_private = None
self._lock = threading.Lock() self._lock = threading.Lock()
self._private_lock = threading.Lock() self._private_lock = threading.Lock()
if authenticate: if authenticate:
config_path = os.path.join(os.path.dirname(__file__), '../credentials/okx_creds.json') config_path = os.path.join(os.path.dirname(__file__), '../credentials/okx_creds.json')
try: try:
with open(config_path, 'r') as f: with open(config_path, 'r') as f:
config = json.load(f) config = json.load(f)
except FileNotFoundError: except FileNotFoundError:
raise FileNotFoundError(f"Credentials file not found at {config_path}. Please create it with the required keys.") raise FileNotFoundError(f"Credentials file not found at {config_path}. Please create it with the required keys.")
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError(f"Credentials file at {config_path} is not valid JSON.") raise ValueError(f"Credentials file at {config_path} is not valid JSON.")
self.api_key = config.get("OKX_API_KEY") self.api_key = config.get("OKX_API_KEY")
self.api_secret = config.get("OKX_API_SECRET") self.api_secret = config.get("OKX_API_SECRET")
self.api_passphrase = config.get("OKX_API_PASSPHRASE") self.api_passphrase = config.get("OKX_API_PASSPHRASE")
if not self.api_key or not self.api_secret or not self.api_passphrase: if not self.api_key or not self.api_secret or not self.api_passphrase:
raise ValueError("API key, secret, and passphrase must be set in the credentials JSON file.") raise ValueError("API key, secret, and passphrase must be set in the credentials JSON file.")
self._authenticate() self._authenticate()
self._connect_ws() self._connect_ws()
def _connect_ws(self): def _connect_ws(self):
if self.ws is None: if self.ws is None:
self.ws = websocket.create_connection(self.PUBLIC_WS_URL, timeout=10) self.ws = websocket.create_connection(self.PUBLIC_WS_URL, timeout=10)
if self.authenticated and self.api_key and self.api_secret and self.api_passphrase and self.ws_private is None: if self.authenticated and self.api_key and self.api_secret and self.api_passphrase and self.ws_private is None:
self.ws_private = websocket.create_connection(self.PRIVATE_WS_URL, timeout=10) self.ws_private = websocket.create_connection(self.PRIVATE_WS_URL, timeout=10)
def _get_timestamp(self): def _get_timestamp(self):
return str(round(time.time(), 3)) return str(round(time.time(), 3))
def _sign(self, timestamp, method, request_path, body): def _sign(self, timestamp, method, request_path, body):
if not body: if not body:
body = '' body = ''
message = f'{timestamp}{method}{request_path}{body}' message = f'{timestamp}{method}{request_path}{body}'
mac = hmac.new(self.api_secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256) mac = hmac.new(self.api_secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256)
return base64.b64encode(mac.digest()).decode() return base64.b64encode(mac.digest()).decode()
def _authenticate(self): def _authenticate(self):
import websocket import websocket
timestamp = self._get_timestamp() timestamp = self._get_timestamp()
sign = self._sign(timestamp, 'GET', '/users/self/verify', '') sign = self._sign(timestamp, 'GET', '/users/self/verify', '')
login_params = { login_params = {
"op": "login", "op": "login",
"args": [{ "args": [{
"apiKey": self.api_key, "apiKey": self.api_key,
"passphrase": self.api_passphrase, "passphrase": self.api_passphrase,
"timestamp": timestamp, "timestamp": timestamp,
"sign": sign "sign": sign
}] }]
} }
self.ws_private.send(json.dumps(login_params)) self.ws_private.send(json.dumps(login_params))
logging.info("Waiting for login response from OKX...") logging.info("Waiting for login response from OKX...")
while True: while True:
try: try:
resp = self.ws_private.recv() resp = self.ws_private.recv()
logging.debug(f"Received from OKX private WS: {resp}") logging.debug(f"Received from OKX private WS: {resp}")
if not resp: if not resp:
continue continue
try: try:
msg = json.loads(resp) msg = json.loads(resp)
except Exception: except Exception:
logging.warning(f"Non-JSON message received: {resp}") logging.warning(f"Non-JSON message received: {resp}")
continue continue
if msg.get("event") == "login": if msg.get("event") == "login":
if msg.get("code") == "0": if msg.get("code") == "0":
logging.info("OKX WebSocket login successful.") logging.info("OKX WebSocket login successful.")
self.authenticated = True self.authenticated = True
break break
else: else:
raise Exception(f"WebSocket authentication failed: {msg}") raise Exception(f"WebSocket authentication failed: {msg}")
except websocket._exceptions.WebSocketConnectionClosedException as e: except websocket._exceptions.WebSocketConnectionClosedException as e:
logging.error(f"WebSocket connection closed during authentication: {e}") logging.error(f"WebSocket connection closed during authentication: {e}")
raise raise
except Exception as e: except Exception as e:
logging.error(f"Exception during authentication: {e}") logging.error(f"Exception during authentication: {e}")
raise raise
def subscribe_candlesticks(self, instrument="BTC-USDT", timeframe="1m"): def subscribe_candlesticks(self, instrument="BTC-USDT", timeframe="1m"):
# OKX uses candle1m, candle5m, etc. # OKX uses candle1m, candle5m, etc.
tf_map = {"1m": "candle1m", "5m": "candle5m", "15m": "candle15m", "1h": "candle1H"} tf_map = {"1m": "candle1m", "5m": "candle5m", "15m": "candle15m", "1h": "candle1H"}
channel = tf_map.get(timeframe, f"candle{timeframe}") channel = tf_map.get(timeframe, f"candle{timeframe}")
params = { params = {
"op": "subscribe", "op": "subscribe",
"args": [{"channel": channel, "instId": instrument}] "args": [{"channel": channel, "instId": instrument}]
} }
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
self.ws.send(json.dumps(params)) self.ws.send(json.dumps(params))
def subscribe_trades(self, instrument="BTC-USDT"): def subscribe_trades(self, instrument="BTC-USDT"):
params = { params = {
"op": "subscribe", "op": "subscribe",
"args": [{"channel": "trades", "instId": instrument}] "args": [{"channel": "trades", "instId": instrument}]
} }
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
self.ws.send(json.dumps(params)) self.ws.send(json.dumps(params))
def subscribe_ticker(self, instrument="BTC-USDT"): def subscribe_ticker(self, instrument="BTC-USDT"):
params = { params = {
"op": "subscribe", "op": "subscribe",
"args": [{"channel": "tickers", "instId": instrument}] "args": [{"channel": "tickers", "instId": instrument}]
} }
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
self.ws.send(json.dumps(params)) self.ws.send(json.dumps(params))
def subscribe_book(self, instrument="BTC-USDT", depth=5, channel="books5"): def subscribe_book(self, instrument="BTC-USDT", depth=5, channel="books5"):
# OKX supports books5, books50, books-l2-tbt # OKX supports books5, books50, books-l2-tbt
# channel = "books5" if depth <= 5 else "books50" # channel = "books5" if depth <= 5 else "books50"
params = { params = {
"op": "subscribe", "op": "subscribe",
"args": [{"channel": channel, "instId": instrument}] "args": [{"channel": channel, "instId": instrument}]
} }
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
self.ws.send(json.dumps(params)) self.ws.send(json.dumps(params))
def subscribe_user_order(self): def subscribe_user_order(self):
if not self.authenticated: if not self.authenticated:
logging.warning("Attempted to subscribe to user order channel without authentication.") logging.warning("Attempted to subscribe to user order channel without authentication.")
return return
params = { params = {
"op": "subscribe", "op": "subscribe",
"args": [{"channel": "orders", "instType": "SPOT"}] "args": [{"channel": "orders", "instType": "SPOT"}]
} }
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}") logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
self.ws_private.send(json.dumps(params)) self.ws_private.send(json.dumps(params))
def subscribe_user_trade(self): def subscribe_user_trade(self):
if not self.authenticated: if not self.authenticated:
logging.warning("Attempted to subscribe to user trade channel without authentication.") logging.warning("Attempted to subscribe to user trade channel without authentication.")
return return
params = { params = {
"op": "subscribe", "op": "subscribe",
"args": [{"channel": "trades", "instType": "SPOT"}] "args": [{"channel": "trades", "instType": "SPOT"}]
} }
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}") logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
self.ws_private.send(json.dumps(params)) self.ws_private.send(json.dumps(params))
def subscribe_user_balance(self): def subscribe_user_balance(self):
if not self.authenticated: if not self.authenticated:
logging.warning("Attempted to subscribe to user balance channel without authentication.") logging.warning("Attempted to subscribe to user balance channel without authentication.")
return return
params = { params = {
"op": "subscribe", "op": "subscribe",
"args": [{"channel": "balance_and_position"}] "args": [{"channel": "balance_and_position"}]
} }
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}") logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
self.ws_private.send(json.dumps(params)) self.ws_private.send(json.dumps(params))
def get_balance(self, currency=None): def get_balance(self, currency=None):
url = f"{self.REST_URL}/api/v5/account/balance" url = f"{self.REST_URL}/api/v5/account/balance"
timestamp = self._get_timestamp() timestamp = self._get_timestamp()
method = "GET" method = "GET"
request_path = "/api/v5/account/balance" request_path = "/api/v5/account/balance"
body = '' body = ''
sign = self._sign(timestamp, method, request_path, body) sign = self._sign(timestamp, method, request_path, body)
headers = { headers = {
"OK-ACCESS-KEY": self.api_key, "OK-ACCESS-KEY": self.api_key,
"OK-ACCESS-SIGN": sign, "OK-ACCESS-SIGN": sign,
"OK-ACCESS-TIMESTAMP": timestamp, "OK-ACCESS-TIMESTAMP": timestamp,
"OK-ACCESS-PASSPHRASE": self.api_passphrase, "OK-ACCESS-PASSPHRASE": self.api_passphrase,
"Content-Type": "application/json" "Content-Type": "application/json"
} }
resp = requests.get(url, headers=headers) resp = requests.get(url, headers=headers)
if resp.status_code == 200: if resp.status_code == 200:
data = resp.json() data = resp.json()
balances = data.get("data", [{}])[0].get("details", []) balances = data.get("data", [{}])[0].get("details", [])
if currency: if currency:
return [b for b in balances if b.get("ccy") == currency] return [b for b in balances if b.get("ccy") == currency]
return balances return balances
return [] return []
def place_order(self, side, amount, instrument="BTC-USDT"): def place_order(self, side, amount, instrument="BTC-USDT"):
url = f"{self.REST_URL}/api/v5/trade/order" url = f"{self.REST_URL}/api/v5/trade/order"
timestamp = self._get_timestamp() timestamp = self._get_timestamp()
method = "POST" method = "POST"
request_path = "/api/v5/trade/order" request_path = "/api/v5/trade/order"
body_dict = { body_dict = {
"instId": instrument, "instId": instrument,
"tdMode": "cash", "tdMode": "cash",
"side": side.lower(), "side": side.lower(),
"ordType": "market", "ordType": "market",
"sz": str(amount) "sz": str(amount)
} }
body = json.dumps(body_dict) body = json.dumps(body_dict)
sign = self._sign(timestamp, method, request_path, body) sign = self._sign(timestamp, method, request_path, body)
headers = { headers = {
"OK-ACCESS-KEY": self.api_key, "OK-ACCESS-KEY": self.api_key,
"OK-ACCESS-SIGN": sign, "OK-ACCESS-SIGN": sign,
"OK-ACCESS-TIMESTAMP": timestamp, "OK-ACCESS-TIMESTAMP": timestamp,
"OK-ACCESS-PASSPHRASE": self.api_passphrase, "OK-ACCESS-PASSPHRASE": self.api_passphrase,
"Content-Type": "application/json" "Content-Type": "application/json"
} }
resp = requests.post(url, headers=headers, data=body) resp = requests.post(url, headers=headers, data=body)
return resp.json() return resp.json()
def buy_btc(self, amount, instrument="BTC-USDT"): def buy_btc(self, amount, instrument="BTC-USDT"):
return self.place_order("buy", amount, instrument) return self.place_order("buy", amount, instrument)
def sell_btc(self, amount, instrument="BTC-USDT"): def sell_btc(self, amount, instrument="BTC-USDT"):
return self.place_order("sell", amount, instrument) return self.place_order("sell", amount, instrument)
def get_instruments(self): def get_instruments(self):
url = f"{self.REST_URL}/api/v5/public/instruments?instType=SPOT" url = f"{self.REST_URL}/api/v5/public/instruments?instType=SPOT"
resp = requests.get(url) resp = requests.get(url)
if resp.status_code == 200: if resp.status_code == 200:
data = resp.json() data = resp.json()
return data.get("data", []) return data.get("data", [])
return [] return []

122
test_predictor.py Normal file
View File

@ -0,0 +1,122 @@
import time
import json
import numpy as np
import pandas as pd
import requests
import xgboost as xgb
from datetime import datetime, timedelta
# --- CONFIG ---
OKX_REST_URL = 'https://www.okx.com/api/v5/market/candles'
SYMBOL = 'BTC-USDT'
BAR = '1m'
HIST_MINUTES = 250 # Number of minutes of history to fetch for features
MODEL_PATH = 'data/xgboost_model.json'
# --- Fetch recent candles from OKX REST API ---
def fetch_recent_candles(symbol, bar, limit=HIST_MINUTES):
params = {
'instId': symbol,
'bar': bar,
'limit': str(limit)
}
resp = requests.get(OKX_REST_URL, params=params)
data = resp.json()
if data['code'] != '0':
raise Exception(f"OKX API error: {data['msg']}")
# OKX returns most recent first, reverse to chronological
candles = data['data'][::-1]
df = pd.DataFrame(candles)
# OKX columns: [ts, o, h, l, c, vol, volCcy, confirm, ...] (see API docs)
# We'll use: ts, o, h, l, c, vol
col_map = {
0: 'Timestamp',
1: 'Open',
2: 'High',
3: 'Low',
4: 'Close',
5: 'Volume',
}
df = df.rename(columns={str(k): v for k, v in col_map.items()})
# If columns are not named, use integer index
for k, v in col_map.items():
if v not in df.columns:
df[v] = df.iloc[:, k]
df = df[['Timestamp', 'Open', 'High', 'Low', 'Close', 'Volume']]
df['Timestamp'] = pd.to_datetime(df['Timestamp'].astype(np.int64), unit='ms')
for col in ['Open', 'High', 'Low', 'Close', 'Volume']:
df[col] = pd.to_numeric(df[col], errors='coerce')
return df
# --- Feature Engineering (minimal, real-time) ---
def add_features(df):
# Log return (target, not used for prediction)
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
# RSI (14)
def calc_rsi(close, window=14):
delta = close.diff()
up = delta.clip(lower=0)
down = -1 * delta.clip(upper=0)
ma_up = up.rolling(window=window, min_periods=window).mean()
ma_down = down.rolling(window=window, min_periods=window).mean()
rs = ma_up / ma_down
return 100 - (100 / (1 + rs))
df['rsi'] = calc_rsi(df['Close'])
# EMA 14
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
# SMA 50, 200
df['sma_50'] = df['Close'].rolling(window=50).mean()
df['sma_200'] = df['Close'].rolling(window=200).mean()
# ATR 14
high_low = df['High'] - df['Low']
high_close = np.abs(df['High'] - df['Close'].shift(1))
low_close = np.abs(df['Low'] - df['Close'].shift(1))
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
df['atr'] = tr.rolling(window=14).mean()
# ROC 10
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
# DPO 20
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
# Hour
df['hour'] = df['Timestamp'].dt.hour
# Add more features as needed (match with main.py)
return df
# --- Load model and feature columns ---
def load_model_and_features(model_path):
model = xgb.Booster()
model.load_model(model_path)
# Try to infer feature names from main.py (hardcoded for now)
feature_cols = [
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
]
return model, feature_cols
# --- Predict next log return and price ---
def predict_next(df, model, feature_cols):
# Use the last row for prediction
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
pred_log_return = model.predict(dmatrix)[0]
last_price = df['Close'].iloc[-1]
pred_price = last_price * np.exp(pred_log_return)
return pred_price, pred_log_return, last_price
if __name__ == '__main__':
print('Fetching recent candles from OKX...')
df = fetch_recent_candles(SYMBOL, BAR)
df = add_features(df)
model, feature_cols = load_model_and_features(MODEL_PATH)
print('Waiting for new candle...')
last_timestamp = df['Timestamp'].iloc[-1]
while True:
time.sleep(5)
new_df = fetch_recent_candles(SYMBOL, BAR, limit=HIST_MINUTES)
if new_df['Timestamp'].iloc[-1] > last_timestamp:
df = new_df
df = add_features(df)
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
print(f"[{df['Timestamp'].iloc[-1]}] Last price: {last_price:.2f} | Predicted next price: {pred_price:.2f} | Predicted log return: {pred_log_return:.6f}")
last_timestamp = df['Timestamp'].iloc[-1]
else:
print('No new candle yet...')

View File

@ -1,108 +1,108 @@
import dash import dash
from dash import dcc, html from dash import dcc, html
from dash.dependencies import Output, Input from dash.dependencies import Output, Input
import plotly.graph_objs as go import plotly.graph_objs as go
import time import time
from collections import defaultdict from collections import defaultdict
def run_dash(book_history, trade_history, BOOK_HISTORY_SECONDS=5, TRADE_HISTORY_SECONDS=60): def run_dash(book_history, trade_history, BOOK_HISTORY_SECONDS=5, TRADE_HISTORY_SECONDS=60):
app = dash.Dash(__name__) app = dash.Dash(__name__)
app.layout = html.Div([ app.layout = html.Div([
html.H1("Order Book Depth Chart", style={"textAlign": "center", "color": "#222"}), html.H1("Order Book Depth Chart", style={"textAlign": "center", "color": "#222"}),
dcc.Graph(id='order-book-graph', style={"height": "90vh", "width": "100vw"}), dcc.Graph(id='order-book-graph', style={"height": "90vh", "width": "100vw"}),
dcc.Interval(id='interval-component', interval=2*1000, n_intervals=0) dcc.Interval(id='interval-component', interval=2*1000, n_intervals=0)
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"}) ], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
@app.callback( @app.callback(
[Output('order-book-graph', 'figure')], [Output('order-book-graph', 'figure')],
[Input('interval-component', 'n_intervals')] [Input('interval-component', 'n_intervals')]
) )
def update_graphs(n): def update_graphs(n):
now = time.time() * 1000 # current time in ms now = time.time() * 1000 # current time in ms
# Prune book_history to only keep last BOOK_HISTORY_SECONDS # Prune book_history to only keep last BOOK_HISTORY_SECONDS
while book_history and now - book_history[0]['timestamp'] > BOOK_HISTORY_SECONDS * 1000: while book_history and now - book_history[0]['timestamp'] > BOOK_HISTORY_SECONDS * 1000:
book_history.popleft() book_history.popleft()
# Prune trade_history to only keep last TRADE_HISTORY_SECONDS # Prune trade_history to only keep last TRADE_HISTORY_SECONDS
while trade_history and now - float(trade_history[0]['timestamp']) > TRADE_HISTORY_SECONDS * 1000: while trade_history and now - float(trade_history[0]['timestamp']) > TRADE_HISTORY_SECONDS * 1000:
trade_history.popleft() trade_history.popleft()
# Aggregate bids/asks from book_history # Aggregate bids/asks from book_history
bids_dict = {} bids_dict = {}
asks_dict = {} asks_dict = {}
for book in book_history: for book in book_history:
for price, size, *_ in book['bids']: for price, size, *_ in book['bids']:
price = float(price) price = float(price)
size = float(size) size = float(size)
bids_dict[price] = bids_dict.get(price, 0) + size bids_dict[price] = bids_dict.get(price, 0) + size
for price, size, *_ in book['asks']: for price, size, *_ in book['asks']:
price = float(price) price = float(price)
size = float(size) size = float(size)
asks_dict[price] = asks_dict.get(price, 0) + size asks_dict[price] = asks_dict.get(price, 0) + size
try: try:
# Prepare and sort bids/asks # Prepare and sort bids/asks
bids = sorted([[p, s] for p, s in bids_dict.items()], reverse=True) bids = sorted([[p, s] for p, s in bids_dict.items()], reverse=True)
asks = sorted([[p, s] for p, s in asks_dict.items()]) asks = sorted([[p, s] for p, s in asks_dict.items()])
# Cumulative sum # Cumulative sum
bid_prices = [b[0] for b in bids] bid_prices = [b[0] for b in bids]
bid_sizes = [b[1] for b in bids] bid_sizes = [b[1] for b in bids]
ask_prices = [a[0] for a in asks] ask_prices = [a[0] for a in asks]
ask_sizes = [a[1] for a in asks] ask_sizes = [a[1] for a in asks]
bid_cumsum = [sum(bid_sizes[:i+1]) for i in range(len(bid_sizes))] bid_cumsum = [sum(bid_sizes[:i+1]) for i in range(len(bid_sizes))]
ask_cumsum = [sum(ask_sizes[:i+1]) for i in range(len(ask_sizes))] ask_cumsum = [sum(ask_sizes[:i+1]) for i in range(len(ask_sizes))]
except Exception as e: except Exception as e:
bid_prices, bid_cumsum, ask_prices, ask_cumsum = [], [], [], [] bid_prices, bid_cumsum, ask_prices, ask_cumsum = [], [], [], []
fig = go.Figure() fig = go.Figure()
# Add order book lines (primary y-axis) # Add order book lines (primary y-axis)
fig.add_trace(go.Scatter( fig.add_trace(go.Scatter(
x=bid_prices, y=bid_cumsum, mode='lines', name='Bids', x=bid_prices, y=bid_cumsum, mode='lines', name='Bids',
line=dict(color='green'), fill='tozeroy', yaxis='y1' line=dict(color='green'), fill='tozeroy', yaxis='y1'
)) ))
fig.add_trace(go.Scatter( fig.add_trace(go.Scatter(
x=ask_prices, y=ask_cumsum, mode='lines', name='Asks', x=ask_prices, y=ask_cumsum, mode='lines', name='Asks',
line=dict(color='red'), fill='tozeroy', yaxis='y1' line=dict(color='red'), fill='tozeroy', yaxis='y1'
)) ))
trade_volume_by_price = defaultdict(float) trade_volume_by_price = defaultdict(float)
for trade in trade_history: for trade in trade_history:
price_bin = round(float(trade['price']), 2) price_bin = round(float(trade['price']), 2)
trade_volume_by_price[price_bin] += float(trade['size']) trade_volume_by_price[price_bin] += float(trade['size'])
prices = list(trade_volume_by_price.keys()) prices = list(trade_volume_by_price.keys())
volumes = list(trade_volume_by_price.values()) volumes = list(trade_volume_by_price.values())
# Sort by price for display # Sort by price for display
sorted_pairs = sorted(zip(prices, volumes)) sorted_pairs = sorted(zip(prices, volumes))
prices = [p for p, v in sorted_pairs] prices = [p for p, v in sorted_pairs]
volumes = [v for p, v in sorted_pairs] volumes = [v for p, v in sorted_pairs]
# Add trade volume bars (secondary y-axis) # Add trade volume bars (secondary y-axis)
fig.add_trace(go.Bar( fig.add_trace(go.Bar(
x=prices, y=volumes, marker_color='#7ec8e3', name='Trade Volume', x=prices, y=volumes, marker_color='#7ec8e3', name='Trade Volume',
opacity=0.7, yaxis='y2' opacity=0.7, yaxis='y2'
)) ))
# Update layout for dual y-axes # Update layout for dual y-axes
fig.update_layout( fig.update_layout(
title='Order Book Depth & Realized Trade Volume by Price', title='Order Book Depth & Realized Trade Volume by Price',
xaxis=dict(title='Price'), xaxis=dict(title='Price'),
yaxis=dict(title='Cumulative Size', side='left'), yaxis=dict(title='Cumulative Size', side='left'),
yaxis2=dict( yaxis2=dict(
title='Traded Volume', title='Traded Volume',
overlaying='y', overlaying='y',
side='right', side='right',
showgrid=False showgrid=False
), ),
template='plotly_dark' template='plotly_dark'
) )
return [fig] return [fig]
app.run(debug=True, use_reloader=False) app.run(debug=True, use_reloader=False)