From f534825e537461eb298e6d3b1c548d5fbe1e7491 Mon Sep 17 00:00:00 2001 From: Simon Moisy Date: Fri, 30 May 2025 12:40:49 +0800 Subject: [PATCH] testing live plot --- live_plot.py | 150 +++++++++++++++ main.py | 304 +++++++++++++++--------------- market_db.py | 154 +++++++-------- okx_client.py | 466 +++++++++++++++++++++++----------------------- test_predictor.py | 122 ++++++++++++ visualizer.py | 216 ++++++++++----------- 6 files changed, 842 insertions(+), 570 deletions(-) create mode 100644 live_plot.py create mode 100644 test_predictor.py diff --git a/live_plot.py b/live_plot.py new file mode 100644 index 0000000..af1b6a3 --- /dev/null +++ b/live_plot.py @@ -0,0 +1,150 @@ +import json +import threading +import time +import numpy as np +import pandas as pd +import xgboost as xgb +from datetime import datetime +from okx_client import OKXClient +import dash +from dash import dcc, html +from dash.dependencies import Output, Input +import plotly.graph_objs as go +import websocket + +# --- Prediction utilities (from test_predictor.py) --- +def add_features(df): + df['log_return'] = np.log(df['Close'] / df['Close'].shift(1)) + def calc_rsi(close, window=14): + delta = close.diff() + up = delta.clip(lower=0) + down = -1 * delta.clip(upper=0) + ma_up = up.rolling(window=window, min_periods=window).mean() + ma_down = down.rolling(window=window, min_periods=window).mean() + rs = ma_up / ma_down + return 100 - (100 / (1 + rs)) + df['rsi'] = calc_rsi(df['Close']) + df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean() + df['sma_50'] = df['Close'].rolling(window=50).mean() + df['sma_200'] = df['Close'].rolling(window=200).mean() + high_low = df['High'] - df['Low'] + high_close = np.abs(df['High'] - df['Close'].shift(1)) + low_close = np.abs(df['Low'] - df['Close'].shift(1)) + tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1) + df['atr'] = tr.rolling(window=14).mean() + df['roc_10'] = df['Close'].pct_change(periods=10) * 100 + df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10) + df['hour'] = df['Timestamp'].dt.hour + return df + +def load_model_and_features(model_path): + model = xgb.Booster() + model.load_model(model_path) + feature_cols = [ + 'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour' + ] + return model, feature_cols + +def predict_next(df, model, feature_cols): + X = df[feature_cols].iloc[[-1]].values.astype(np.float32) + dmatrix = xgb.DMatrix(X, feature_names=feature_cols) + pred_log_return = model.predict(dmatrix)[0] + last_price = df['Close'].iloc[-1] + pred_price = last_price * np.exp(pred_log_return) + return pred_price, pred_log_return, last_price + +# --- Main live plotting --- +WINDOW = 50 +MODEL_PATH = 'data/xgboost_model.json' + +ohlcv_bars = [] # [timestamp, open, high, low, close, volume] +bar_lock = threading.Lock() +model, feature_cols = load_model_and_features(MODEL_PATH) + +# --- Background thread to collect trades and aggregate OHLCV --- +def ws_collector(): + client = OKXClient(authenticate=False) + client.subscribe_trades(instrument="BTC-USDT") + current_bar = None + while True: + try: + msg = client.ws.recv() + data = json.loads(msg) + except websocket._exceptions.WebSocketTimeoutException: + continue # Just try again + except Exception as e: + print(f"WebSocket error: {e}") + break # or try to reconnect + if 'arg' in data and data['arg'].get('channel', '') == 'trades': + for trade in data.get('data', []): + # trade: {'instId', 'tradeId', 'px', 'sz', 'side', 'ts'} + ts = int(trade['ts']) + price = float(trade['px']) + size = float(trade['sz']) + dt = datetime.utcfromtimestamp(ts / 1000) + bar_seconds = 30 # or 15, 30, etc. + bar_time = dt.replace(second=(dt.second // bar_seconds) * bar_seconds, microsecond=0) + with bar_lock: + if not ohlcv_bars or ohlcv_bars[-1][0] != bar_time: + # New bar + ohlcv_bars.append([bar_time, price, price, price, price, size]) + if len(ohlcv_bars) > WINDOW: + ohlcv_bars.pop(0) + else: + # Update current bar + bar = ohlcv_bars[-1] + bar[2] = max(bar[2], price) # high + bar[3] = min(bar[3], price) # low + bar[4] = price # close + bar[5] += size # volume + +# Start the background thread +threading.Thread(target=ws_collector, daemon=True).start() + +# --- Dash App --- +app = dash.Dash(__name__) +app.layout = html.Div([ + html.H2('BTC/USDT Price & Prediction (OKX, XGBoost, Trades Aggregated)', style={"textAlign": "center", "margin": 0, "padding": 0}), + dcc.Graph(id='live-graph', animate=False, style={"height": "90vh", "width": "100vw", "margin": 0, "padding": 0}), + dcc.Interval( + id='interval-component', + interval=5*1000, # 5 seconds + n_intervals=0 + ), + html.Div(id='prediction-output', style={"textAlign": "center", "fontSize": 20, "marginTop": 10}) +], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"}) + +@app.callback( + [Output('live-graph', 'figure'), Output('prediction-output', 'children')], + [Input('interval-component', 'n_intervals')] +) +def update_graph_live(n): + with bar_lock: + bars = list(ohlcv_bars) + if len(bars) < 2: + return go.Figure(), "Waiting for data..." + df = pd.DataFrame(bars, columns=["Timestamp", "Open", "High", "Low", "Close", "Volume"]) + df["Timestamp"] = pd.to_datetime(df["Timestamp"]) + df = add_features(df) + if df[feature_cols].isnull().any().any(): + pred_text = "Not enough data for prediction." + else: + pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols) + pred_text = f"Last: {last_price:.2f} | Predicted next: {pred_price:.2f} | LogRet: {pred_log_return:.6f}" + fig = go.Figure() + fig.add_trace(go.Candlestick( + x=df["Timestamp"], + open=df["Open"], + high=df["High"], + low=df["Low"], + close=df["Close"], + name='Candlestick', + increasing_line_color='green', + decreasing_line_color='red', + showlegend=False + )) + fig.update_layout(title='BTC-USDT 1m OHLCV (Aggregated from Trades)', xaxis_title='Time', yaxis_title='Price (USDT)', xaxis_rangeslider_visible=False) + return fig, pred_text + +if __name__ == '__main__': + app.run(debug=True) \ No newline at end of file diff --git a/main.py b/main.py index 5ea7220..fa67c9b 100644 --- a/main.py +++ b/main.py @@ -1,152 +1,152 @@ -from okx_client import OKXClient -from market_db import MarketDB -import json -import logging -import threading -from collections import deque -import time -import signal - -latest_book = {'bids': [], 'asks': [], 'timestamp': None} -book_history = deque() -trade_history = deque() - -TRADE_HISTORY_SECONDS = 60 -BOOK_HISTORY_SECONDS = 5 - -shutdown_flag = threading.Event() - -def connect(instrument, max_retries=5): - logging.info(f"Connecting to OKX for instrument: {instrument}") - retries = 0 - backoff = 1 - while not shutdown_flag.is_set(): - try: - client = OKXClient(authenticate=False) - client.subscribe_trades(instrument) - client.subscribe_book(instrument, depth=5, channel="books") - logging.info(f"Subscribed to trades and book for {instrument}") - return client - except Exception as e: - retries += 1 - logging.error(f"Failed to connect to OKX: {e}. Retry {retries}/{max_retries} in {backoff}s.") - if retries >= max_retries: - logging.critical("Max retries reached. Exiting connect loop.") - raise - time.sleep(backoff) - backoff = min(backoff * 2, 60) # exponential backoff, max 60s - return None - -def cleanup(client, db): - if client and hasattr(client, 'ws') and client.ws: - try: - client.ws.close() - except Exception as e: - logging.warning(f"Error closing websocket: {e}") - if db: - db.close() - -def signal_handler(signum, frame): - logging.info(f"Received signal {signum}, shutting down...") - shutdown_flag.set() - -signal.signal(signal.SIGINT, signal_handler) -signal.signal(signal.SIGTERM, signal_handler) - -def main(): - instruments = [ - "ETH-USDT", - "BTC-USDT", - "SOL-USDT", - "DOGE-USDT", - "TON-USDT", - "ETH-USDC", - "SOPH-USDT", - "PEPE-USDT", - "BTC-USDC", - "UNI-USDT" - ] - - logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') - dbs = {} - clients = {} - try: - for instrument in instruments: - dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db") - logging.info(f"Database initialized for {instrument}") - clients[instrument] = connect(instrument) - - while not shutdown_flag.is_set(): - for instrument in instruments: - client = clients[instrument] - db = dbs[instrument] - try: - data = client.ws.recv() - except Exception as e: - logging.warning(f"WebSocket disconnected or error for {instrument}: {e}. Reconnecting...") - cleanup(client, None) - try: - clients[instrument] = connect(instrument) - except Exception as e: - logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.") - continue - continue - - if shutdown_flag.is_set(): - break - if data == '': - continue - - try: - msg = json.loads(data) - except Exception as e: - logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}") - continue - - if 'arg' in msg and msg['arg'].get('channel') == 'trades': - for trade in msg.get('data', []): - db.insert_trade({ - 'instrument': instrument, - 'trade_id': trade.get('tradeId'), - 'price': float(trade.get('px')), - 'size': float(trade.get('sz')), - 'side': trade.get('side'), - 'timestamp': trade.get('ts') - }) - ts = float(trade.get('ts', time.time() * 1000)) - trade_history.append({ - 'price': trade.get('px'), - 'size': trade.get('sz'), - 'side': trade.get('side'), - 'timestamp': ts - }) - elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'): - for book in msg.get('data', []): - db.insert_book({ - 'instrument': instrument, - 'bids': book.get('bids'), - 'asks': book.get('asks'), - 'timestamp': book.get('ts') - }) - latest_book['bids'] = book.get('bids', []) - latest_book['asks'] = book.get('asks', []) - latest_book['timestamp'] = book.get('ts') - ts = float(book.get('ts', time.time() * 1000)) - book_history.append({ - 'bids': book.get('bids', []), - 'asks': book.get('asks', []), - 'timestamp': ts - }) - else: - logging.info(f"Unknown message for {instrument}: {msg}") - except Exception as e: - logging.critical(f"Fatal error in main: {e}") - finally: - for client in clients.values(): - cleanup(client, None) - for db in dbs.values(): - cleanup(None, db) - logging.info('Shutdown complete.') - -if __name__ == '__main__': - main() +from okx_client import OKXClient +from market_db import MarketDB +import json +import logging +import threading +from collections import deque +import time +import signal + +latest_book = {'bids': [], 'asks': [], 'timestamp': None} +book_history = deque() +trade_history = deque() + +TRADE_HISTORY_SECONDS = 60 +BOOK_HISTORY_SECONDS = 5 + +shutdown_flag = threading.Event() + +def connect(instrument, max_retries=5): + logging.info(f"Connecting to OKX for instrument: {instrument}") + retries = 0 + backoff = 1 + while not shutdown_flag.is_set(): + try: + client = OKXClient(authenticate=False) + client.subscribe_trades(instrument) + client.subscribe_book(instrument, depth=5, channel="books") + logging.info(f"Subscribed to trades and book for {instrument}") + return client + except Exception as e: + retries += 1 + logging.error(f"Failed to connect to OKX: {e}. Retry {retries}/{max_retries} in {backoff}s.") + if retries >= max_retries: + logging.critical("Max retries reached. Exiting connect loop.") + raise + time.sleep(backoff) + backoff = min(backoff * 2, 60) # exponential backoff, max 60s + return None + +def cleanup(client, db): + if client and hasattr(client, 'ws') and client.ws: + try: + client.ws.close() + except Exception as e: + logging.warning(f"Error closing websocket: {e}") + if db: + db.close() + +def signal_handler(signum, frame): + logging.info(f"Received signal {signum}, shutting down...") + shutdown_flag.set() + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + +def main(): + instruments = [ + "ETH-USDT", + "BTC-USDT", + "SOL-USDT", + "DOGE-USDT", + "TON-USDT", + "ETH-USDC", + "SOPH-USDT", + "PEPE-USDT", + "BTC-USDC", + "UNI-USDT" + ] + + logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') + dbs = {} + clients = {} + try: + for instrument in instruments: + dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db") + logging.info(f"Database initialized for {instrument}") + clients[instrument] = connect(instrument) + + while not shutdown_flag.is_set(): + for instrument in instruments: + client = clients[instrument] + db = dbs[instrument] + try: + data = client.ws.recv() + except Exception as e: + logging.warning(f"WebSocket disconnected or error for {instrument}: {e}. Reconnecting...") + cleanup(client, None) + try: + clients[instrument] = connect(instrument) + except Exception as e: + logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.") + continue + continue + + if shutdown_flag.is_set(): + break + if data == '': + continue + + try: + msg = json.loads(data) + except Exception as e: + logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}") + continue + + if 'arg' in msg and msg['arg'].get('channel') == 'trades': + for trade in msg.get('data', []): + db.insert_trade({ + 'instrument': instrument, + 'trade_id': trade.get('tradeId'), + 'price': float(trade.get('px')), + 'size': float(trade.get('sz')), + 'side': trade.get('side'), + 'timestamp': trade.get('ts') + }) + ts = float(trade.get('ts', time.time() * 1000)) + trade_history.append({ + 'price': trade.get('px'), + 'size': trade.get('sz'), + 'side': trade.get('side'), + 'timestamp': ts + }) + elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'): + for book in msg.get('data', []): + db.insert_book({ + 'instrument': instrument, + 'bids': book.get('bids'), + 'asks': book.get('asks'), + 'timestamp': book.get('ts') + }) + latest_book['bids'] = book.get('bids', []) + latest_book['asks'] = book.get('asks', []) + latest_book['timestamp'] = book.get('ts') + ts = float(book.get('ts', time.time() * 1000)) + book_history.append({ + 'bids': book.get('bids', []), + 'asks': book.get('asks', []), + 'timestamp': ts + }) + else: + logging.info(f"Unknown message for {instrument}: {msg}") + except Exception as e: + logging.critical(f"Fatal error in main: {e}") + finally: + for client in clients.values(): + cleanup(client, None) + for db in dbs.values(): + cleanup(None, db) + logging.info('Shutdown complete.') + +if __name__ == '__main__': + main() diff --git a/market_db.py b/market_db.py index 8bc6abf..461f4d9 100644 --- a/market_db.py +++ b/market_db.py @@ -1,77 +1,77 @@ -import sqlite3 -import os -from typing import Any, Dict, List, Optional -import logging - -class MarketDB: - def __init__(self, market: str, db_dir: str = ""): - db_name = f"{market}.db" - db_path = db_name if not db_dir else f"{db_dir.rstrip('/')}/{db_name}" - if db_dir: - os.makedirs(db_dir, exist_ok=True) - self.conn = sqlite3.connect(db_path) - logging.info(f"Connected to database at {db_path}") - self._create_tables() - - def _create_tables(self): - cursor = self.conn.cursor() - cursor.execute(''' - CREATE TABLE IF NOT EXISTS trades ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - instrument TEXT, - trade_id TEXT, - price REAL, - size REAL, - side TEXT, - timestamp TEXT - ) - ''') - cursor.execute(''' - CREATE TABLE IF NOT EXISTS book ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - instrument TEXT, - bids TEXT, - asks TEXT, - timestamp TEXT - ) - ''') - self.conn.commit() - logging.info("Database tables ensured.") - - def insert_trade(self, trade: Dict[str, Any]): - cursor = self.conn.cursor() - cursor.execute(''' - INSERT INTO trades (instrument, trade_id, price, size, side, timestamp) - VALUES (?, ?, ?, ?, ?, ?) - ''', ( - trade.get('instrument'), - trade.get('trade_id'), - trade.get('price'), - trade.get('size'), - trade.get('side'), - trade.get('timestamp') - )) - self.conn.commit() - logging.debug(f"Inserted trade: {trade}") - - def insert_book(self, book: Dict[str, Any]): - cursor = self.conn.cursor() - bids = book.get('bids', []) - asks = book.get('asks', []) - best_bid = next((b for b in bids if float(b[1]) > 0), ['-', '-']) - best_ask = next((a for a in asks if float(a[1]) > 0), ['-', '-']) - cursor.execute(''' - INSERT INTO book (instrument, bids, asks, timestamp) - VALUES (?, ?, ?, ?) - ''', ( - book.get('instrument'), - str(bids), - str(asks), - book.get('timestamp') - )) - self.conn.commit() - logging.debug(f"Inserted book: {book.get('instrument', 'N/A')} ts:{book.get('timestamp', 'N/A')} bid:{best_bid} ask:{best_ask}") - - def close(self): - self.conn.close() - logging.info("Database connection closed.") +import sqlite3 +import os +from typing import Any, Dict, List, Optional +import logging + +class MarketDB: + def __init__(self, market: str, db_dir: str = ""): + db_name = f"{market}.db" + db_path = db_name if not db_dir else f"{db_dir.rstrip('/')}/{db_name}" + if db_dir: + os.makedirs(db_dir, exist_ok=True) + self.conn = sqlite3.connect(db_path) + logging.info(f"Connected to database at {db_path}") + self._create_tables() + + def _create_tables(self): + cursor = self.conn.cursor() + cursor.execute(''' + CREATE TABLE IF NOT EXISTS trades ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + instrument TEXT, + trade_id TEXT, + price REAL, + size REAL, + side TEXT, + timestamp TEXT + ) + ''') + cursor.execute(''' + CREATE TABLE IF NOT EXISTS book ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + instrument TEXT, + bids TEXT, + asks TEXT, + timestamp TEXT + ) + ''') + self.conn.commit() + logging.info("Database tables ensured.") + + def insert_trade(self, trade: Dict[str, Any]): + cursor = self.conn.cursor() + cursor.execute(''' + INSERT INTO trades (instrument, trade_id, price, size, side, timestamp) + VALUES (?, ?, ?, ?, ?, ?) + ''', ( + trade.get('instrument'), + trade.get('trade_id'), + trade.get('price'), + trade.get('size'), + trade.get('side'), + trade.get('timestamp') + )) + self.conn.commit() + logging.debug(f"Inserted trade: {trade}") + + def insert_book(self, book: Dict[str, Any]): + cursor = self.conn.cursor() + bids = book.get('bids', []) + asks = book.get('asks', []) + best_bid = next((b for b in bids if float(b[1]) > 0), ['-', '-']) + best_ask = next((a for a in asks if float(a[1]) > 0), ['-', '-']) + cursor.execute(''' + INSERT INTO book (instrument, bids, asks, timestamp) + VALUES (?, ?, ?, ?) + ''', ( + book.get('instrument'), + str(bids), + str(asks), + book.get('timestamp') + )) + self.conn.commit() + logging.debug(f"Inserted book: {book.get('instrument', 'N/A')} ts:{book.get('timestamp', 'N/A')} bid:{best_bid} ask:{best_ask}") + + def close(self): + self.conn.close() + logging.info("Database connection closed.") diff --git a/okx_client.py b/okx_client.py index 52957ad..cb8a577 100644 --- a/okx_client.py +++ b/okx_client.py @@ -1,233 +1,233 @@ -import os -import time -import hmac -import hashlib -import base64 -import json -import pandas as pd -import threading -import requests -import websocket -import logging - -class OKXClient: - PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public" - PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private" - REST_URL = "https://www.okx.com" - - def __init__(self, authenticate: bool = True): - self.authenticated = False - self.api_key = None - self.api_secret = None - self.api_passphrase = None - self.ws = None - self.ws_private = None - self._lock = threading.Lock() - self._private_lock = threading.Lock() - - if authenticate: - config_path = os.path.join(os.path.dirname(__file__), '../credentials/okx_creds.json') - try: - with open(config_path, 'r') as f: - config = json.load(f) - except FileNotFoundError: - raise FileNotFoundError(f"Credentials file not found at {config_path}. Please create it with the required keys.") - except json.JSONDecodeError: - raise ValueError(f"Credentials file at {config_path} is not valid JSON.") - - self.api_key = config.get("OKX_API_KEY") - self.api_secret = config.get("OKX_API_SECRET") - self.api_passphrase = config.get("OKX_API_PASSPHRASE") - - if not self.api_key or not self.api_secret or not self.api_passphrase: - raise ValueError("API key, secret, and passphrase must be set in the credentials JSON file.") - - self._authenticate() - self._connect_ws() - - def _connect_ws(self): - if self.ws is None: - self.ws = websocket.create_connection(self.PUBLIC_WS_URL, timeout=10) - if self.authenticated and self.api_key and self.api_secret and self.api_passphrase and self.ws_private is None: - self.ws_private = websocket.create_connection(self.PRIVATE_WS_URL, timeout=10) - - def _get_timestamp(self): - return str(round(time.time(), 3)) - - def _sign(self, timestamp, method, request_path, body): - if not body: - body = '' - message = f'{timestamp}{method}{request_path}{body}' - mac = hmac.new(self.api_secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256) - return base64.b64encode(mac.digest()).decode() - - def _authenticate(self): - import websocket - timestamp = self._get_timestamp() - sign = self._sign(timestamp, 'GET', '/users/self/verify', '') - login_params = { - "op": "login", - "args": [{ - "apiKey": self.api_key, - "passphrase": self.api_passphrase, - "timestamp": timestamp, - "sign": sign - }] - } - self.ws_private.send(json.dumps(login_params)) - logging.info("Waiting for login response from OKX...") - while True: - try: - resp = self.ws_private.recv() - logging.debug(f"Received from OKX private WS: {resp}") - if not resp: - continue - try: - msg = json.loads(resp) - except Exception: - logging.warning(f"Non-JSON message received: {resp}") - continue - if msg.get("event") == "login": - if msg.get("code") == "0": - logging.info("OKX WebSocket login successful.") - self.authenticated = True - break - else: - raise Exception(f"WebSocket authentication failed: {msg}") - except websocket._exceptions.WebSocketConnectionClosedException as e: - logging.error(f"WebSocket connection closed during authentication: {e}") - raise - except Exception as e: - logging.error(f"Exception during authentication: {e}") - raise - - def subscribe_candlesticks(self, instrument="BTC-USDT", timeframe="1m"): - # OKX uses candle1m, candle5m, etc. - tf_map = {"1m": "candle1m", "5m": "candle5m", "15m": "candle15m", "1h": "candle1H"} - channel = tf_map.get(timeframe, f"candle{timeframe}") - params = { - "op": "subscribe", - "args": [{"channel": channel, "instId": instrument}] - } - logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") - self.ws.send(json.dumps(params)) - - def subscribe_trades(self, instrument="BTC-USDT"): - params = { - "op": "subscribe", - "args": [{"channel": "trades", "instId": instrument}] - } - logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") - self.ws.send(json.dumps(params)) - - def subscribe_ticker(self, instrument="BTC-USDT"): - params = { - "op": "subscribe", - "args": [{"channel": "tickers", "instId": instrument}] - } - logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") - self.ws.send(json.dumps(params)) - - def subscribe_book(self, instrument="BTC-USDT", depth=5, channel="books5"): - # OKX supports books5, books50, books-l2-tbt - # channel = "books5" if depth <= 5 else "books50" - params = { - "op": "subscribe", - "args": [{"channel": channel, "instId": instrument}] - } - logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") - self.ws.send(json.dumps(params)) - - def subscribe_user_order(self): - if not self.authenticated: - logging.warning("Attempted to subscribe to user order channel without authentication.") - return - params = { - "op": "subscribe", - "args": [{"channel": "orders", "instType": "SPOT"}] - } - logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}") - self.ws_private.send(json.dumps(params)) - - def subscribe_user_trade(self): - if not self.authenticated: - logging.warning("Attempted to subscribe to user trade channel without authentication.") - return - params = { - "op": "subscribe", - "args": [{"channel": "trades", "instType": "SPOT"}] - } - logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}") - self.ws_private.send(json.dumps(params)) - - def subscribe_user_balance(self): - if not self.authenticated: - logging.warning("Attempted to subscribe to user balance channel without authentication.") - return - params = { - "op": "subscribe", - "args": [{"channel": "balance_and_position"}] - } - logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}") - self.ws_private.send(json.dumps(params)) - - def get_balance(self, currency=None): - url = f"{self.REST_URL}/api/v5/account/balance" - timestamp = self._get_timestamp() - method = "GET" - request_path = "/api/v5/account/balance" - body = '' - sign = self._sign(timestamp, method, request_path, body) - headers = { - "OK-ACCESS-KEY": self.api_key, - "OK-ACCESS-SIGN": sign, - "OK-ACCESS-TIMESTAMP": timestamp, - "OK-ACCESS-PASSPHRASE": self.api_passphrase, - "Content-Type": "application/json" - } - resp = requests.get(url, headers=headers) - if resp.status_code == 200: - data = resp.json() - balances = data.get("data", [{}])[0].get("details", []) - if currency: - return [b for b in balances if b.get("ccy") == currency] - return balances - return [] - - def place_order(self, side, amount, instrument="BTC-USDT"): - url = f"{self.REST_URL}/api/v5/trade/order" - timestamp = self._get_timestamp() - method = "POST" - request_path = "/api/v5/trade/order" - body_dict = { - "instId": instrument, - "tdMode": "cash", - "side": side.lower(), - "ordType": "market", - "sz": str(amount) - } - body = json.dumps(body_dict) - sign = self._sign(timestamp, method, request_path, body) - headers = { - "OK-ACCESS-KEY": self.api_key, - "OK-ACCESS-SIGN": sign, - "OK-ACCESS-TIMESTAMP": timestamp, - "OK-ACCESS-PASSPHRASE": self.api_passphrase, - "Content-Type": "application/json" - } - resp = requests.post(url, headers=headers, data=body) - return resp.json() - - def buy_btc(self, amount, instrument="BTC-USDT"): - return self.place_order("buy", amount, instrument) - - def sell_btc(self, amount, instrument="BTC-USDT"): - return self.place_order("sell", amount, instrument) - - def get_instruments(self): - url = f"{self.REST_URL}/api/v5/public/instruments?instType=SPOT" - resp = requests.get(url) - if resp.status_code == 200: - data = resp.json() - return data.get("data", []) - return [] +import os +import time +import hmac +import hashlib +import base64 +import json +import pandas as pd +import threading +import requests +import websocket +import logging + +class OKXClient: + PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public" + PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private" + REST_URL = "https://www.okx.com" + + def __init__(self, authenticate: bool = True): + self.authenticated = False + self.api_key = None + self.api_secret = None + self.api_passphrase = None + self.ws = None + self.ws_private = None + self._lock = threading.Lock() + self._private_lock = threading.Lock() + + if authenticate: + config_path = os.path.join(os.path.dirname(__file__), '../credentials/okx_creds.json') + try: + with open(config_path, 'r') as f: + config = json.load(f) + except FileNotFoundError: + raise FileNotFoundError(f"Credentials file not found at {config_path}. Please create it with the required keys.") + except json.JSONDecodeError: + raise ValueError(f"Credentials file at {config_path} is not valid JSON.") + + self.api_key = config.get("OKX_API_KEY") + self.api_secret = config.get("OKX_API_SECRET") + self.api_passphrase = config.get("OKX_API_PASSPHRASE") + + if not self.api_key or not self.api_secret or not self.api_passphrase: + raise ValueError("API key, secret, and passphrase must be set in the credentials JSON file.") + + self._authenticate() + self._connect_ws() + + def _connect_ws(self): + if self.ws is None: + self.ws = websocket.create_connection(self.PUBLIC_WS_URL, timeout=10) + if self.authenticated and self.api_key and self.api_secret and self.api_passphrase and self.ws_private is None: + self.ws_private = websocket.create_connection(self.PRIVATE_WS_URL, timeout=10) + + def _get_timestamp(self): + return str(round(time.time(), 3)) + + def _sign(self, timestamp, method, request_path, body): + if not body: + body = '' + message = f'{timestamp}{method}{request_path}{body}' + mac = hmac.new(self.api_secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256) + return base64.b64encode(mac.digest()).decode() + + def _authenticate(self): + import websocket + timestamp = self._get_timestamp() + sign = self._sign(timestamp, 'GET', '/users/self/verify', '') + login_params = { + "op": "login", + "args": [{ + "apiKey": self.api_key, + "passphrase": self.api_passphrase, + "timestamp": timestamp, + "sign": sign + }] + } + self.ws_private.send(json.dumps(login_params)) + logging.info("Waiting for login response from OKX...") + while True: + try: + resp = self.ws_private.recv() + logging.debug(f"Received from OKX private WS: {resp}") + if not resp: + continue + try: + msg = json.loads(resp) + except Exception: + logging.warning(f"Non-JSON message received: {resp}") + continue + if msg.get("event") == "login": + if msg.get("code") == "0": + logging.info("OKX WebSocket login successful.") + self.authenticated = True + break + else: + raise Exception(f"WebSocket authentication failed: {msg}") + except websocket._exceptions.WebSocketConnectionClosedException as e: + logging.error(f"WebSocket connection closed during authentication: {e}") + raise + except Exception as e: + logging.error(f"Exception during authentication: {e}") + raise + + def subscribe_candlesticks(self, instrument="BTC-USDT", timeframe="1m"): + # OKX uses candle1m, candle5m, etc. + tf_map = {"1m": "candle1m", "5m": "candle5m", "15m": "candle15m", "1h": "candle1H"} + channel = tf_map.get(timeframe, f"candle{timeframe}") + params = { + "op": "subscribe", + "args": [{"channel": channel, "instId": instrument}] + } + logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") + self.ws.send(json.dumps(params)) + + def subscribe_trades(self, instrument="BTC-USDT"): + params = { + "op": "subscribe", + "args": [{"channel": "trades", "instId": instrument}] + } + logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") + self.ws.send(json.dumps(params)) + + def subscribe_ticker(self, instrument="BTC-USDT"): + params = { + "op": "subscribe", + "args": [{"channel": "tickers", "instId": instrument}] + } + logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") + self.ws.send(json.dumps(params)) + + def subscribe_book(self, instrument="BTC-USDT", depth=5, channel="books5"): + # OKX supports books5, books50, books-l2-tbt + # channel = "books5" if depth <= 5 else "books50" + params = { + "op": "subscribe", + "args": [{"channel": channel, "instId": instrument}] + } + logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") + self.ws.send(json.dumps(params)) + + def subscribe_user_order(self): + if not self.authenticated: + logging.warning("Attempted to subscribe to user order channel without authentication.") + return + params = { + "op": "subscribe", + "args": [{"channel": "orders", "instType": "SPOT"}] + } + logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}") + self.ws_private.send(json.dumps(params)) + + def subscribe_user_trade(self): + if not self.authenticated: + logging.warning("Attempted to subscribe to user trade channel without authentication.") + return + params = { + "op": "subscribe", + "args": [{"channel": "trades", "instType": "SPOT"}] + } + logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}") + self.ws_private.send(json.dumps(params)) + + def subscribe_user_balance(self): + if not self.authenticated: + logging.warning("Attempted to subscribe to user balance channel without authentication.") + return + params = { + "op": "subscribe", + "args": [{"channel": "balance_and_position"}] + } + logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}") + self.ws_private.send(json.dumps(params)) + + def get_balance(self, currency=None): + url = f"{self.REST_URL}/api/v5/account/balance" + timestamp = self._get_timestamp() + method = "GET" + request_path = "/api/v5/account/balance" + body = '' + sign = self._sign(timestamp, method, request_path, body) + headers = { + "OK-ACCESS-KEY": self.api_key, + "OK-ACCESS-SIGN": sign, + "OK-ACCESS-TIMESTAMP": timestamp, + "OK-ACCESS-PASSPHRASE": self.api_passphrase, + "Content-Type": "application/json" + } + resp = requests.get(url, headers=headers) + if resp.status_code == 200: + data = resp.json() + balances = data.get("data", [{}])[0].get("details", []) + if currency: + return [b for b in balances if b.get("ccy") == currency] + return balances + return [] + + def place_order(self, side, amount, instrument="BTC-USDT"): + url = f"{self.REST_URL}/api/v5/trade/order" + timestamp = self._get_timestamp() + method = "POST" + request_path = "/api/v5/trade/order" + body_dict = { + "instId": instrument, + "tdMode": "cash", + "side": side.lower(), + "ordType": "market", + "sz": str(amount) + } + body = json.dumps(body_dict) + sign = self._sign(timestamp, method, request_path, body) + headers = { + "OK-ACCESS-KEY": self.api_key, + "OK-ACCESS-SIGN": sign, + "OK-ACCESS-TIMESTAMP": timestamp, + "OK-ACCESS-PASSPHRASE": self.api_passphrase, + "Content-Type": "application/json" + } + resp = requests.post(url, headers=headers, data=body) + return resp.json() + + def buy_btc(self, amount, instrument="BTC-USDT"): + return self.place_order("buy", amount, instrument) + + def sell_btc(self, amount, instrument="BTC-USDT"): + return self.place_order("sell", amount, instrument) + + def get_instruments(self): + url = f"{self.REST_URL}/api/v5/public/instruments?instType=SPOT" + resp = requests.get(url) + if resp.status_code == 200: + data = resp.json() + return data.get("data", []) + return [] diff --git a/test_predictor.py b/test_predictor.py new file mode 100644 index 0000000..dad87d8 --- /dev/null +++ b/test_predictor.py @@ -0,0 +1,122 @@ +import time +import json +import numpy as np +import pandas as pd +import requests +import xgboost as xgb +from datetime import datetime, timedelta + +# --- CONFIG --- +OKX_REST_URL = 'https://www.okx.com/api/v5/market/candles' +SYMBOL = 'BTC-USDT' +BAR = '1m' +HIST_MINUTES = 250 # Number of minutes of history to fetch for features +MODEL_PATH = 'data/xgboost_model.json' + +# --- Fetch recent candles from OKX REST API --- +def fetch_recent_candles(symbol, bar, limit=HIST_MINUTES): + params = { + 'instId': symbol, + 'bar': bar, + 'limit': str(limit) + } + resp = requests.get(OKX_REST_URL, params=params) + data = resp.json() + if data['code'] != '0': + raise Exception(f"OKX API error: {data['msg']}") + # OKX returns most recent first, reverse to chronological + candles = data['data'][::-1] + df = pd.DataFrame(candles) + # OKX columns: [ts, o, h, l, c, vol, volCcy, confirm, ...] (see API docs) + # We'll use: ts, o, h, l, c, vol + col_map = { + 0: 'Timestamp', + 1: 'Open', + 2: 'High', + 3: 'Low', + 4: 'Close', + 5: 'Volume', + } + df = df.rename(columns={str(k): v for k, v in col_map.items()}) + # If columns are not named, use integer index + for k, v in col_map.items(): + if v not in df.columns: + df[v] = df.iloc[:, k] + df = df[['Timestamp', 'Open', 'High', 'Low', 'Close', 'Volume']] + df['Timestamp'] = pd.to_datetime(df['Timestamp'].astype(np.int64), unit='ms') + for col in ['Open', 'High', 'Low', 'Close', 'Volume']: + df[col] = pd.to_numeric(df[col], errors='coerce') + return df + +# --- Feature Engineering (minimal, real-time) --- +def add_features(df): + # Log return (target, not used for prediction) + df['log_return'] = np.log(df['Close'] / df['Close'].shift(1)) + # RSI (14) + def calc_rsi(close, window=14): + delta = close.diff() + up = delta.clip(lower=0) + down = -1 * delta.clip(upper=0) + ma_up = up.rolling(window=window, min_periods=window).mean() + ma_down = down.rolling(window=window, min_periods=window).mean() + rs = ma_up / ma_down + return 100 - (100 / (1 + rs)) + df['rsi'] = calc_rsi(df['Close']) + # EMA 14 + df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean() + # SMA 50, 200 + df['sma_50'] = df['Close'].rolling(window=50).mean() + df['sma_200'] = df['Close'].rolling(window=200).mean() + # ATR 14 + high_low = df['High'] - df['Low'] + high_close = np.abs(df['High'] - df['Close'].shift(1)) + low_close = np.abs(df['Low'] - df['Close'].shift(1)) + tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1) + df['atr'] = tr.rolling(window=14).mean() + # ROC 10 + df['roc_10'] = df['Close'].pct_change(periods=10) * 100 + # DPO 20 + df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10) + # Hour + df['hour'] = df['Timestamp'].dt.hour + # Add more features as needed (match with main.py) + return df + +# --- Load model and feature columns --- +def load_model_and_features(model_path): + model = xgb.Booster() + model.load_model(model_path) + # Try to infer feature names from main.py (hardcoded for now) + feature_cols = [ + 'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour' + ] + return model, feature_cols + +# --- Predict next log return and price --- +def predict_next(df, model, feature_cols): + # Use the last row for prediction + X = df[feature_cols].iloc[[-1]].values.astype(np.float32) + dmatrix = xgb.DMatrix(X, feature_names=feature_cols) + pred_log_return = model.predict(dmatrix)[0] + last_price = df['Close'].iloc[-1] + pred_price = last_price * np.exp(pred_log_return) + return pred_price, pred_log_return, last_price + +if __name__ == '__main__': + print('Fetching recent candles from OKX...') + df = fetch_recent_candles(SYMBOL, BAR) + df = add_features(df) + model, feature_cols = load_model_and_features(MODEL_PATH) + print('Waiting for new candle...') + last_timestamp = df['Timestamp'].iloc[-1] + while True: + time.sleep(5) + new_df = fetch_recent_candles(SYMBOL, BAR, limit=HIST_MINUTES) + if new_df['Timestamp'].iloc[-1] > last_timestamp: + df = new_df + df = add_features(df) + pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols) + print(f"[{df['Timestamp'].iloc[-1]}] Last price: {last_price:.2f} | Predicted next price: {pred_price:.2f} | Predicted log return: {pred_log_return:.6f}") + last_timestamp = df['Timestamp'].iloc[-1] + else: + print('No new candle yet...') diff --git a/visualizer.py b/visualizer.py index 573a360..2ecb11e 100644 --- a/visualizer.py +++ b/visualizer.py @@ -1,108 +1,108 @@ -import dash -from dash import dcc, html -from dash.dependencies import Output, Input -import plotly.graph_objs as go -import time -from collections import defaultdict - -def run_dash(book_history, trade_history, BOOK_HISTORY_SECONDS=5, TRADE_HISTORY_SECONDS=60): - app = dash.Dash(__name__) - app.layout = html.Div([ - html.H1("Order Book Depth Chart", style={"textAlign": "center", "color": "#222"}), - dcc.Graph(id='order-book-graph', style={"height": "90vh", "width": "100vw"}), - dcc.Interval(id='interval-component', interval=2*1000, n_intervals=0) - ], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"}) - - @app.callback( - [Output('order-book-graph', 'figure')], - [Input('interval-component', 'n_intervals')] - ) - def update_graphs(n): - now = time.time() * 1000 # current time in ms - - # Prune book_history to only keep last BOOK_HISTORY_SECONDS - while book_history and now - book_history[0]['timestamp'] > BOOK_HISTORY_SECONDS * 1000: - book_history.popleft() - - # Prune trade_history to only keep last TRADE_HISTORY_SECONDS - while trade_history and now - float(trade_history[0]['timestamp']) > TRADE_HISTORY_SECONDS * 1000: - trade_history.popleft() - - # Aggregate bids/asks from book_history - bids_dict = {} - asks_dict = {} - - for book in book_history: - for price, size, *_ in book['bids']: - price = float(price) - size = float(size) - bids_dict[price] = bids_dict.get(price, 0) + size - - for price, size, *_ in book['asks']: - price = float(price) - size = float(size) - asks_dict[price] = asks_dict.get(price, 0) + size - - try: - # Prepare and sort bids/asks - bids = sorted([[p, s] for p, s in bids_dict.items()], reverse=True) - asks = sorted([[p, s] for p, s in asks_dict.items()]) - - # Cumulative sum - bid_prices = [b[0] for b in bids] - bid_sizes = [b[1] for b in bids] - ask_prices = [a[0] for a in asks] - ask_sizes = [a[1] for a in asks] - bid_cumsum = [sum(bid_sizes[:i+1]) for i in range(len(bid_sizes))] - ask_cumsum = [sum(ask_sizes[:i+1]) for i in range(len(ask_sizes))] - except Exception as e: - bid_prices, bid_cumsum, ask_prices, ask_cumsum = [], [], [], [] - - fig = go.Figure() - - # Add order book lines (primary y-axis) - fig.add_trace(go.Scatter( - x=bid_prices, y=bid_cumsum, mode='lines', name='Bids', - line=dict(color='green'), fill='tozeroy', yaxis='y1' - )) - fig.add_trace(go.Scatter( - x=ask_prices, y=ask_cumsum, mode='lines', name='Asks', - line=dict(color='red'), fill='tozeroy', yaxis='y1' - )) - - trade_volume_by_price = defaultdict(float) - - for trade in trade_history: - price_bin = round(float(trade['price']), 2) - trade_volume_by_price[price_bin] += float(trade['size']) - - prices = list(trade_volume_by_price.keys()) - volumes = list(trade_volume_by_price.values()) - - # Sort by price for display - sorted_pairs = sorted(zip(prices, volumes)) - prices = [p for p, v in sorted_pairs] - volumes = [v for p, v in sorted_pairs] - - # Add trade volume bars (secondary y-axis) - fig.add_trace(go.Bar( - x=prices, y=volumes, marker_color='#7ec8e3', name='Trade Volume', - opacity=0.7, yaxis='y2' - )) - - # Update layout for dual y-axes - fig.update_layout( - title='Order Book Depth & Realized Trade Volume by Price', - xaxis=dict(title='Price'), - yaxis=dict(title='Cumulative Size', side='left'), - yaxis2=dict( - title='Traded Volume', - overlaying='y', - side='right', - showgrid=False - ), - template='plotly_dark' - ) - return [fig] - - app.run(debug=True, use_reloader=False) +import dash +from dash import dcc, html +from dash.dependencies import Output, Input +import plotly.graph_objs as go +import time +from collections import defaultdict + +def run_dash(book_history, trade_history, BOOK_HISTORY_SECONDS=5, TRADE_HISTORY_SECONDS=60): + app = dash.Dash(__name__) + app.layout = html.Div([ + html.H1("Order Book Depth Chart", style={"textAlign": "center", "color": "#222"}), + dcc.Graph(id='order-book-graph', style={"height": "90vh", "width": "100vw"}), + dcc.Interval(id='interval-component', interval=2*1000, n_intervals=0) + ], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"}) + + @app.callback( + [Output('order-book-graph', 'figure')], + [Input('interval-component', 'n_intervals')] + ) + def update_graphs(n): + now = time.time() * 1000 # current time in ms + + # Prune book_history to only keep last BOOK_HISTORY_SECONDS + while book_history and now - book_history[0]['timestamp'] > BOOK_HISTORY_SECONDS * 1000: + book_history.popleft() + + # Prune trade_history to only keep last TRADE_HISTORY_SECONDS + while trade_history and now - float(trade_history[0]['timestamp']) > TRADE_HISTORY_SECONDS * 1000: + trade_history.popleft() + + # Aggregate bids/asks from book_history + bids_dict = {} + asks_dict = {} + + for book in book_history: + for price, size, *_ in book['bids']: + price = float(price) + size = float(size) + bids_dict[price] = bids_dict.get(price, 0) + size + + for price, size, *_ in book['asks']: + price = float(price) + size = float(size) + asks_dict[price] = asks_dict.get(price, 0) + size + + try: + # Prepare and sort bids/asks + bids = sorted([[p, s] for p, s in bids_dict.items()], reverse=True) + asks = sorted([[p, s] for p, s in asks_dict.items()]) + + # Cumulative sum + bid_prices = [b[0] for b in bids] + bid_sizes = [b[1] for b in bids] + ask_prices = [a[0] for a in asks] + ask_sizes = [a[1] for a in asks] + bid_cumsum = [sum(bid_sizes[:i+1]) for i in range(len(bid_sizes))] + ask_cumsum = [sum(ask_sizes[:i+1]) for i in range(len(ask_sizes))] + except Exception as e: + bid_prices, bid_cumsum, ask_prices, ask_cumsum = [], [], [], [] + + fig = go.Figure() + + # Add order book lines (primary y-axis) + fig.add_trace(go.Scatter( + x=bid_prices, y=bid_cumsum, mode='lines', name='Bids', + line=dict(color='green'), fill='tozeroy', yaxis='y1' + )) + fig.add_trace(go.Scatter( + x=ask_prices, y=ask_cumsum, mode='lines', name='Asks', + line=dict(color='red'), fill='tozeroy', yaxis='y1' + )) + + trade_volume_by_price = defaultdict(float) + + for trade in trade_history: + price_bin = round(float(trade['price']), 2) + trade_volume_by_price[price_bin] += float(trade['size']) + + prices = list(trade_volume_by_price.keys()) + volumes = list(trade_volume_by_price.values()) + + # Sort by price for display + sorted_pairs = sorted(zip(prices, volumes)) + prices = [p for p, v in sorted_pairs] + volumes = [v for p, v in sorted_pairs] + + # Add trade volume bars (secondary y-axis) + fig.add_trace(go.Bar( + x=prices, y=volumes, marker_color='#7ec8e3', name='Trade Volume', + opacity=0.7, yaxis='y2' + )) + + # Update layout for dual y-axes + fig.update_layout( + title='Order Book Depth & Realized Trade Volume by Price', + xaxis=dict(title='Price'), + yaxis=dict(title='Cumulative Size', side='left'), + yaxis2=dict( + title='Traded Volume', + overlaying='y', + side='right', + showgrid=False + ), + template='plotly_dark' + ) + return [fig] + + app.run(debug=True, use_reloader=False)