testing live plot

This commit is contained in:
Simon Moisy 2025-05-30 12:40:49 +08:00
parent 8f96e14b8b
commit f534825e53
6 changed files with 842 additions and 570 deletions

150
live_plot.py Normal file
View File

@ -0,0 +1,150 @@
import json
import threading
import time
import numpy as np
import pandas as pd
import xgboost as xgb
from datetime import datetime
from okx_client import OKXClient
import dash
from dash import dcc, html
from dash.dependencies import Output, Input
import plotly.graph_objs as go
import websocket
# --- Prediction utilities (from test_predictor.py) ---
def add_features(df):
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
def calc_rsi(close, window=14):
delta = close.diff()
up = delta.clip(lower=0)
down = -1 * delta.clip(upper=0)
ma_up = up.rolling(window=window, min_periods=window).mean()
ma_down = down.rolling(window=window, min_periods=window).mean()
rs = ma_up / ma_down
return 100 - (100 / (1 + rs))
df['rsi'] = calc_rsi(df['Close'])
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
df['sma_50'] = df['Close'].rolling(window=50).mean()
df['sma_200'] = df['Close'].rolling(window=200).mean()
high_low = df['High'] - df['Low']
high_close = np.abs(df['High'] - df['Close'].shift(1))
low_close = np.abs(df['Low'] - df['Close'].shift(1))
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
df['atr'] = tr.rolling(window=14).mean()
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
df['hour'] = df['Timestamp'].dt.hour
return df
def load_model_and_features(model_path):
model = xgb.Booster()
model.load_model(model_path)
feature_cols = [
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
]
return model, feature_cols
def predict_next(df, model, feature_cols):
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
pred_log_return = model.predict(dmatrix)[0]
last_price = df['Close'].iloc[-1]
pred_price = last_price * np.exp(pred_log_return)
return pred_price, pred_log_return, last_price
# --- Main live plotting ---
WINDOW = 50
MODEL_PATH = 'data/xgboost_model.json'
ohlcv_bars = [] # [timestamp, open, high, low, close, volume]
bar_lock = threading.Lock()
model, feature_cols = load_model_and_features(MODEL_PATH)
# --- Background thread to collect trades and aggregate OHLCV ---
def ws_collector():
client = OKXClient(authenticate=False)
client.subscribe_trades(instrument="BTC-USDT")
current_bar = None
while True:
try:
msg = client.ws.recv()
data = json.loads(msg)
except websocket._exceptions.WebSocketTimeoutException:
continue # Just try again
except Exception as e:
print(f"WebSocket error: {e}")
break # or try to reconnect
if 'arg' in data and data['arg'].get('channel', '') == 'trades':
for trade in data.get('data', []):
# trade: {'instId', 'tradeId', 'px', 'sz', 'side', 'ts'}
ts = int(trade['ts'])
price = float(trade['px'])
size = float(trade['sz'])
dt = datetime.utcfromtimestamp(ts / 1000)
bar_seconds = 30 # or 15, 30, etc.
bar_time = dt.replace(second=(dt.second // bar_seconds) * bar_seconds, microsecond=0)
with bar_lock:
if not ohlcv_bars or ohlcv_bars[-1][0] != bar_time:
# New bar
ohlcv_bars.append([bar_time, price, price, price, price, size])
if len(ohlcv_bars) > WINDOW:
ohlcv_bars.pop(0)
else:
# Update current bar
bar = ohlcv_bars[-1]
bar[2] = max(bar[2], price) # high
bar[3] = min(bar[3], price) # low
bar[4] = price # close
bar[5] += size # volume
# Start the background thread
threading.Thread(target=ws_collector, daemon=True).start()
# --- Dash App ---
app = dash.Dash(__name__)
app.layout = html.Div([
html.H2('BTC/USDT Price & Prediction (OKX, XGBoost, Trades Aggregated)', style={"textAlign": "center", "margin": 0, "padding": 0}),
dcc.Graph(id='live-graph', animate=False, style={"height": "90vh", "width": "100vw", "margin": 0, "padding": 0}),
dcc.Interval(
id='interval-component',
interval=5*1000, # 5 seconds
n_intervals=0
),
html.Div(id='prediction-output', style={"textAlign": "center", "fontSize": 20, "marginTop": 10})
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
@app.callback(
[Output('live-graph', 'figure'), Output('prediction-output', 'children')],
[Input('interval-component', 'n_intervals')]
)
def update_graph_live(n):
with bar_lock:
bars = list(ohlcv_bars)
if len(bars) < 2:
return go.Figure(), "Waiting for data..."
df = pd.DataFrame(bars, columns=["Timestamp", "Open", "High", "Low", "Close", "Volume"])
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
df = add_features(df)
if df[feature_cols].isnull().any().any():
pred_text = "Not enough data for prediction."
else:
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
pred_text = f"Last: {last_price:.2f} | Predicted next: {pred_price:.2f} | LogRet: {pred_log_return:.6f}"
fig = go.Figure()
fig.add_trace(go.Candlestick(
x=df["Timestamp"],
open=df["Open"],
high=df["High"],
low=df["Low"],
close=df["Close"],
name='Candlestick',
increasing_line_color='green',
decreasing_line_color='red',
showlegend=False
))
fig.update_layout(title='BTC-USDT 1m OHLCV (Aggregated from Trades)', xaxis_title='Time', yaxis_title='Price (USDT)', xaxis_rangeslider_visible=False)
return fig, pred_text
if __name__ == '__main__':
app.run(debug=True)

304
main.py
View File

@ -1,152 +1,152 @@
from okx_client import OKXClient
from market_db import MarketDB
import json
import logging
import threading
from collections import deque
import time
import signal
latest_book = {'bids': [], 'asks': [], 'timestamp': None}
book_history = deque()
trade_history = deque()
TRADE_HISTORY_SECONDS = 60
BOOK_HISTORY_SECONDS = 5
shutdown_flag = threading.Event()
def connect(instrument, max_retries=5):
logging.info(f"Connecting to OKX for instrument: {instrument}")
retries = 0
backoff = 1
while not shutdown_flag.is_set():
try:
client = OKXClient(authenticate=False)
client.subscribe_trades(instrument)
client.subscribe_book(instrument, depth=5, channel="books")
logging.info(f"Subscribed to trades and book for {instrument}")
return client
except Exception as e:
retries += 1
logging.error(f"Failed to connect to OKX: {e}. Retry {retries}/{max_retries} in {backoff}s.")
if retries >= max_retries:
logging.critical("Max retries reached. Exiting connect loop.")
raise
time.sleep(backoff)
backoff = min(backoff * 2, 60) # exponential backoff, max 60s
return None
def cleanup(client, db):
if client and hasattr(client, 'ws') and client.ws:
try:
client.ws.close()
except Exception as e:
logging.warning(f"Error closing websocket: {e}")
if db:
db.close()
def signal_handler(signum, frame):
logging.info(f"Received signal {signum}, shutting down...")
shutdown_flag.set()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
def main():
instruments = [
"ETH-USDT",
"BTC-USDT",
"SOL-USDT",
"DOGE-USDT",
"TON-USDT",
"ETH-USDC",
"SOPH-USDT",
"PEPE-USDT",
"BTC-USDC",
"UNI-USDT"
]
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
dbs = {}
clients = {}
try:
for instrument in instruments:
dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db")
logging.info(f"Database initialized for {instrument}")
clients[instrument] = connect(instrument)
while not shutdown_flag.is_set():
for instrument in instruments:
client = clients[instrument]
db = dbs[instrument]
try:
data = client.ws.recv()
except Exception as e:
logging.warning(f"WebSocket disconnected or error for {instrument}: {e}. Reconnecting...")
cleanup(client, None)
try:
clients[instrument] = connect(instrument)
except Exception as e:
logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.")
continue
continue
if shutdown_flag.is_set():
break
if data == '':
continue
try:
msg = json.loads(data)
except Exception as e:
logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}")
continue
if 'arg' in msg and msg['arg'].get('channel') == 'trades':
for trade in msg.get('data', []):
db.insert_trade({
'instrument': instrument,
'trade_id': trade.get('tradeId'),
'price': float(trade.get('px')),
'size': float(trade.get('sz')),
'side': trade.get('side'),
'timestamp': trade.get('ts')
})
ts = float(trade.get('ts', time.time() * 1000))
trade_history.append({
'price': trade.get('px'),
'size': trade.get('sz'),
'side': trade.get('side'),
'timestamp': ts
})
elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'):
for book in msg.get('data', []):
db.insert_book({
'instrument': instrument,
'bids': book.get('bids'),
'asks': book.get('asks'),
'timestamp': book.get('ts')
})
latest_book['bids'] = book.get('bids', [])
latest_book['asks'] = book.get('asks', [])
latest_book['timestamp'] = book.get('ts')
ts = float(book.get('ts', time.time() * 1000))
book_history.append({
'bids': book.get('bids', []),
'asks': book.get('asks', []),
'timestamp': ts
})
else:
logging.info(f"Unknown message for {instrument}: {msg}")
except Exception as e:
logging.critical(f"Fatal error in main: {e}")
finally:
for client in clients.values():
cleanup(client, None)
for db in dbs.values():
cleanup(None, db)
logging.info('Shutdown complete.')
if __name__ == '__main__':
main()
from okx_client import OKXClient
from market_db import MarketDB
import json
import logging
import threading
from collections import deque
import time
import signal
latest_book = {'bids': [], 'asks': [], 'timestamp': None}
book_history = deque()
trade_history = deque()
TRADE_HISTORY_SECONDS = 60
BOOK_HISTORY_SECONDS = 5
shutdown_flag = threading.Event()
def connect(instrument, max_retries=5):
logging.info(f"Connecting to OKX for instrument: {instrument}")
retries = 0
backoff = 1
while not shutdown_flag.is_set():
try:
client = OKXClient(authenticate=False)
client.subscribe_trades(instrument)
client.subscribe_book(instrument, depth=5, channel="books")
logging.info(f"Subscribed to trades and book for {instrument}")
return client
except Exception as e:
retries += 1
logging.error(f"Failed to connect to OKX: {e}. Retry {retries}/{max_retries} in {backoff}s.")
if retries >= max_retries:
logging.critical("Max retries reached. Exiting connect loop.")
raise
time.sleep(backoff)
backoff = min(backoff * 2, 60) # exponential backoff, max 60s
return None
def cleanup(client, db):
if client and hasattr(client, 'ws') and client.ws:
try:
client.ws.close()
except Exception as e:
logging.warning(f"Error closing websocket: {e}")
if db:
db.close()
def signal_handler(signum, frame):
logging.info(f"Received signal {signum}, shutting down...")
shutdown_flag.set()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
def main():
instruments = [
"ETH-USDT",
"BTC-USDT",
"SOL-USDT",
"DOGE-USDT",
"TON-USDT",
"ETH-USDC",
"SOPH-USDT",
"PEPE-USDT",
"BTC-USDC",
"UNI-USDT"
]
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
dbs = {}
clients = {}
try:
for instrument in instruments:
dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db")
logging.info(f"Database initialized for {instrument}")
clients[instrument] = connect(instrument)
while not shutdown_flag.is_set():
for instrument in instruments:
client = clients[instrument]
db = dbs[instrument]
try:
data = client.ws.recv()
except Exception as e:
logging.warning(f"WebSocket disconnected or error for {instrument}: {e}. Reconnecting...")
cleanup(client, None)
try:
clients[instrument] = connect(instrument)
except Exception as e:
logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.")
continue
continue
if shutdown_flag.is_set():
break
if data == '':
continue
try:
msg = json.loads(data)
except Exception as e:
logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}")
continue
if 'arg' in msg and msg['arg'].get('channel') == 'trades':
for trade in msg.get('data', []):
db.insert_trade({
'instrument': instrument,
'trade_id': trade.get('tradeId'),
'price': float(trade.get('px')),
'size': float(trade.get('sz')),
'side': trade.get('side'),
'timestamp': trade.get('ts')
})
ts = float(trade.get('ts', time.time() * 1000))
trade_history.append({
'price': trade.get('px'),
'size': trade.get('sz'),
'side': trade.get('side'),
'timestamp': ts
})
elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'):
for book in msg.get('data', []):
db.insert_book({
'instrument': instrument,
'bids': book.get('bids'),
'asks': book.get('asks'),
'timestamp': book.get('ts')
})
latest_book['bids'] = book.get('bids', [])
latest_book['asks'] = book.get('asks', [])
latest_book['timestamp'] = book.get('ts')
ts = float(book.get('ts', time.time() * 1000))
book_history.append({
'bids': book.get('bids', []),
'asks': book.get('asks', []),
'timestamp': ts
})
else:
logging.info(f"Unknown message for {instrument}: {msg}")
except Exception as e:
logging.critical(f"Fatal error in main: {e}")
finally:
for client in clients.values():
cleanup(client, None)
for db in dbs.values():
cleanup(None, db)
logging.info('Shutdown complete.')
if __name__ == '__main__':
main()

View File

@ -1,77 +1,77 @@
import sqlite3
import os
from typing import Any, Dict, List, Optional
import logging
class MarketDB:
def __init__(self, market: str, db_dir: str = ""):
db_name = f"{market}.db"
db_path = db_name if not db_dir else f"{db_dir.rstrip('/')}/{db_name}"
if db_dir:
os.makedirs(db_dir, exist_ok=True)
self.conn = sqlite3.connect(db_path)
logging.info(f"Connected to database at {db_path}")
self._create_tables()
def _create_tables(self):
cursor = self.conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
instrument TEXT,
trade_id TEXT,
price REAL,
size REAL,
side TEXT,
timestamp TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS book (
id INTEGER PRIMARY KEY AUTOINCREMENT,
instrument TEXT,
bids TEXT,
asks TEXT,
timestamp TEXT
)
''')
self.conn.commit()
logging.info("Database tables ensured.")
def insert_trade(self, trade: Dict[str, Any]):
cursor = self.conn.cursor()
cursor.execute('''
INSERT INTO trades (instrument, trade_id, price, size, side, timestamp)
VALUES (?, ?, ?, ?, ?, ?)
''', (
trade.get('instrument'),
trade.get('trade_id'),
trade.get('price'),
trade.get('size'),
trade.get('side'),
trade.get('timestamp')
))
self.conn.commit()
logging.debug(f"Inserted trade: {trade}")
def insert_book(self, book: Dict[str, Any]):
cursor = self.conn.cursor()
bids = book.get('bids', [])
asks = book.get('asks', [])
best_bid = next((b for b in bids if float(b[1]) > 0), ['-', '-'])
best_ask = next((a for a in asks if float(a[1]) > 0), ['-', '-'])
cursor.execute('''
INSERT INTO book (instrument, bids, asks, timestamp)
VALUES (?, ?, ?, ?)
''', (
book.get('instrument'),
str(bids),
str(asks),
book.get('timestamp')
))
self.conn.commit()
logging.debug(f"Inserted book: {book.get('instrument', 'N/A')} ts:{book.get('timestamp', 'N/A')} bid:{best_bid} ask:{best_ask}")
def close(self):
self.conn.close()
logging.info("Database connection closed.")
import sqlite3
import os
from typing import Any, Dict, List, Optional
import logging
class MarketDB:
def __init__(self, market: str, db_dir: str = ""):
db_name = f"{market}.db"
db_path = db_name if not db_dir else f"{db_dir.rstrip('/')}/{db_name}"
if db_dir:
os.makedirs(db_dir, exist_ok=True)
self.conn = sqlite3.connect(db_path)
logging.info(f"Connected to database at {db_path}")
self._create_tables()
def _create_tables(self):
cursor = self.conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
instrument TEXT,
trade_id TEXT,
price REAL,
size REAL,
side TEXT,
timestamp TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS book (
id INTEGER PRIMARY KEY AUTOINCREMENT,
instrument TEXT,
bids TEXT,
asks TEXT,
timestamp TEXT
)
''')
self.conn.commit()
logging.info("Database tables ensured.")
def insert_trade(self, trade: Dict[str, Any]):
cursor = self.conn.cursor()
cursor.execute('''
INSERT INTO trades (instrument, trade_id, price, size, side, timestamp)
VALUES (?, ?, ?, ?, ?, ?)
''', (
trade.get('instrument'),
trade.get('trade_id'),
trade.get('price'),
trade.get('size'),
trade.get('side'),
trade.get('timestamp')
))
self.conn.commit()
logging.debug(f"Inserted trade: {trade}")
def insert_book(self, book: Dict[str, Any]):
cursor = self.conn.cursor()
bids = book.get('bids', [])
asks = book.get('asks', [])
best_bid = next((b for b in bids if float(b[1]) > 0), ['-', '-'])
best_ask = next((a for a in asks if float(a[1]) > 0), ['-', '-'])
cursor.execute('''
INSERT INTO book (instrument, bids, asks, timestamp)
VALUES (?, ?, ?, ?)
''', (
book.get('instrument'),
str(bids),
str(asks),
book.get('timestamp')
))
self.conn.commit()
logging.debug(f"Inserted book: {book.get('instrument', 'N/A')} ts:{book.get('timestamp', 'N/A')} bid:{best_bid} ask:{best_ask}")
def close(self):
self.conn.close()
logging.info("Database connection closed.")

View File

@ -1,233 +1,233 @@
import os
import time
import hmac
import hashlib
import base64
import json
import pandas as pd
import threading
import requests
import websocket
import logging
class OKXClient:
PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public"
PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private"
REST_URL = "https://www.okx.com"
def __init__(self, authenticate: bool = True):
self.authenticated = False
self.api_key = None
self.api_secret = None
self.api_passphrase = None
self.ws = None
self.ws_private = None
self._lock = threading.Lock()
self._private_lock = threading.Lock()
if authenticate:
config_path = os.path.join(os.path.dirname(__file__), '../credentials/okx_creds.json')
try:
with open(config_path, 'r') as f:
config = json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Credentials file not found at {config_path}. Please create it with the required keys.")
except json.JSONDecodeError:
raise ValueError(f"Credentials file at {config_path} is not valid JSON.")
self.api_key = config.get("OKX_API_KEY")
self.api_secret = config.get("OKX_API_SECRET")
self.api_passphrase = config.get("OKX_API_PASSPHRASE")
if not self.api_key or not self.api_secret or not self.api_passphrase:
raise ValueError("API key, secret, and passphrase must be set in the credentials JSON file.")
self._authenticate()
self._connect_ws()
def _connect_ws(self):
if self.ws is None:
self.ws = websocket.create_connection(self.PUBLIC_WS_URL, timeout=10)
if self.authenticated and self.api_key and self.api_secret and self.api_passphrase and self.ws_private is None:
self.ws_private = websocket.create_connection(self.PRIVATE_WS_URL, timeout=10)
def _get_timestamp(self):
return str(round(time.time(), 3))
def _sign(self, timestamp, method, request_path, body):
if not body:
body = ''
message = f'{timestamp}{method}{request_path}{body}'
mac = hmac.new(self.api_secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256)
return base64.b64encode(mac.digest()).decode()
def _authenticate(self):
import websocket
timestamp = self._get_timestamp()
sign = self._sign(timestamp, 'GET', '/users/self/verify', '')
login_params = {
"op": "login",
"args": [{
"apiKey": self.api_key,
"passphrase": self.api_passphrase,
"timestamp": timestamp,
"sign": sign
}]
}
self.ws_private.send(json.dumps(login_params))
logging.info("Waiting for login response from OKX...")
while True:
try:
resp = self.ws_private.recv()
logging.debug(f"Received from OKX private WS: {resp}")
if not resp:
continue
try:
msg = json.loads(resp)
except Exception:
logging.warning(f"Non-JSON message received: {resp}")
continue
if msg.get("event") == "login":
if msg.get("code") == "0":
logging.info("OKX WebSocket login successful.")
self.authenticated = True
break
else:
raise Exception(f"WebSocket authentication failed: {msg}")
except websocket._exceptions.WebSocketConnectionClosedException as e:
logging.error(f"WebSocket connection closed during authentication: {e}")
raise
except Exception as e:
logging.error(f"Exception during authentication: {e}")
raise
def subscribe_candlesticks(self, instrument="BTC-USDT", timeframe="1m"):
# OKX uses candle1m, candle5m, etc.
tf_map = {"1m": "candle1m", "5m": "candle5m", "15m": "candle15m", "1h": "candle1H"}
channel = tf_map.get(timeframe, f"candle{timeframe}")
params = {
"op": "subscribe",
"args": [{"channel": channel, "instId": instrument}]
}
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
self.ws.send(json.dumps(params))
def subscribe_trades(self, instrument="BTC-USDT"):
params = {
"op": "subscribe",
"args": [{"channel": "trades", "instId": instrument}]
}
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
self.ws.send(json.dumps(params))
def subscribe_ticker(self, instrument="BTC-USDT"):
params = {
"op": "subscribe",
"args": [{"channel": "tickers", "instId": instrument}]
}
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
self.ws.send(json.dumps(params))
def subscribe_book(self, instrument="BTC-USDT", depth=5, channel="books5"):
# OKX supports books5, books50, books-l2-tbt
# channel = "books5" if depth <= 5 else "books50"
params = {
"op": "subscribe",
"args": [{"channel": channel, "instId": instrument}]
}
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
self.ws.send(json.dumps(params))
def subscribe_user_order(self):
if not self.authenticated:
logging.warning("Attempted to subscribe to user order channel without authentication.")
return
params = {
"op": "subscribe",
"args": [{"channel": "orders", "instType": "SPOT"}]
}
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
self.ws_private.send(json.dumps(params))
def subscribe_user_trade(self):
if not self.authenticated:
logging.warning("Attempted to subscribe to user trade channel without authentication.")
return
params = {
"op": "subscribe",
"args": [{"channel": "trades", "instType": "SPOT"}]
}
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
self.ws_private.send(json.dumps(params))
def subscribe_user_balance(self):
if not self.authenticated:
logging.warning("Attempted to subscribe to user balance channel without authentication.")
return
params = {
"op": "subscribe",
"args": [{"channel": "balance_and_position"}]
}
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
self.ws_private.send(json.dumps(params))
def get_balance(self, currency=None):
url = f"{self.REST_URL}/api/v5/account/balance"
timestamp = self._get_timestamp()
method = "GET"
request_path = "/api/v5/account/balance"
body = ''
sign = self._sign(timestamp, method, request_path, body)
headers = {
"OK-ACCESS-KEY": self.api_key,
"OK-ACCESS-SIGN": sign,
"OK-ACCESS-TIMESTAMP": timestamp,
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
"Content-Type": "application/json"
}
resp = requests.get(url, headers=headers)
if resp.status_code == 200:
data = resp.json()
balances = data.get("data", [{}])[0].get("details", [])
if currency:
return [b for b in balances if b.get("ccy") == currency]
return balances
return []
def place_order(self, side, amount, instrument="BTC-USDT"):
url = f"{self.REST_URL}/api/v5/trade/order"
timestamp = self._get_timestamp()
method = "POST"
request_path = "/api/v5/trade/order"
body_dict = {
"instId": instrument,
"tdMode": "cash",
"side": side.lower(),
"ordType": "market",
"sz": str(amount)
}
body = json.dumps(body_dict)
sign = self._sign(timestamp, method, request_path, body)
headers = {
"OK-ACCESS-KEY": self.api_key,
"OK-ACCESS-SIGN": sign,
"OK-ACCESS-TIMESTAMP": timestamp,
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
"Content-Type": "application/json"
}
resp = requests.post(url, headers=headers, data=body)
return resp.json()
def buy_btc(self, amount, instrument="BTC-USDT"):
return self.place_order("buy", amount, instrument)
def sell_btc(self, amount, instrument="BTC-USDT"):
return self.place_order("sell", amount, instrument)
def get_instruments(self):
url = f"{self.REST_URL}/api/v5/public/instruments?instType=SPOT"
resp = requests.get(url)
if resp.status_code == 200:
data = resp.json()
return data.get("data", [])
return []
import os
import time
import hmac
import hashlib
import base64
import json
import pandas as pd
import threading
import requests
import websocket
import logging
class OKXClient:
PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public"
PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private"
REST_URL = "https://www.okx.com"
def __init__(self, authenticate: bool = True):
self.authenticated = False
self.api_key = None
self.api_secret = None
self.api_passphrase = None
self.ws = None
self.ws_private = None
self._lock = threading.Lock()
self._private_lock = threading.Lock()
if authenticate:
config_path = os.path.join(os.path.dirname(__file__), '../credentials/okx_creds.json')
try:
with open(config_path, 'r') as f:
config = json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Credentials file not found at {config_path}. Please create it with the required keys.")
except json.JSONDecodeError:
raise ValueError(f"Credentials file at {config_path} is not valid JSON.")
self.api_key = config.get("OKX_API_KEY")
self.api_secret = config.get("OKX_API_SECRET")
self.api_passphrase = config.get("OKX_API_PASSPHRASE")
if not self.api_key or not self.api_secret or not self.api_passphrase:
raise ValueError("API key, secret, and passphrase must be set in the credentials JSON file.")
self._authenticate()
self._connect_ws()
def _connect_ws(self):
if self.ws is None:
self.ws = websocket.create_connection(self.PUBLIC_WS_URL, timeout=10)
if self.authenticated and self.api_key and self.api_secret and self.api_passphrase and self.ws_private is None:
self.ws_private = websocket.create_connection(self.PRIVATE_WS_URL, timeout=10)
def _get_timestamp(self):
return str(round(time.time(), 3))
def _sign(self, timestamp, method, request_path, body):
if not body:
body = ''
message = f'{timestamp}{method}{request_path}{body}'
mac = hmac.new(self.api_secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256)
return base64.b64encode(mac.digest()).decode()
def _authenticate(self):
import websocket
timestamp = self._get_timestamp()
sign = self._sign(timestamp, 'GET', '/users/self/verify', '')
login_params = {
"op": "login",
"args": [{
"apiKey": self.api_key,
"passphrase": self.api_passphrase,
"timestamp": timestamp,
"sign": sign
}]
}
self.ws_private.send(json.dumps(login_params))
logging.info("Waiting for login response from OKX...")
while True:
try:
resp = self.ws_private.recv()
logging.debug(f"Received from OKX private WS: {resp}")
if not resp:
continue
try:
msg = json.loads(resp)
except Exception:
logging.warning(f"Non-JSON message received: {resp}")
continue
if msg.get("event") == "login":
if msg.get("code") == "0":
logging.info("OKX WebSocket login successful.")
self.authenticated = True
break
else:
raise Exception(f"WebSocket authentication failed: {msg}")
except websocket._exceptions.WebSocketConnectionClosedException as e:
logging.error(f"WebSocket connection closed during authentication: {e}")
raise
except Exception as e:
logging.error(f"Exception during authentication: {e}")
raise
def subscribe_candlesticks(self, instrument="BTC-USDT", timeframe="1m"):
# OKX uses candle1m, candle5m, etc.
tf_map = {"1m": "candle1m", "5m": "candle5m", "15m": "candle15m", "1h": "candle1H"}
channel = tf_map.get(timeframe, f"candle{timeframe}")
params = {
"op": "subscribe",
"args": [{"channel": channel, "instId": instrument}]
}
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
self.ws.send(json.dumps(params))
def subscribe_trades(self, instrument="BTC-USDT"):
params = {
"op": "subscribe",
"args": [{"channel": "trades", "instId": instrument}]
}
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
self.ws.send(json.dumps(params))
def subscribe_ticker(self, instrument="BTC-USDT"):
params = {
"op": "subscribe",
"args": [{"channel": "tickers", "instId": instrument}]
}
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
self.ws.send(json.dumps(params))
def subscribe_book(self, instrument="BTC-USDT", depth=5, channel="books5"):
# OKX supports books5, books50, books-l2-tbt
# channel = "books5" if depth <= 5 else "books50"
params = {
"op": "subscribe",
"args": [{"channel": channel, "instId": instrument}]
}
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
self.ws.send(json.dumps(params))
def subscribe_user_order(self):
if not self.authenticated:
logging.warning("Attempted to subscribe to user order channel without authentication.")
return
params = {
"op": "subscribe",
"args": [{"channel": "orders", "instType": "SPOT"}]
}
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
self.ws_private.send(json.dumps(params))
def subscribe_user_trade(self):
if not self.authenticated:
logging.warning("Attempted to subscribe to user trade channel without authentication.")
return
params = {
"op": "subscribe",
"args": [{"channel": "trades", "instType": "SPOT"}]
}
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
self.ws_private.send(json.dumps(params))
def subscribe_user_balance(self):
if not self.authenticated:
logging.warning("Attempted to subscribe to user balance channel without authentication.")
return
params = {
"op": "subscribe",
"args": [{"channel": "balance_and_position"}]
}
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
self.ws_private.send(json.dumps(params))
def get_balance(self, currency=None):
url = f"{self.REST_URL}/api/v5/account/balance"
timestamp = self._get_timestamp()
method = "GET"
request_path = "/api/v5/account/balance"
body = ''
sign = self._sign(timestamp, method, request_path, body)
headers = {
"OK-ACCESS-KEY": self.api_key,
"OK-ACCESS-SIGN": sign,
"OK-ACCESS-TIMESTAMP": timestamp,
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
"Content-Type": "application/json"
}
resp = requests.get(url, headers=headers)
if resp.status_code == 200:
data = resp.json()
balances = data.get("data", [{}])[0].get("details", [])
if currency:
return [b for b in balances if b.get("ccy") == currency]
return balances
return []
def place_order(self, side, amount, instrument="BTC-USDT"):
url = f"{self.REST_URL}/api/v5/trade/order"
timestamp = self._get_timestamp()
method = "POST"
request_path = "/api/v5/trade/order"
body_dict = {
"instId": instrument,
"tdMode": "cash",
"side": side.lower(),
"ordType": "market",
"sz": str(amount)
}
body = json.dumps(body_dict)
sign = self._sign(timestamp, method, request_path, body)
headers = {
"OK-ACCESS-KEY": self.api_key,
"OK-ACCESS-SIGN": sign,
"OK-ACCESS-TIMESTAMP": timestamp,
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
"Content-Type": "application/json"
}
resp = requests.post(url, headers=headers, data=body)
return resp.json()
def buy_btc(self, amount, instrument="BTC-USDT"):
return self.place_order("buy", amount, instrument)
def sell_btc(self, amount, instrument="BTC-USDT"):
return self.place_order("sell", amount, instrument)
def get_instruments(self):
url = f"{self.REST_URL}/api/v5/public/instruments?instType=SPOT"
resp = requests.get(url)
if resp.status_code == 200:
data = resp.json()
return data.get("data", [])
return []

122
test_predictor.py Normal file
View File

@ -0,0 +1,122 @@
import time
import json
import numpy as np
import pandas as pd
import requests
import xgboost as xgb
from datetime import datetime, timedelta
# --- CONFIG ---
OKX_REST_URL = 'https://www.okx.com/api/v5/market/candles'
SYMBOL = 'BTC-USDT'
BAR = '1m'
HIST_MINUTES = 250 # Number of minutes of history to fetch for features
MODEL_PATH = 'data/xgboost_model.json'
# --- Fetch recent candles from OKX REST API ---
def fetch_recent_candles(symbol, bar, limit=HIST_MINUTES):
params = {
'instId': symbol,
'bar': bar,
'limit': str(limit)
}
resp = requests.get(OKX_REST_URL, params=params)
data = resp.json()
if data['code'] != '0':
raise Exception(f"OKX API error: {data['msg']}")
# OKX returns most recent first, reverse to chronological
candles = data['data'][::-1]
df = pd.DataFrame(candles)
# OKX columns: [ts, o, h, l, c, vol, volCcy, confirm, ...] (see API docs)
# We'll use: ts, o, h, l, c, vol
col_map = {
0: 'Timestamp',
1: 'Open',
2: 'High',
3: 'Low',
4: 'Close',
5: 'Volume',
}
df = df.rename(columns={str(k): v for k, v in col_map.items()})
# If columns are not named, use integer index
for k, v in col_map.items():
if v not in df.columns:
df[v] = df.iloc[:, k]
df = df[['Timestamp', 'Open', 'High', 'Low', 'Close', 'Volume']]
df['Timestamp'] = pd.to_datetime(df['Timestamp'].astype(np.int64), unit='ms')
for col in ['Open', 'High', 'Low', 'Close', 'Volume']:
df[col] = pd.to_numeric(df[col], errors='coerce')
return df
# --- Feature Engineering (minimal, real-time) ---
def add_features(df):
# Log return (target, not used for prediction)
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
# RSI (14)
def calc_rsi(close, window=14):
delta = close.diff()
up = delta.clip(lower=0)
down = -1 * delta.clip(upper=0)
ma_up = up.rolling(window=window, min_periods=window).mean()
ma_down = down.rolling(window=window, min_periods=window).mean()
rs = ma_up / ma_down
return 100 - (100 / (1 + rs))
df['rsi'] = calc_rsi(df['Close'])
# EMA 14
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
# SMA 50, 200
df['sma_50'] = df['Close'].rolling(window=50).mean()
df['sma_200'] = df['Close'].rolling(window=200).mean()
# ATR 14
high_low = df['High'] - df['Low']
high_close = np.abs(df['High'] - df['Close'].shift(1))
low_close = np.abs(df['Low'] - df['Close'].shift(1))
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
df['atr'] = tr.rolling(window=14).mean()
# ROC 10
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
# DPO 20
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
# Hour
df['hour'] = df['Timestamp'].dt.hour
# Add more features as needed (match with main.py)
return df
# --- Load model and feature columns ---
def load_model_and_features(model_path):
model = xgb.Booster()
model.load_model(model_path)
# Try to infer feature names from main.py (hardcoded for now)
feature_cols = [
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
]
return model, feature_cols
# --- Predict next log return and price ---
def predict_next(df, model, feature_cols):
# Use the last row for prediction
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
pred_log_return = model.predict(dmatrix)[0]
last_price = df['Close'].iloc[-1]
pred_price = last_price * np.exp(pred_log_return)
return pred_price, pred_log_return, last_price
if __name__ == '__main__':
print('Fetching recent candles from OKX...')
df = fetch_recent_candles(SYMBOL, BAR)
df = add_features(df)
model, feature_cols = load_model_and_features(MODEL_PATH)
print('Waiting for new candle...')
last_timestamp = df['Timestamp'].iloc[-1]
while True:
time.sleep(5)
new_df = fetch_recent_candles(SYMBOL, BAR, limit=HIST_MINUTES)
if new_df['Timestamp'].iloc[-1] > last_timestamp:
df = new_df
df = add_features(df)
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
print(f"[{df['Timestamp'].iloc[-1]}] Last price: {last_price:.2f} | Predicted next price: {pred_price:.2f} | Predicted log return: {pred_log_return:.6f}")
last_timestamp = df['Timestamp'].iloc[-1]
else:
print('No new candle yet...')

View File

@ -1,108 +1,108 @@
import dash
from dash import dcc, html
from dash.dependencies import Output, Input
import plotly.graph_objs as go
import time
from collections import defaultdict
def run_dash(book_history, trade_history, BOOK_HISTORY_SECONDS=5, TRADE_HISTORY_SECONDS=60):
app = dash.Dash(__name__)
app.layout = html.Div([
html.H1("Order Book Depth Chart", style={"textAlign": "center", "color": "#222"}),
dcc.Graph(id='order-book-graph', style={"height": "90vh", "width": "100vw"}),
dcc.Interval(id='interval-component', interval=2*1000, n_intervals=0)
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
@app.callback(
[Output('order-book-graph', 'figure')],
[Input('interval-component', 'n_intervals')]
)
def update_graphs(n):
now = time.time() * 1000 # current time in ms
# Prune book_history to only keep last BOOK_HISTORY_SECONDS
while book_history and now - book_history[0]['timestamp'] > BOOK_HISTORY_SECONDS * 1000:
book_history.popleft()
# Prune trade_history to only keep last TRADE_HISTORY_SECONDS
while trade_history and now - float(trade_history[0]['timestamp']) > TRADE_HISTORY_SECONDS * 1000:
trade_history.popleft()
# Aggregate bids/asks from book_history
bids_dict = {}
asks_dict = {}
for book in book_history:
for price, size, *_ in book['bids']:
price = float(price)
size = float(size)
bids_dict[price] = bids_dict.get(price, 0) + size
for price, size, *_ in book['asks']:
price = float(price)
size = float(size)
asks_dict[price] = asks_dict.get(price, 0) + size
try:
# Prepare and sort bids/asks
bids = sorted([[p, s] for p, s in bids_dict.items()], reverse=True)
asks = sorted([[p, s] for p, s in asks_dict.items()])
# Cumulative sum
bid_prices = [b[0] for b in bids]
bid_sizes = [b[1] for b in bids]
ask_prices = [a[0] for a in asks]
ask_sizes = [a[1] for a in asks]
bid_cumsum = [sum(bid_sizes[:i+1]) for i in range(len(bid_sizes))]
ask_cumsum = [sum(ask_sizes[:i+1]) for i in range(len(ask_sizes))]
except Exception as e:
bid_prices, bid_cumsum, ask_prices, ask_cumsum = [], [], [], []
fig = go.Figure()
# Add order book lines (primary y-axis)
fig.add_trace(go.Scatter(
x=bid_prices, y=bid_cumsum, mode='lines', name='Bids',
line=dict(color='green'), fill='tozeroy', yaxis='y1'
))
fig.add_trace(go.Scatter(
x=ask_prices, y=ask_cumsum, mode='lines', name='Asks',
line=dict(color='red'), fill='tozeroy', yaxis='y1'
))
trade_volume_by_price = defaultdict(float)
for trade in trade_history:
price_bin = round(float(trade['price']), 2)
trade_volume_by_price[price_bin] += float(trade['size'])
prices = list(trade_volume_by_price.keys())
volumes = list(trade_volume_by_price.values())
# Sort by price for display
sorted_pairs = sorted(zip(prices, volumes))
prices = [p for p, v in sorted_pairs]
volumes = [v for p, v in sorted_pairs]
# Add trade volume bars (secondary y-axis)
fig.add_trace(go.Bar(
x=prices, y=volumes, marker_color='#7ec8e3', name='Trade Volume',
opacity=0.7, yaxis='y2'
))
# Update layout for dual y-axes
fig.update_layout(
title='Order Book Depth & Realized Trade Volume by Price',
xaxis=dict(title='Price'),
yaxis=dict(title='Cumulative Size', side='left'),
yaxis2=dict(
title='Traded Volume',
overlaying='y',
side='right',
showgrid=False
),
template='plotly_dark'
)
return [fig]
app.run(debug=True, use_reloader=False)
import dash
from dash import dcc, html
from dash.dependencies import Output, Input
import plotly.graph_objs as go
import time
from collections import defaultdict
def run_dash(book_history, trade_history, BOOK_HISTORY_SECONDS=5, TRADE_HISTORY_SECONDS=60):
app = dash.Dash(__name__)
app.layout = html.Div([
html.H1("Order Book Depth Chart", style={"textAlign": "center", "color": "#222"}),
dcc.Graph(id='order-book-graph', style={"height": "90vh", "width": "100vw"}),
dcc.Interval(id='interval-component', interval=2*1000, n_intervals=0)
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
@app.callback(
[Output('order-book-graph', 'figure')],
[Input('interval-component', 'n_intervals')]
)
def update_graphs(n):
now = time.time() * 1000 # current time in ms
# Prune book_history to only keep last BOOK_HISTORY_SECONDS
while book_history and now - book_history[0]['timestamp'] > BOOK_HISTORY_SECONDS * 1000:
book_history.popleft()
# Prune trade_history to only keep last TRADE_HISTORY_SECONDS
while trade_history and now - float(trade_history[0]['timestamp']) > TRADE_HISTORY_SECONDS * 1000:
trade_history.popleft()
# Aggregate bids/asks from book_history
bids_dict = {}
asks_dict = {}
for book in book_history:
for price, size, *_ in book['bids']:
price = float(price)
size = float(size)
bids_dict[price] = bids_dict.get(price, 0) + size
for price, size, *_ in book['asks']:
price = float(price)
size = float(size)
asks_dict[price] = asks_dict.get(price, 0) + size
try:
# Prepare and sort bids/asks
bids = sorted([[p, s] for p, s in bids_dict.items()], reverse=True)
asks = sorted([[p, s] for p, s in asks_dict.items()])
# Cumulative sum
bid_prices = [b[0] for b in bids]
bid_sizes = [b[1] for b in bids]
ask_prices = [a[0] for a in asks]
ask_sizes = [a[1] for a in asks]
bid_cumsum = [sum(bid_sizes[:i+1]) for i in range(len(bid_sizes))]
ask_cumsum = [sum(ask_sizes[:i+1]) for i in range(len(ask_sizes))]
except Exception as e:
bid_prices, bid_cumsum, ask_prices, ask_cumsum = [], [], [], []
fig = go.Figure()
# Add order book lines (primary y-axis)
fig.add_trace(go.Scatter(
x=bid_prices, y=bid_cumsum, mode='lines', name='Bids',
line=dict(color='green'), fill='tozeroy', yaxis='y1'
))
fig.add_trace(go.Scatter(
x=ask_prices, y=ask_cumsum, mode='lines', name='Asks',
line=dict(color='red'), fill='tozeroy', yaxis='y1'
))
trade_volume_by_price = defaultdict(float)
for trade in trade_history:
price_bin = round(float(trade['price']), 2)
trade_volume_by_price[price_bin] += float(trade['size'])
prices = list(trade_volume_by_price.keys())
volumes = list(trade_volume_by_price.values())
# Sort by price for display
sorted_pairs = sorted(zip(prices, volumes))
prices = [p for p, v in sorted_pairs]
volumes = [v for p, v in sorted_pairs]
# Add trade volume bars (secondary y-axis)
fig.add_trace(go.Bar(
x=prices, y=volumes, marker_color='#7ec8e3', name='Trade Volume',
opacity=0.7, yaxis='y2'
))
# Update layout for dual y-axes
fig.update_layout(
title='Order Book Depth & Realized Trade Volume by Price',
xaxis=dict(title='Price'),
yaxis=dict(title='Cumulative Size', side='left'),
yaxis2=dict(
title='Traded Volume',
overlaying='y',
side='right',
showgrid=False
),
template='plotly_dark'
)
return [fig]
app.run(debug=True, use_reloader=False)