testing live plot
This commit is contained in:
parent
8f96e14b8b
commit
f534825e53
150
live_plot.py
Normal file
150
live_plot.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import xgboost as xgb
|
||||||
|
from datetime import datetime
|
||||||
|
from okx_client import OKXClient
|
||||||
|
import dash
|
||||||
|
from dash import dcc, html
|
||||||
|
from dash.dependencies import Output, Input
|
||||||
|
import plotly.graph_objs as go
|
||||||
|
import websocket
|
||||||
|
|
||||||
|
# --- Prediction utilities (from test_predictor.py) ---
|
||||||
|
def add_features(df):
|
||||||
|
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||||
|
def calc_rsi(close, window=14):
|
||||||
|
delta = close.diff()
|
||||||
|
up = delta.clip(lower=0)
|
||||||
|
down = -1 * delta.clip(upper=0)
|
||||||
|
ma_up = up.rolling(window=window, min_periods=window).mean()
|
||||||
|
ma_down = down.rolling(window=window, min_periods=window).mean()
|
||||||
|
rs = ma_up / ma_down
|
||||||
|
return 100 - (100 / (1 + rs))
|
||||||
|
df['rsi'] = calc_rsi(df['Close'])
|
||||||
|
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
|
||||||
|
df['sma_50'] = df['Close'].rolling(window=50).mean()
|
||||||
|
df['sma_200'] = df['Close'].rolling(window=200).mean()
|
||||||
|
high_low = df['High'] - df['Low']
|
||||||
|
high_close = np.abs(df['High'] - df['Close'].shift(1))
|
||||||
|
low_close = np.abs(df['Low'] - df['Close'].shift(1))
|
||||||
|
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
||||||
|
df['atr'] = tr.rolling(window=14).mean()
|
||||||
|
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
|
||||||
|
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
|
||||||
|
df['hour'] = df['Timestamp'].dt.hour
|
||||||
|
return df
|
||||||
|
|
||||||
|
def load_model_and_features(model_path):
|
||||||
|
model = xgb.Booster()
|
||||||
|
model.load_model(model_path)
|
||||||
|
feature_cols = [
|
||||||
|
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
|
||||||
|
]
|
||||||
|
return model, feature_cols
|
||||||
|
|
||||||
|
def predict_next(df, model, feature_cols):
|
||||||
|
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
|
||||||
|
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
|
||||||
|
pred_log_return = model.predict(dmatrix)[0]
|
||||||
|
last_price = df['Close'].iloc[-1]
|
||||||
|
pred_price = last_price * np.exp(pred_log_return)
|
||||||
|
return pred_price, pred_log_return, last_price
|
||||||
|
|
||||||
|
# --- Main live plotting ---
|
||||||
|
WINDOW = 50
|
||||||
|
MODEL_PATH = 'data/xgboost_model.json'
|
||||||
|
|
||||||
|
ohlcv_bars = [] # [timestamp, open, high, low, close, volume]
|
||||||
|
bar_lock = threading.Lock()
|
||||||
|
model, feature_cols = load_model_and_features(MODEL_PATH)
|
||||||
|
|
||||||
|
# --- Background thread to collect trades and aggregate OHLCV ---
|
||||||
|
def ws_collector():
|
||||||
|
client = OKXClient(authenticate=False)
|
||||||
|
client.subscribe_trades(instrument="BTC-USDT")
|
||||||
|
current_bar = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = client.ws.recv()
|
||||||
|
data = json.loads(msg)
|
||||||
|
except websocket._exceptions.WebSocketTimeoutException:
|
||||||
|
continue # Just try again
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WebSocket error: {e}")
|
||||||
|
break # or try to reconnect
|
||||||
|
if 'arg' in data and data['arg'].get('channel', '') == 'trades':
|
||||||
|
for trade in data.get('data', []):
|
||||||
|
# trade: {'instId', 'tradeId', 'px', 'sz', 'side', 'ts'}
|
||||||
|
ts = int(trade['ts'])
|
||||||
|
price = float(trade['px'])
|
||||||
|
size = float(trade['sz'])
|
||||||
|
dt = datetime.utcfromtimestamp(ts / 1000)
|
||||||
|
bar_seconds = 30 # or 15, 30, etc.
|
||||||
|
bar_time = dt.replace(second=(dt.second // bar_seconds) * bar_seconds, microsecond=0)
|
||||||
|
with bar_lock:
|
||||||
|
if not ohlcv_bars or ohlcv_bars[-1][0] != bar_time:
|
||||||
|
# New bar
|
||||||
|
ohlcv_bars.append([bar_time, price, price, price, price, size])
|
||||||
|
if len(ohlcv_bars) > WINDOW:
|
||||||
|
ohlcv_bars.pop(0)
|
||||||
|
else:
|
||||||
|
# Update current bar
|
||||||
|
bar = ohlcv_bars[-1]
|
||||||
|
bar[2] = max(bar[2], price) # high
|
||||||
|
bar[3] = min(bar[3], price) # low
|
||||||
|
bar[4] = price # close
|
||||||
|
bar[5] += size # volume
|
||||||
|
|
||||||
|
# Start the background thread
|
||||||
|
threading.Thread(target=ws_collector, daemon=True).start()
|
||||||
|
|
||||||
|
# --- Dash App ---
|
||||||
|
app = dash.Dash(__name__)
|
||||||
|
app.layout = html.Div([
|
||||||
|
html.H2('BTC/USDT Price & Prediction (OKX, XGBoost, Trades Aggregated)', style={"textAlign": "center", "margin": 0, "padding": 0}),
|
||||||
|
dcc.Graph(id='live-graph', animate=False, style={"height": "90vh", "width": "100vw", "margin": 0, "padding": 0}),
|
||||||
|
dcc.Interval(
|
||||||
|
id='interval-component',
|
||||||
|
interval=5*1000, # 5 seconds
|
||||||
|
n_intervals=0
|
||||||
|
),
|
||||||
|
html.Div(id='prediction-output', style={"textAlign": "center", "fontSize": 20, "marginTop": 10})
|
||||||
|
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
|
||||||
|
|
||||||
|
@app.callback(
|
||||||
|
[Output('live-graph', 'figure'), Output('prediction-output', 'children')],
|
||||||
|
[Input('interval-component', 'n_intervals')]
|
||||||
|
)
|
||||||
|
def update_graph_live(n):
|
||||||
|
with bar_lock:
|
||||||
|
bars = list(ohlcv_bars)
|
||||||
|
if len(bars) < 2:
|
||||||
|
return go.Figure(), "Waiting for data..."
|
||||||
|
df = pd.DataFrame(bars, columns=["Timestamp", "Open", "High", "Low", "Close", "Volume"])
|
||||||
|
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
|
||||||
|
df = add_features(df)
|
||||||
|
if df[feature_cols].isnull().any().any():
|
||||||
|
pred_text = "Not enough data for prediction."
|
||||||
|
else:
|
||||||
|
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
|
||||||
|
pred_text = f"Last: {last_price:.2f} | Predicted next: {pred_price:.2f} | LogRet: {pred_log_return:.6f}"
|
||||||
|
fig = go.Figure()
|
||||||
|
fig.add_trace(go.Candlestick(
|
||||||
|
x=df["Timestamp"],
|
||||||
|
open=df["Open"],
|
||||||
|
high=df["High"],
|
||||||
|
low=df["Low"],
|
||||||
|
close=df["Close"],
|
||||||
|
name='Candlestick',
|
||||||
|
increasing_line_color='green',
|
||||||
|
decreasing_line_color='red',
|
||||||
|
showlegend=False
|
||||||
|
))
|
||||||
|
fig.update_layout(title='BTC-USDT 1m OHLCV (Aggregated from Trades)', xaxis_title='Time', yaxis_title='Price (USDT)', xaxis_rangeslider_visible=False)
|
||||||
|
return fig, pred_text
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(debug=True)
|
||||||
304
main.py
304
main.py
@ -1,152 +1,152 @@
|
|||||||
from okx_client import OKXClient
|
from okx_client import OKXClient
|
||||||
from market_db import MarketDB
|
from market_db import MarketDB
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import time
|
import time
|
||||||
import signal
|
import signal
|
||||||
|
|
||||||
latest_book = {'bids': [], 'asks': [], 'timestamp': None}
|
latest_book = {'bids': [], 'asks': [], 'timestamp': None}
|
||||||
book_history = deque()
|
book_history = deque()
|
||||||
trade_history = deque()
|
trade_history = deque()
|
||||||
|
|
||||||
TRADE_HISTORY_SECONDS = 60
|
TRADE_HISTORY_SECONDS = 60
|
||||||
BOOK_HISTORY_SECONDS = 5
|
BOOK_HISTORY_SECONDS = 5
|
||||||
|
|
||||||
shutdown_flag = threading.Event()
|
shutdown_flag = threading.Event()
|
||||||
|
|
||||||
def connect(instrument, max_retries=5):
|
def connect(instrument, max_retries=5):
|
||||||
logging.info(f"Connecting to OKX for instrument: {instrument}")
|
logging.info(f"Connecting to OKX for instrument: {instrument}")
|
||||||
retries = 0
|
retries = 0
|
||||||
backoff = 1
|
backoff = 1
|
||||||
while not shutdown_flag.is_set():
|
while not shutdown_flag.is_set():
|
||||||
try:
|
try:
|
||||||
client = OKXClient(authenticate=False)
|
client = OKXClient(authenticate=False)
|
||||||
client.subscribe_trades(instrument)
|
client.subscribe_trades(instrument)
|
||||||
client.subscribe_book(instrument, depth=5, channel="books")
|
client.subscribe_book(instrument, depth=5, channel="books")
|
||||||
logging.info(f"Subscribed to trades and book for {instrument}")
|
logging.info(f"Subscribed to trades and book for {instrument}")
|
||||||
return client
|
return client
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
retries += 1
|
retries += 1
|
||||||
logging.error(f"Failed to connect to OKX: {e}. Retry {retries}/{max_retries} in {backoff}s.")
|
logging.error(f"Failed to connect to OKX: {e}. Retry {retries}/{max_retries} in {backoff}s.")
|
||||||
if retries >= max_retries:
|
if retries >= max_retries:
|
||||||
logging.critical("Max retries reached. Exiting connect loop.")
|
logging.critical("Max retries reached. Exiting connect loop.")
|
||||||
raise
|
raise
|
||||||
time.sleep(backoff)
|
time.sleep(backoff)
|
||||||
backoff = min(backoff * 2, 60) # exponential backoff, max 60s
|
backoff = min(backoff * 2, 60) # exponential backoff, max 60s
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def cleanup(client, db):
|
def cleanup(client, db):
|
||||||
if client and hasattr(client, 'ws') and client.ws:
|
if client and hasattr(client, 'ws') and client.ws:
|
||||||
try:
|
try:
|
||||||
client.ws.close()
|
client.ws.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Error closing websocket: {e}")
|
logging.warning(f"Error closing websocket: {e}")
|
||||||
if db:
|
if db:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
def signal_handler(signum, frame):
|
def signal_handler(signum, frame):
|
||||||
logging.info(f"Received signal {signum}, shutting down...")
|
logging.info(f"Received signal {signum}, shutting down...")
|
||||||
shutdown_flag.set()
|
shutdown_flag.set()
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
instruments = [
|
instruments = [
|
||||||
"ETH-USDT",
|
"ETH-USDT",
|
||||||
"BTC-USDT",
|
"BTC-USDT",
|
||||||
"SOL-USDT",
|
"SOL-USDT",
|
||||||
"DOGE-USDT",
|
"DOGE-USDT",
|
||||||
"TON-USDT",
|
"TON-USDT",
|
||||||
"ETH-USDC",
|
"ETH-USDC",
|
||||||
"SOPH-USDT",
|
"SOPH-USDT",
|
||||||
"PEPE-USDT",
|
"PEPE-USDT",
|
||||||
"BTC-USDC",
|
"BTC-USDC",
|
||||||
"UNI-USDT"
|
"UNI-USDT"
|
||||||
]
|
]
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
||||||
dbs = {}
|
dbs = {}
|
||||||
clients = {}
|
clients = {}
|
||||||
try:
|
try:
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db")
|
dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db")
|
||||||
logging.info(f"Database initialized for {instrument}")
|
logging.info(f"Database initialized for {instrument}")
|
||||||
clients[instrument] = connect(instrument)
|
clients[instrument] = connect(instrument)
|
||||||
|
|
||||||
while not shutdown_flag.is_set():
|
while not shutdown_flag.is_set():
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
client = clients[instrument]
|
client = clients[instrument]
|
||||||
db = dbs[instrument]
|
db = dbs[instrument]
|
||||||
try:
|
try:
|
||||||
data = client.ws.recv()
|
data = client.ws.recv()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"WebSocket disconnected or error for {instrument}: {e}. Reconnecting...")
|
logging.warning(f"WebSocket disconnected or error for {instrument}: {e}. Reconnecting...")
|
||||||
cleanup(client, None)
|
cleanup(client, None)
|
||||||
try:
|
try:
|
||||||
clients[instrument] = connect(instrument)
|
clients[instrument] = connect(instrument)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.")
|
logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.")
|
||||||
continue
|
continue
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if shutdown_flag.is_set():
|
if shutdown_flag.is_set():
|
||||||
break
|
break
|
||||||
if data == '':
|
if data == '':
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
msg = json.loads(data)
|
msg = json.loads(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}")
|
logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if 'arg' in msg and msg['arg'].get('channel') == 'trades':
|
if 'arg' in msg and msg['arg'].get('channel') == 'trades':
|
||||||
for trade in msg.get('data', []):
|
for trade in msg.get('data', []):
|
||||||
db.insert_trade({
|
db.insert_trade({
|
||||||
'instrument': instrument,
|
'instrument': instrument,
|
||||||
'trade_id': trade.get('tradeId'),
|
'trade_id': trade.get('tradeId'),
|
||||||
'price': float(trade.get('px')),
|
'price': float(trade.get('px')),
|
||||||
'size': float(trade.get('sz')),
|
'size': float(trade.get('sz')),
|
||||||
'side': trade.get('side'),
|
'side': trade.get('side'),
|
||||||
'timestamp': trade.get('ts')
|
'timestamp': trade.get('ts')
|
||||||
})
|
})
|
||||||
ts = float(trade.get('ts', time.time() * 1000))
|
ts = float(trade.get('ts', time.time() * 1000))
|
||||||
trade_history.append({
|
trade_history.append({
|
||||||
'price': trade.get('px'),
|
'price': trade.get('px'),
|
||||||
'size': trade.get('sz'),
|
'size': trade.get('sz'),
|
||||||
'side': trade.get('side'),
|
'side': trade.get('side'),
|
||||||
'timestamp': ts
|
'timestamp': ts
|
||||||
})
|
})
|
||||||
elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'):
|
elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'):
|
||||||
for book in msg.get('data', []):
|
for book in msg.get('data', []):
|
||||||
db.insert_book({
|
db.insert_book({
|
||||||
'instrument': instrument,
|
'instrument': instrument,
|
||||||
'bids': book.get('bids'),
|
'bids': book.get('bids'),
|
||||||
'asks': book.get('asks'),
|
'asks': book.get('asks'),
|
||||||
'timestamp': book.get('ts')
|
'timestamp': book.get('ts')
|
||||||
})
|
})
|
||||||
latest_book['bids'] = book.get('bids', [])
|
latest_book['bids'] = book.get('bids', [])
|
||||||
latest_book['asks'] = book.get('asks', [])
|
latest_book['asks'] = book.get('asks', [])
|
||||||
latest_book['timestamp'] = book.get('ts')
|
latest_book['timestamp'] = book.get('ts')
|
||||||
ts = float(book.get('ts', time.time() * 1000))
|
ts = float(book.get('ts', time.time() * 1000))
|
||||||
book_history.append({
|
book_history.append({
|
||||||
'bids': book.get('bids', []),
|
'bids': book.get('bids', []),
|
||||||
'asks': book.get('asks', []),
|
'asks': book.get('asks', []),
|
||||||
'timestamp': ts
|
'timestamp': ts
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
logging.info(f"Unknown message for {instrument}: {msg}")
|
logging.info(f"Unknown message for {instrument}: {msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.critical(f"Fatal error in main: {e}")
|
logging.critical(f"Fatal error in main: {e}")
|
||||||
finally:
|
finally:
|
||||||
for client in clients.values():
|
for client in clients.values():
|
||||||
cleanup(client, None)
|
cleanup(client, None)
|
||||||
for db in dbs.values():
|
for db in dbs.values():
|
||||||
cleanup(None, db)
|
cleanup(None, db)
|
||||||
logging.info('Shutdown complete.')
|
logging.info('Shutdown complete.')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
154
market_db.py
154
market_db.py
@ -1,77 +1,77 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
class MarketDB:
|
class MarketDB:
|
||||||
def __init__(self, market: str, db_dir: str = ""):
|
def __init__(self, market: str, db_dir: str = ""):
|
||||||
db_name = f"{market}.db"
|
db_name = f"{market}.db"
|
||||||
db_path = db_name if not db_dir else f"{db_dir.rstrip('/')}/{db_name}"
|
db_path = db_name if not db_dir else f"{db_dir.rstrip('/')}/{db_name}"
|
||||||
if db_dir:
|
if db_dir:
|
||||||
os.makedirs(db_dir, exist_ok=True)
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
self.conn = sqlite3.connect(db_path)
|
self.conn = sqlite3.connect(db_path)
|
||||||
logging.info(f"Connected to database at {db_path}")
|
logging.info(f"Connected to database at {db_path}")
|
||||||
self._create_tables()
|
self._create_tables()
|
||||||
|
|
||||||
def _create_tables(self):
|
def _create_tables(self):
|
||||||
cursor = self.conn.cursor()
|
cursor = self.conn.cursor()
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
CREATE TABLE IF NOT EXISTS trades (
|
CREATE TABLE IF NOT EXISTS trades (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
instrument TEXT,
|
instrument TEXT,
|
||||||
trade_id TEXT,
|
trade_id TEXT,
|
||||||
price REAL,
|
price REAL,
|
||||||
size REAL,
|
size REAL,
|
||||||
side TEXT,
|
side TEXT,
|
||||||
timestamp TEXT
|
timestamp TEXT
|
||||||
)
|
)
|
||||||
''')
|
''')
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
CREATE TABLE IF NOT EXISTS book (
|
CREATE TABLE IF NOT EXISTS book (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
instrument TEXT,
|
instrument TEXT,
|
||||||
bids TEXT,
|
bids TEXT,
|
||||||
asks TEXT,
|
asks TEXT,
|
||||||
timestamp TEXT
|
timestamp TEXT
|
||||||
)
|
)
|
||||||
''')
|
''')
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
logging.info("Database tables ensured.")
|
logging.info("Database tables ensured.")
|
||||||
|
|
||||||
def insert_trade(self, trade: Dict[str, Any]):
|
def insert_trade(self, trade: Dict[str, Any]):
|
||||||
cursor = self.conn.cursor()
|
cursor = self.conn.cursor()
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
INSERT INTO trades (instrument, trade_id, price, size, side, timestamp)
|
INSERT INTO trades (instrument, trade_id, price, size, side, timestamp)
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
''', (
|
''', (
|
||||||
trade.get('instrument'),
|
trade.get('instrument'),
|
||||||
trade.get('trade_id'),
|
trade.get('trade_id'),
|
||||||
trade.get('price'),
|
trade.get('price'),
|
||||||
trade.get('size'),
|
trade.get('size'),
|
||||||
trade.get('side'),
|
trade.get('side'),
|
||||||
trade.get('timestamp')
|
trade.get('timestamp')
|
||||||
))
|
))
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
logging.debug(f"Inserted trade: {trade}")
|
logging.debug(f"Inserted trade: {trade}")
|
||||||
|
|
||||||
def insert_book(self, book: Dict[str, Any]):
|
def insert_book(self, book: Dict[str, Any]):
|
||||||
cursor = self.conn.cursor()
|
cursor = self.conn.cursor()
|
||||||
bids = book.get('bids', [])
|
bids = book.get('bids', [])
|
||||||
asks = book.get('asks', [])
|
asks = book.get('asks', [])
|
||||||
best_bid = next((b for b in bids if float(b[1]) > 0), ['-', '-'])
|
best_bid = next((b for b in bids if float(b[1]) > 0), ['-', '-'])
|
||||||
best_ask = next((a for a in asks if float(a[1]) > 0), ['-', '-'])
|
best_ask = next((a for a in asks if float(a[1]) > 0), ['-', '-'])
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
INSERT INTO book (instrument, bids, asks, timestamp)
|
INSERT INTO book (instrument, bids, asks, timestamp)
|
||||||
VALUES (?, ?, ?, ?)
|
VALUES (?, ?, ?, ?)
|
||||||
''', (
|
''', (
|
||||||
book.get('instrument'),
|
book.get('instrument'),
|
||||||
str(bids),
|
str(bids),
|
||||||
str(asks),
|
str(asks),
|
||||||
book.get('timestamp')
|
book.get('timestamp')
|
||||||
))
|
))
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
logging.debug(f"Inserted book: {book.get('instrument', 'N/A')} ts:{book.get('timestamp', 'N/A')} bid:{best_bid} ask:{best_ask}")
|
logging.debug(f"Inserted book: {book.get('instrument', 'N/A')} ts:{book.get('timestamp', 'N/A')} bid:{best_bid} ask:{best_ask}")
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
logging.info("Database connection closed.")
|
logging.info("Database connection closed.")
|
||||||
|
|||||||
466
okx_client.py
466
okx_client.py
@ -1,233 +1,233 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import hmac
|
import hmac
|
||||||
import hashlib
|
import hashlib
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import threading
|
import threading
|
||||||
import requests
|
import requests
|
||||||
import websocket
|
import websocket
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
class OKXClient:
|
class OKXClient:
|
||||||
PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public"
|
PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public"
|
||||||
PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private"
|
PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private"
|
||||||
REST_URL = "https://www.okx.com"
|
REST_URL = "https://www.okx.com"
|
||||||
|
|
||||||
def __init__(self, authenticate: bool = True):
|
def __init__(self, authenticate: bool = True):
|
||||||
self.authenticated = False
|
self.authenticated = False
|
||||||
self.api_key = None
|
self.api_key = None
|
||||||
self.api_secret = None
|
self.api_secret = None
|
||||||
self.api_passphrase = None
|
self.api_passphrase = None
|
||||||
self.ws = None
|
self.ws = None
|
||||||
self.ws_private = None
|
self.ws_private = None
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._private_lock = threading.Lock()
|
self._private_lock = threading.Lock()
|
||||||
|
|
||||||
if authenticate:
|
if authenticate:
|
||||||
config_path = os.path.join(os.path.dirname(__file__), '../credentials/okx_creds.json')
|
config_path = os.path.join(os.path.dirname(__file__), '../credentials/okx_creds.json')
|
||||||
try:
|
try:
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise FileNotFoundError(f"Credentials file not found at {config_path}. Please create it with the required keys.")
|
raise FileNotFoundError(f"Credentials file not found at {config_path}. Please create it with the required keys.")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise ValueError(f"Credentials file at {config_path} is not valid JSON.")
|
raise ValueError(f"Credentials file at {config_path} is not valid JSON.")
|
||||||
|
|
||||||
self.api_key = config.get("OKX_API_KEY")
|
self.api_key = config.get("OKX_API_KEY")
|
||||||
self.api_secret = config.get("OKX_API_SECRET")
|
self.api_secret = config.get("OKX_API_SECRET")
|
||||||
self.api_passphrase = config.get("OKX_API_PASSPHRASE")
|
self.api_passphrase = config.get("OKX_API_PASSPHRASE")
|
||||||
|
|
||||||
if not self.api_key or not self.api_secret or not self.api_passphrase:
|
if not self.api_key or not self.api_secret or not self.api_passphrase:
|
||||||
raise ValueError("API key, secret, and passphrase must be set in the credentials JSON file.")
|
raise ValueError("API key, secret, and passphrase must be set in the credentials JSON file.")
|
||||||
|
|
||||||
self._authenticate()
|
self._authenticate()
|
||||||
self._connect_ws()
|
self._connect_ws()
|
||||||
|
|
||||||
def _connect_ws(self):
|
def _connect_ws(self):
|
||||||
if self.ws is None:
|
if self.ws is None:
|
||||||
self.ws = websocket.create_connection(self.PUBLIC_WS_URL, timeout=10)
|
self.ws = websocket.create_connection(self.PUBLIC_WS_URL, timeout=10)
|
||||||
if self.authenticated and self.api_key and self.api_secret and self.api_passphrase and self.ws_private is None:
|
if self.authenticated and self.api_key and self.api_secret and self.api_passphrase and self.ws_private is None:
|
||||||
self.ws_private = websocket.create_connection(self.PRIVATE_WS_URL, timeout=10)
|
self.ws_private = websocket.create_connection(self.PRIVATE_WS_URL, timeout=10)
|
||||||
|
|
||||||
def _get_timestamp(self):
|
def _get_timestamp(self):
|
||||||
return str(round(time.time(), 3))
|
return str(round(time.time(), 3))
|
||||||
|
|
||||||
def _sign(self, timestamp, method, request_path, body):
|
def _sign(self, timestamp, method, request_path, body):
|
||||||
if not body:
|
if not body:
|
||||||
body = ''
|
body = ''
|
||||||
message = f'{timestamp}{method}{request_path}{body}'
|
message = f'{timestamp}{method}{request_path}{body}'
|
||||||
mac = hmac.new(self.api_secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256)
|
mac = hmac.new(self.api_secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256)
|
||||||
return base64.b64encode(mac.digest()).decode()
|
return base64.b64encode(mac.digest()).decode()
|
||||||
|
|
||||||
def _authenticate(self):
|
def _authenticate(self):
|
||||||
import websocket
|
import websocket
|
||||||
timestamp = self._get_timestamp()
|
timestamp = self._get_timestamp()
|
||||||
sign = self._sign(timestamp, 'GET', '/users/self/verify', '')
|
sign = self._sign(timestamp, 'GET', '/users/self/verify', '')
|
||||||
login_params = {
|
login_params = {
|
||||||
"op": "login",
|
"op": "login",
|
||||||
"args": [{
|
"args": [{
|
||||||
"apiKey": self.api_key,
|
"apiKey": self.api_key,
|
||||||
"passphrase": self.api_passphrase,
|
"passphrase": self.api_passphrase,
|
||||||
"timestamp": timestamp,
|
"timestamp": timestamp,
|
||||||
"sign": sign
|
"sign": sign
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
self.ws_private.send(json.dumps(login_params))
|
self.ws_private.send(json.dumps(login_params))
|
||||||
logging.info("Waiting for login response from OKX...")
|
logging.info("Waiting for login response from OKX...")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
resp = self.ws_private.recv()
|
resp = self.ws_private.recv()
|
||||||
logging.debug(f"Received from OKX private WS: {resp}")
|
logging.debug(f"Received from OKX private WS: {resp}")
|
||||||
if not resp:
|
if not resp:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
msg = json.loads(resp)
|
msg = json.loads(resp)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warning(f"Non-JSON message received: {resp}")
|
logging.warning(f"Non-JSON message received: {resp}")
|
||||||
continue
|
continue
|
||||||
if msg.get("event") == "login":
|
if msg.get("event") == "login":
|
||||||
if msg.get("code") == "0":
|
if msg.get("code") == "0":
|
||||||
logging.info("OKX WebSocket login successful.")
|
logging.info("OKX WebSocket login successful.")
|
||||||
self.authenticated = True
|
self.authenticated = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise Exception(f"WebSocket authentication failed: {msg}")
|
raise Exception(f"WebSocket authentication failed: {msg}")
|
||||||
except websocket._exceptions.WebSocketConnectionClosedException as e:
|
except websocket._exceptions.WebSocketConnectionClosedException as e:
|
||||||
logging.error(f"WebSocket connection closed during authentication: {e}")
|
logging.error(f"WebSocket connection closed during authentication: {e}")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Exception during authentication: {e}")
|
logging.error(f"Exception during authentication: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def subscribe_candlesticks(self, instrument="BTC-USDT", timeframe="1m"):
|
def subscribe_candlesticks(self, instrument="BTC-USDT", timeframe="1m"):
|
||||||
# OKX uses candle1m, candle5m, etc.
|
# OKX uses candle1m, candle5m, etc.
|
||||||
tf_map = {"1m": "candle1m", "5m": "candle5m", "15m": "candle15m", "1h": "candle1H"}
|
tf_map = {"1m": "candle1m", "5m": "candle5m", "15m": "candle15m", "1h": "candle1H"}
|
||||||
channel = tf_map.get(timeframe, f"candle{timeframe}")
|
channel = tf_map.get(timeframe, f"candle{timeframe}")
|
||||||
params = {
|
params = {
|
||||||
"op": "subscribe",
|
"op": "subscribe",
|
||||||
"args": [{"channel": channel, "instId": instrument}]
|
"args": [{"channel": channel, "instId": instrument}]
|
||||||
}
|
}
|
||||||
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
||||||
self.ws.send(json.dumps(params))
|
self.ws.send(json.dumps(params))
|
||||||
|
|
||||||
def subscribe_trades(self, instrument="BTC-USDT"):
|
def subscribe_trades(self, instrument="BTC-USDT"):
|
||||||
params = {
|
params = {
|
||||||
"op": "subscribe",
|
"op": "subscribe",
|
||||||
"args": [{"channel": "trades", "instId": instrument}]
|
"args": [{"channel": "trades", "instId": instrument}]
|
||||||
}
|
}
|
||||||
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
||||||
self.ws.send(json.dumps(params))
|
self.ws.send(json.dumps(params))
|
||||||
|
|
||||||
def subscribe_ticker(self, instrument="BTC-USDT"):
|
def subscribe_ticker(self, instrument="BTC-USDT"):
|
||||||
params = {
|
params = {
|
||||||
"op": "subscribe",
|
"op": "subscribe",
|
||||||
"args": [{"channel": "tickers", "instId": instrument}]
|
"args": [{"channel": "tickers", "instId": instrument}]
|
||||||
}
|
}
|
||||||
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
||||||
self.ws.send(json.dumps(params))
|
self.ws.send(json.dumps(params))
|
||||||
|
|
||||||
def subscribe_book(self, instrument="BTC-USDT", depth=5, channel="books5"):
|
def subscribe_book(self, instrument="BTC-USDT", depth=5, channel="books5"):
|
||||||
# OKX supports books5, books50, books-l2-tbt
|
# OKX supports books5, books50, books-l2-tbt
|
||||||
# channel = "books5" if depth <= 5 else "books50"
|
# channel = "books5" if depth <= 5 else "books50"
|
||||||
params = {
|
params = {
|
||||||
"op": "subscribe",
|
"op": "subscribe",
|
||||||
"args": [{"channel": channel, "instId": instrument}]
|
"args": [{"channel": channel, "instId": instrument}]
|
||||||
}
|
}
|
||||||
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
||||||
self.ws.send(json.dumps(params))
|
self.ws.send(json.dumps(params))
|
||||||
|
|
||||||
def subscribe_user_order(self):
|
def subscribe_user_order(self):
|
||||||
if not self.authenticated:
|
if not self.authenticated:
|
||||||
logging.warning("Attempted to subscribe to user order channel without authentication.")
|
logging.warning("Attempted to subscribe to user order channel without authentication.")
|
||||||
return
|
return
|
||||||
params = {
|
params = {
|
||||||
"op": "subscribe",
|
"op": "subscribe",
|
||||||
"args": [{"channel": "orders", "instType": "SPOT"}]
|
"args": [{"channel": "orders", "instType": "SPOT"}]
|
||||||
}
|
}
|
||||||
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
|
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
|
||||||
self.ws_private.send(json.dumps(params))
|
self.ws_private.send(json.dumps(params))
|
||||||
|
|
||||||
def subscribe_user_trade(self):
|
def subscribe_user_trade(self):
|
||||||
if not self.authenticated:
|
if not self.authenticated:
|
||||||
logging.warning("Attempted to subscribe to user trade channel without authentication.")
|
logging.warning("Attempted to subscribe to user trade channel without authentication.")
|
||||||
return
|
return
|
||||||
params = {
|
params = {
|
||||||
"op": "subscribe",
|
"op": "subscribe",
|
||||||
"args": [{"channel": "trades", "instType": "SPOT"}]
|
"args": [{"channel": "trades", "instType": "SPOT"}]
|
||||||
}
|
}
|
||||||
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
|
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
|
||||||
self.ws_private.send(json.dumps(params))
|
self.ws_private.send(json.dumps(params))
|
||||||
|
|
||||||
def subscribe_user_balance(self):
|
def subscribe_user_balance(self):
|
||||||
if not self.authenticated:
|
if not self.authenticated:
|
||||||
logging.warning("Attempted to subscribe to user balance channel without authentication.")
|
logging.warning("Attempted to subscribe to user balance channel without authentication.")
|
||||||
return
|
return
|
||||||
params = {
|
params = {
|
||||||
"op": "subscribe",
|
"op": "subscribe",
|
||||||
"args": [{"channel": "balance_and_position"}]
|
"args": [{"channel": "balance_and_position"}]
|
||||||
}
|
}
|
||||||
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
|
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
|
||||||
self.ws_private.send(json.dumps(params))
|
self.ws_private.send(json.dumps(params))
|
||||||
|
|
||||||
def get_balance(self, currency=None):
|
def get_balance(self, currency=None):
|
||||||
url = f"{self.REST_URL}/api/v5/account/balance"
|
url = f"{self.REST_URL}/api/v5/account/balance"
|
||||||
timestamp = self._get_timestamp()
|
timestamp = self._get_timestamp()
|
||||||
method = "GET"
|
method = "GET"
|
||||||
request_path = "/api/v5/account/balance"
|
request_path = "/api/v5/account/balance"
|
||||||
body = ''
|
body = ''
|
||||||
sign = self._sign(timestamp, method, request_path, body)
|
sign = self._sign(timestamp, method, request_path, body)
|
||||||
headers = {
|
headers = {
|
||||||
"OK-ACCESS-KEY": self.api_key,
|
"OK-ACCESS-KEY": self.api_key,
|
||||||
"OK-ACCESS-SIGN": sign,
|
"OK-ACCESS-SIGN": sign,
|
||||||
"OK-ACCESS-TIMESTAMP": timestamp,
|
"OK-ACCESS-TIMESTAMP": timestamp,
|
||||||
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
|
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
resp = requests.get(url, headers=headers)
|
resp = requests.get(url, headers=headers)
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
balances = data.get("data", [{}])[0].get("details", [])
|
balances = data.get("data", [{}])[0].get("details", [])
|
||||||
if currency:
|
if currency:
|
||||||
return [b for b in balances if b.get("ccy") == currency]
|
return [b for b in balances if b.get("ccy") == currency]
|
||||||
return balances
|
return balances
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def place_order(self, side, amount, instrument="BTC-USDT"):
|
def place_order(self, side, amount, instrument="BTC-USDT"):
|
||||||
url = f"{self.REST_URL}/api/v5/trade/order"
|
url = f"{self.REST_URL}/api/v5/trade/order"
|
||||||
timestamp = self._get_timestamp()
|
timestamp = self._get_timestamp()
|
||||||
method = "POST"
|
method = "POST"
|
||||||
request_path = "/api/v5/trade/order"
|
request_path = "/api/v5/trade/order"
|
||||||
body_dict = {
|
body_dict = {
|
||||||
"instId": instrument,
|
"instId": instrument,
|
||||||
"tdMode": "cash",
|
"tdMode": "cash",
|
||||||
"side": side.lower(),
|
"side": side.lower(),
|
||||||
"ordType": "market",
|
"ordType": "market",
|
||||||
"sz": str(amount)
|
"sz": str(amount)
|
||||||
}
|
}
|
||||||
body = json.dumps(body_dict)
|
body = json.dumps(body_dict)
|
||||||
sign = self._sign(timestamp, method, request_path, body)
|
sign = self._sign(timestamp, method, request_path, body)
|
||||||
headers = {
|
headers = {
|
||||||
"OK-ACCESS-KEY": self.api_key,
|
"OK-ACCESS-KEY": self.api_key,
|
||||||
"OK-ACCESS-SIGN": sign,
|
"OK-ACCESS-SIGN": sign,
|
||||||
"OK-ACCESS-TIMESTAMP": timestamp,
|
"OK-ACCESS-TIMESTAMP": timestamp,
|
||||||
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
|
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
resp = requests.post(url, headers=headers, data=body)
|
resp = requests.post(url, headers=headers, data=body)
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
def buy_btc(self, amount, instrument="BTC-USDT"):
|
def buy_btc(self, amount, instrument="BTC-USDT"):
|
||||||
return self.place_order("buy", amount, instrument)
|
return self.place_order("buy", amount, instrument)
|
||||||
|
|
||||||
def sell_btc(self, amount, instrument="BTC-USDT"):
|
def sell_btc(self, amount, instrument="BTC-USDT"):
|
||||||
return self.place_order("sell", amount, instrument)
|
return self.place_order("sell", amount, instrument)
|
||||||
|
|
||||||
def get_instruments(self):
|
def get_instruments(self):
|
||||||
url = f"{self.REST_URL}/api/v5/public/instruments?instType=SPOT"
|
url = f"{self.REST_URL}/api/v5/public/instruments?instType=SPOT"
|
||||||
resp = requests.get(url)
|
resp = requests.get(url)
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
return data.get("data", [])
|
return data.get("data", [])
|
||||||
return []
|
return []
|
||||||
|
|||||||
122
test_predictor.py
Normal file
122
test_predictor.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import time
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import requests
|
||||||
|
import xgboost as xgb
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
# --- CONFIG ---
|
||||||
|
OKX_REST_URL = 'https://www.okx.com/api/v5/market/candles'
|
||||||
|
SYMBOL = 'BTC-USDT'
|
||||||
|
BAR = '1m'
|
||||||
|
HIST_MINUTES = 250 # Number of minutes of history to fetch for features
|
||||||
|
MODEL_PATH = 'data/xgboost_model.json'
|
||||||
|
|
||||||
|
# --- Fetch recent candles from OKX REST API ---
|
||||||
|
def fetch_recent_candles(symbol, bar, limit=HIST_MINUTES):
|
||||||
|
params = {
|
||||||
|
'instId': symbol,
|
||||||
|
'bar': bar,
|
||||||
|
'limit': str(limit)
|
||||||
|
}
|
||||||
|
resp = requests.get(OKX_REST_URL, params=params)
|
||||||
|
data = resp.json()
|
||||||
|
if data['code'] != '0':
|
||||||
|
raise Exception(f"OKX API error: {data['msg']}")
|
||||||
|
# OKX returns most recent first, reverse to chronological
|
||||||
|
candles = data['data'][::-1]
|
||||||
|
df = pd.DataFrame(candles)
|
||||||
|
# OKX columns: [ts, o, h, l, c, vol, volCcy, confirm, ...] (see API docs)
|
||||||
|
# We'll use: ts, o, h, l, c, vol
|
||||||
|
col_map = {
|
||||||
|
0: 'Timestamp',
|
||||||
|
1: 'Open',
|
||||||
|
2: 'High',
|
||||||
|
3: 'Low',
|
||||||
|
4: 'Close',
|
||||||
|
5: 'Volume',
|
||||||
|
}
|
||||||
|
df = df.rename(columns={str(k): v for k, v in col_map.items()})
|
||||||
|
# If columns are not named, use integer index
|
||||||
|
for k, v in col_map.items():
|
||||||
|
if v not in df.columns:
|
||||||
|
df[v] = df.iloc[:, k]
|
||||||
|
df = df[['Timestamp', 'Open', 'High', 'Low', 'Close', 'Volume']]
|
||||||
|
df['Timestamp'] = pd.to_datetime(df['Timestamp'].astype(np.int64), unit='ms')
|
||||||
|
for col in ['Open', 'High', 'Low', 'Close', 'Volume']:
|
||||||
|
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||||
|
return df
|
||||||
|
|
||||||
|
# --- Feature Engineering (minimal, real-time) ---
|
||||||
|
def add_features(df):
|
||||||
|
# Log return (target, not used for prediction)
|
||||||
|
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||||
|
# RSI (14)
|
||||||
|
def calc_rsi(close, window=14):
|
||||||
|
delta = close.diff()
|
||||||
|
up = delta.clip(lower=0)
|
||||||
|
down = -1 * delta.clip(upper=0)
|
||||||
|
ma_up = up.rolling(window=window, min_periods=window).mean()
|
||||||
|
ma_down = down.rolling(window=window, min_periods=window).mean()
|
||||||
|
rs = ma_up / ma_down
|
||||||
|
return 100 - (100 / (1 + rs))
|
||||||
|
df['rsi'] = calc_rsi(df['Close'])
|
||||||
|
# EMA 14
|
||||||
|
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
|
||||||
|
# SMA 50, 200
|
||||||
|
df['sma_50'] = df['Close'].rolling(window=50).mean()
|
||||||
|
df['sma_200'] = df['Close'].rolling(window=200).mean()
|
||||||
|
# ATR 14
|
||||||
|
high_low = df['High'] - df['Low']
|
||||||
|
high_close = np.abs(df['High'] - df['Close'].shift(1))
|
||||||
|
low_close = np.abs(df['Low'] - df['Close'].shift(1))
|
||||||
|
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
||||||
|
df['atr'] = tr.rolling(window=14).mean()
|
||||||
|
# ROC 10
|
||||||
|
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
|
||||||
|
# DPO 20
|
||||||
|
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
|
||||||
|
# Hour
|
||||||
|
df['hour'] = df['Timestamp'].dt.hour
|
||||||
|
# Add more features as needed (match with main.py)
|
||||||
|
return df
|
||||||
|
|
||||||
|
# --- Load model and feature columns ---
|
||||||
|
def load_model_and_features(model_path):
|
||||||
|
model = xgb.Booster()
|
||||||
|
model.load_model(model_path)
|
||||||
|
# Try to infer feature names from main.py (hardcoded for now)
|
||||||
|
feature_cols = [
|
||||||
|
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
|
||||||
|
]
|
||||||
|
return model, feature_cols
|
||||||
|
|
||||||
|
# --- Predict next log return and price ---
|
||||||
|
def predict_next(df, model, feature_cols):
|
||||||
|
# Use the last row for prediction
|
||||||
|
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
|
||||||
|
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
|
||||||
|
pred_log_return = model.predict(dmatrix)[0]
|
||||||
|
last_price = df['Close'].iloc[-1]
|
||||||
|
pred_price = last_price * np.exp(pred_log_return)
|
||||||
|
return pred_price, pred_log_return, last_price
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print('Fetching recent candles from OKX...')
|
||||||
|
df = fetch_recent_candles(SYMBOL, BAR)
|
||||||
|
df = add_features(df)
|
||||||
|
model, feature_cols = load_model_and_features(MODEL_PATH)
|
||||||
|
print('Waiting for new candle...')
|
||||||
|
last_timestamp = df['Timestamp'].iloc[-1]
|
||||||
|
while True:
|
||||||
|
time.sleep(5)
|
||||||
|
new_df = fetch_recent_candles(SYMBOL, BAR, limit=HIST_MINUTES)
|
||||||
|
if new_df['Timestamp'].iloc[-1] > last_timestamp:
|
||||||
|
df = new_df
|
||||||
|
df = add_features(df)
|
||||||
|
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
|
||||||
|
print(f"[{df['Timestamp'].iloc[-1]}] Last price: {last_price:.2f} | Predicted next price: {pred_price:.2f} | Predicted log return: {pred_log_return:.6f}")
|
||||||
|
last_timestamp = df['Timestamp'].iloc[-1]
|
||||||
|
else:
|
||||||
|
print('No new candle yet...')
|
||||||
216
visualizer.py
216
visualizer.py
@ -1,108 +1,108 @@
|
|||||||
import dash
|
import dash
|
||||||
from dash import dcc, html
|
from dash import dcc, html
|
||||||
from dash.dependencies import Output, Input
|
from dash.dependencies import Output, Input
|
||||||
import plotly.graph_objs as go
|
import plotly.graph_objs as go
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
def run_dash(book_history, trade_history, BOOK_HISTORY_SECONDS=5, TRADE_HISTORY_SECONDS=60):
|
def run_dash(book_history, trade_history, BOOK_HISTORY_SECONDS=5, TRADE_HISTORY_SECONDS=60):
|
||||||
app = dash.Dash(__name__)
|
app = dash.Dash(__name__)
|
||||||
app.layout = html.Div([
|
app.layout = html.Div([
|
||||||
html.H1("Order Book Depth Chart", style={"textAlign": "center", "color": "#222"}),
|
html.H1("Order Book Depth Chart", style={"textAlign": "center", "color": "#222"}),
|
||||||
dcc.Graph(id='order-book-graph', style={"height": "90vh", "width": "100vw"}),
|
dcc.Graph(id='order-book-graph', style={"height": "90vh", "width": "100vw"}),
|
||||||
dcc.Interval(id='interval-component', interval=2*1000, n_intervals=0)
|
dcc.Interval(id='interval-component', interval=2*1000, n_intervals=0)
|
||||||
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
|
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
|
||||||
|
|
||||||
@app.callback(
|
@app.callback(
|
||||||
[Output('order-book-graph', 'figure')],
|
[Output('order-book-graph', 'figure')],
|
||||||
[Input('interval-component', 'n_intervals')]
|
[Input('interval-component', 'n_intervals')]
|
||||||
)
|
)
|
||||||
def update_graphs(n):
|
def update_graphs(n):
|
||||||
now = time.time() * 1000 # current time in ms
|
now = time.time() * 1000 # current time in ms
|
||||||
|
|
||||||
# Prune book_history to only keep last BOOK_HISTORY_SECONDS
|
# Prune book_history to only keep last BOOK_HISTORY_SECONDS
|
||||||
while book_history and now - book_history[0]['timestamp'] > BOOK_HISTORY_SECONDS * 1000:
|
while book_history and now - book_history[0]['timestamp'] > BOOK_HISTORY_SECONDS * 1000:
|
||||||
book_history.popleft()
|
book_history.popleft()
|
||||||
|
|
||||||
# Prune trade_history to only keep last TRADE_HISTORY_SECONDS
|
# Prune trade_history to only keep last TRADE_HISTORY_SECONDS
|
||||||
while trade_history and now - float(trade_history[0]['timestamp']) > TRADE_HISTORY_SECONDS * 1000:
|
while trade_history and now - float(trade_history[0]['timestamp']) > TRADE_HISTORY_SECONDS * 1000:
|
||||||
trade_history.popleft()
|
trade_history.popleft()
|
||||||
|
|
||||||
# Aggregate bids/asks from book_history
|
# Aggregate bids/asks from book_history
|
||||||
bids_dict = {}
|
bids_dict = {}
|
||||||
asks_dict = {}
|
asks_dict = {}
|
||||||
|
|
||||||
for book in book_history:
|
for book in book_history:
|
||||||
for price, size, *_ in book['bids']:
|
for price, size, *_ in book['bids']:
|
||||||
price = float(price)
|
price = float(price)
|
||||||
size = float(size)
|
size = float(size)
|
||||||
bids_dict[price] = bids_dict.get(price, 0) + size
|
bids_dict[price] = bids_dict.get(price, 0) + size
|
||||||
|
|
||||||
for price, size, *_ in book['asks']:
|
for price, size, *_ in book['asks']:
|
||||||
price = float(price)
|
price = float(price)
|
||||||
size = float(size)
|
size = float(size)
|
||||||
asks_dict[price] = asks_dict.get(price, 0) + size
|
asks_dict[price] = asks_dict.get(price, 0) + size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Prepare and sort bids/asks
|
# Prepare and sort bids/asks
|
||||||
bids = sorted([[p, s] for p, s in bids_dict.items()], reverse=True)
|
bids = sorted([[p, s] for p, s in bids_dict.items()], reverse=True)
|
||||||
asks = sorted([[p, s] for p, s in asks_dict.items()])
|
asks = sorted([[p, s] for p, s in asks_dict.items()])
|
||||||
|
|
||||||
# Cumulative sum
|
# Cumulative sum
|
||||||
bid_prices = [b[0] for b in bids]
|
bid_prices = [b[0] for b in bids]
|
||||||
bid_sizes = [b[1] for b in bids]
|
bid_sizes = [b[1] for b in bids]
|
||||||
ask_prices = [a[0] for a in asks]
|
ask_prices = [a[0] for a in asks]
|
||||||
ask_sizes = [a[1] for a in asks]
|
ask_sizes = [a[1] for a in asks]
|
||||||
bid_cumsum = [sum(bid_sizes[:i+1]) for i in range(len(bid_sizes))]
|
bid_cumsum = [sum(bid_sizes[:i+1]) for i in range(len(bid_sizes))]
|
||||||
ask_cumsum = [sum(ask_sizes[:i+1]) for i in range(len(ask_sizes))]
|
ask_cumsum = [sum(ask_sizes[:i+1]) for i in range(len(ask_sizes))]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
bid_prices, bid_cumsum, ask_prices, ask_cumsum = [], [], [], []
|
bid_prices, bid_cumsum, ask_prices, ask_cumsum = [], [], [], []
|
||||||
|
|
||||||
fig = go.Figure()
|
fig = go.Figure()
|
||||||
|
|
||||||
# Add order book lines (primary y-axis)
|
# Add order book lines (primary y-axis)
|
||||||
fig.add_trace(go.Scatter(
|
fig.add_trace(go.Scatter(
|
||||||
x=bid_prices, y=bid_cumsum, mode='lines', name='Bids',
|
x=bid_prices, y=bid_cumsum, mode='lines', name='Bids',
|
||||||
line=dict(color='green'), fill='tozeroy', yaxis='y1'
|
line=dict(color='green'), fill='tozeroy', yaxis='y1'
|
||||||
))
|
))
|
||||||
fig.add_trace(go.Scatter(
|
fig.add_trace(go.Scatter(
|
||||||
x=ask_prices, y=ask_cumsum, mode='lines', name='Asks',
|
x=ask_prices, y=ask_cumsum, mode='lines', name='Asks',
|
||||||
line=dict(color='red'), fill='tozeroy', yaxis='y1'
|
line=dict(color='red'), fill='tozeroy', yaxis='y1'
|
||||||
))
|
))
|
||||||
|
|
||||||
trade_volume_by_price = defaultdict(float)
|
trade_volume_by_price = defaultdict(float)
|
||||||
|
|
||||||
for trade in trade_history:
|
for trade in trade_history:
|
||||||
price_bin = round(float(trade['price']), 2)
|
price_bin = round(float(trade['price']), 2)
|
||||||
trade_volume_by_price[price_bin] += float(trade['size'])
|
trade_volume_by_price[price_bin] += float(trade['size'])
|
||||||
|
|
||||||
prices = list(trade_volume_by_price.keys())
|
prices = list(trade_volume_by_price.keys())
|
||||||
volumes = list(trade_volume_by_price.values())
|
volumes = list(trade_volume_by_price.values())
|
||||||
|
|
||||||
# Sort by price for display
|
# Sort by price for display
|
||||||
sorted_pairs = sorted(zip(prices, volumes))
|
sorted_pairs = sorted(zip(prices, volumes))
|
||||||
prices = [p for p, v in sorted_pairs]
|
prices = [p for p, v in sorted_pairs]
|
||||||
volumes = [v for p, v in sorted_pairs]
|
volumes = [v for p, v in sorted_pairs]
|
||||||
|
|
||||||
# Add trade volume bars (secondary y-axis)
|
# Add trade volume bars (secondary y-axis)
|
||||||
fig.add_trace(go.Bar(
|
fig.add_trace(go.Bar(
|
||||||
x=prices, y=volumes, marker_color='#7ec8e3', name='Trade Volume',
|
x=prices, y=volumes, marker_color='#7ec8e3', name='Trade Volume',
|
||||||
opacity=0.7, yaxis='y2'
|
opacity=0.7, yaxis='y2'
|
||||||
))
|
))
|
||||||
|
|
||||||
# Update layout for dual y-axes
|
# Update layout for dual y-axes
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title='Order Book Depth & Realized Trade Volume by Price',
|
title='Order Book Depth & Realized Trade Volume by Price',
|
||||||
xaxis=dict(title='Price'),
|
xaxis=dict(title='Price'),
|
||||||
yaxis=dict(title='Cumulative Size', side='left'),
|
yaxis=dict(title='Cumulative Size', side='left'),
|
||||||
yaxis2=dict(
|
yaxis2=dict(
|
||||||
title='Traded Volume',
|
title='Traded Volume',
|
||||||
overlaying='y',
|
overlaying='y',
|
||||||
side='right',
|
side='right',
|
||||||
showgrid=False
|
showgrid=False
|
||||||
),
|
),
|
||||||
template='plotly_dark'
|
template='plotly_dark'
|
||||||
)
|
)
|
||||||
return [fig]
|
return [fig]
|
||||||
|
|
||||||
app.run(debug=True, use_reloader=False)
|
app.run(debug=True, use_reloader=False)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user