testing live plot
This commit is contained in:
parent
8f96e14b8b
commit
f534825e53
150
live_plot.py
Normal file
150
live_plot.py
Normal file
@ -0,0 +1,150 @@
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import xgboost as xgb
|
||||
from datetime import datetime
|
||||
from okx_client import OKXClient
|
||||
import dash
|
||||
from dash import dcc, html
|
||||
from dash.dependencies import Output, Input
|
||||
import plotly.graph_objs as go
|
||||
import websocket
|
||||
|
||||
# --- Prediction utilities (from test_predictor.py) ---
|
||||
def add_features(df):
|
||||
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||
def calc_rsi(close, window=14):
|
||||
delta = close.diff()
|
||||
up = delta.clip(lower=0)
|
||||
down = -1 * delta.clip(upper=0)
|
||||
ma_up = up.rolling(window=window, min_periods=window).mean()
|
||||
ma_down = down.rolling(window=window, min_periods=window).mean()
|
||||
rs = ma_up / ma_down
|
||||
return 100 - (100 / (1 + rs))
|
||||
df['rsi'] = calc_rsi(df['Close'])
|
||||
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
|
||||
df['sma_50'] = df['Close'].rolling(window=50).mean()
|
||||
df['sma_200'] = df['Close'].rolling(window=200).mean()
|
||||
high_low = df['High'] - df['Low']
|
||||
high_close = np.abs(df['High'] - df['Close'].shift(1))
|
||||
low_close = np.abs(df['Low'] - df['Close'].shift(1))
|
||||
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
||||
df['atr'] = tr.rolling(window=14).mean()
|
||||
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
|
||||
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
|
||||
df['hour'] = df['Timestamp'].dt.hour
|
||||
return df
|
||||
|
||||
def load_model_and_features(model_path):
|
||||
model = xgb.Booster()
|
||||
model.load_model(model_path)
|
||||
feature_cols = [
|
||||
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
|
||||
]
|
||||
return model, feature_cols
|
||||
|
||||
def predict_next(df, model, feature_cols):
|
||||
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
|
||||
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
|
||||
pred_log_return = model.predict(dmatrix)[0]
|
||||
last_price = df['Close'].iloc[-1]
|
||||
pred_price = last_price * np.exp(pred_log_return)
|
||||
return pred_price, pred_log_return, last_price
|
||||
|
||||
# --- Main live plotting ---
|
||||
WINDOW = 50
|
||||
MODEL_PATH = 'data/xgboost_model.json'
|
||||
|
||||
ohlcv_bars = [] # [timestamp, open, high, low, close, volume]
|
||||
bar_lock = threading.Lock()
|
||||
model, feature_cols = load_model_and_features(MODEL_PATH)
|
||||
|
||||
# --- Background thread to collect trades and aggregate OHLCV ---
|
||||
def ws_collector():
|
||||
client = OKXClient(authenticate=False)
|
||||
client.subscribe_trades(instrument="BTC-USDT")
|
||||
current_bar = None
|
||||
while True:
|
||||
try:
|
||||
msg = client.ws.recv()
|
||||
data = json.loads(msg)
|
||||
except websocket._exceptions.WebSocketTimeoutException:
|
||||
continue # Just try again
|
||||
except Exception as e:
|
||||
print(f"WebSocket error: {e}")
|
||||
break # or try to reconnect
|
||||
if 'arg' in data and data['arg'].get('channel', '') == 'trades':
|
||||
for trade in data.get('data', []):
|
||||
# trade: {'instId', 'tradeId', 'px', 'sz', 'side', 'ts'}
|
||||
ts = int(trade['ts'])
|
||||
price = float(trade['px'])
|
||||
size = float(trade['sz'])
|
||||
dt = datetime.utcfromtimestamp(ts / 1000)
|
||||
bar_seconds = 30 # or 15, 30, etc.
|
||||
bar_time = dt.replace(second=(dt.second // bar_seconds) * bar_seconds, microsecond=0)
|
||||
with bar_lock:
|
||||
if not ohlcv_bars or ohlcv_bars[-1][0] != bar_time:
|
||||
# New bar
|
||||
ohlcv_bars.append([bar_time, price, price, price, price, size])
|
||||
if len(ohlcv_bars) > WINDOW:
|
||||
ohlcv_bars.pop(0)
|
||||
else:
|
||||
# Update current bar
|
||||
bar = ohlcv_bars[-1]
|
||||
bar[2] = max(bar[2], price) # high
|
||||
bar[3] = min(bar[3], price) # low
|
||||
bar[4] = price # close
|
||||
bar[5] += size # volume
|
||||
|
||||
# Start the background thread
|
||||
threading.Thread(target=ws_collector, daemon=True).start()
|
||||
|
||||
# --- Dash App ---
|
||||
app = dash.Dash(__name__)
|
||||
app.layout = html.Div([
|
||||
html.H2('BTC/USDT Price & Prediction (OKX, XGBoost, Trades Aggregated)', style={"textAlign": "center", "margin": 0, "padding": 0}),
|
||||
dcc.Graph(id='live-graph', animate=False, style={"height": "90vh", "width": "100vw", "margin": 0, "padding": 0}),
|
||||
dcc.Interval(
|
||||
id='interval-component',
|
||||
interval=5*1000, # 5 seconds
|
||||
n_intervals=0
|
||||
),
|
||||
html.Div(id='prediction-output', style={"textAlign": "center", "fontSize": 20, "marginTop": 10})
|
||||
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
|
||||
|
||||
@app.callback(
|
||||
[Output('live-graph', 'figure'), Output('prediction-output', 'children')],
|
||||
[Input('interval-component', 'n_intervals')]
|
||||
)
|
||||
def update_graph_live(n):
|
||||
with bar_lock:
|
||||
bars = list(ohlcv_bars)
|
||||
if len(bars) < 2:
|
||||
return go.Figure(), "Waiting for data..."
|
||||
df = pd.DataFrame(bars, columns=["Timestamp", "Open", "High", "Low", "Close", "Volume"])
|
||||
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
|
||||
df = add_features(df)
|
||||
if df[feature_cols].isnull().any().any():
|
||||
pred_text = "Not enough data for prediction."
|
||||
else:
|
||||
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
|
||||
pred_text = f"Last: {last_price:.2f} | Predicted next: {pred_price:.2f} | LogRet: {pred_log_return:.6f}"
|
||||
fig = go.Figure()
|
||||
fig.add_trace(go.Candlestick(
|
||||
x=df["Timestamp"],
|
||||
open=df["Open"],
|
||||
high=df["High"],
|
||||
low=df["Low"],
|
||||
close=df["Close"],
|
||||
name='Candlestick',
|
||||
increasing_line_color='green',
|
||||
decreasing_line_color='red',
|
||||
showlegend=False
|
||||
))
|
||||
fig.update_layout(title='BTC-USDT 1m OHLCV (Aggregated from Trades)', xaxis_title='Time', yaxis_title='Price (USDT)', xaxis_rangeslider_visible=False)
|
||||
return fig, pred_text
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True)
|
||||
304
main.py
304
main.py
@ -1,152 +1,152 @@
|
||||
from okx_client import OKXClient
|
||||
from market_db import MarketDB
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
import time
|
||||
import signal
|
||||
|
||||
latest_book = {'bids': [], 'asks': [], 'timestamp': None}
|
||||
book_history = deque()
|
||||
trade_history = deque()
|
||||
|
||||
TRADE_HISTORY_SECONDS = 60
|
||||
BOOK_HISTORY_SECONDS = 5
|
||||
|
||||
shutdown_flag = threading.Event()
|
||||
|
||||
def connect(instrument, max_retries=5):
|
||||
logging.info(f"Connecting to OKX for instrument: {instrument}")
|
||||
retries = 0
|
||||
backoff = 1
|
||||
while not shutdown_flag.is_set():
|
||||
try:
|
||||
client = OKXClient(authenticate=False)
|
||||
client.subscribe_trades(instrument)
|
||||
client.subscribe_book(instrument, depth=5, channel="books")
|
||||
logging.info(f"Subscribed to trades and book for {instrument}")
|
||||
return client
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
logging.error(f"Failed to connect to OKX: {e}. Retry {retries}/{max_retries} in {backoff}s.")
|
||||
if retries >= max_retries:
|
||||
logging.critical("Max retries reached. Exiting connect loop.")
|
||||
raise
|
||||
time.sleep(backoff)
|
||||
backoff = min(backoff * 2, 60) # exponential backoff, max 60s
|
||||
return None
|
||||
|
||||
def cleanup(client, db):
|
||||
if client and hasattr(client, 'ws') and client.ws:
|
||||
try:
|
||||
client.ws.close()
|
||||
except Exception as e:
|
||||
logging.warning(f"Error closing websocket: {e}")
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logging.info(f"Received signal {signum}, shutting down...")
|
||||
shutdown_flag.set()
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
def main():
|
||||
instruments = [
|
||||
"ETH-USDT",
|
||||
"BTC-USDT",
|
||||
"SOL-USDT",
|
||||
"DOGE-USDT",
|
||||
"TON-USDT",
|
||||
"ETH-USDC",
|
||||
"SOPH-USDT",
|
||||
"PEPE-USDT",
|
||||
"BTC-USDC",
|
||||
"UNI-USDT"
|
||||
]
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
||||
dbs = {}
|
||||
clients = {}
|
||||
try:
|
||||
for instrument in instruments:
|
||||
dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db")
|
||||
logging.info(f"Database initialized for {instrument}")
|
||||
clients[instrument] = connect(instrument)
|
||||
|
||||
while not shutdown_flag.is_set():
|
||||
for instrument in instruments:
|
||||
client = clients[instrument]
|
||||
db = dbs[instrument]
|
||||
try:
|
||||
data = client.ws.recv()
|
||||
except Exception as e:
|
||||
logging.warning(f"WebSocket disconnected or error for {instrument}: {e}. Reconnecting...")
|
||||
cleanup(client, None)
|
||||
try:
|
||||
clients[instrument] = connect(instrument)
|
||||
except Exception as e:
|
||||
logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.")
|
||||
continue
|
||||
continue
|
||||
|
||||
if shutdown_flag.is_set():
|
||||
break
|
||||
if data == '':
|
||||
continue
|
||||
|
||||
try:
|
||||
msg = json.loads(data)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}")
|
||||
continue
|
||||
|
||||
if 'arg' in msg and msg['arg'].get('channel') == 'trades':
|
||||
for trade in msg.get('data', []):
|
||||
db.insert_trade({
|
||||
'instrument': instrument,
|
||||
'trade_id': trade.get('tradeId'),
|
||||
'price': float(trade.get('px')),
|
||||
'size': float(trade.get('sz')),
|
||||
'side': trade.get('side'),
|
||||
'timestamp': trade.get('ts')
|
||||
})
|
||||
ts = float(trade.get('ts', time.time() * 1000))
|
||||
trade_history.append({
|
||||
'price': trade.get('px'),
|
||||
'size': trade.get('sz'),
|
||||
'side': trade.get('side'),
|
||||
'timestamp': ts
|
||||
})
|
||||
elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'):
|
||||
for book in msg.get('data', []):
|
||||
db.insert_book({
|
||||
'instrument': instrument,
|
||||
'bids': book.get('bids'),
|
||||
'asks': book.get('asks'),
|
||||
'timestamp': book.get('ts')
|
||||
})
|
||||
latest_book['bids'] = book.get('bids', [])
|
||||
latest_book['asks'] = book.get('asks', [])
|
||||
latest_book['timestamp'] = book.get('ts')
|
||||
ts = float(book.get('ts', time.time() * 1000))
|
||||
book_history.append({
|
||||
'bids': book.get('bids', []),
|
||||
'asks': book.get('asks', []),
|
||||
'timestamp': ts
|
||||
})
|
||||
else:
|
||||
logging.info(f"Unknown message for {instrument}: {msg}")
|
||||
except Exception as e:
|
||||
logging.critical(f"Fatal error in main: {e}")
|
||||
finally:
|
||||
for client in clients.values():
|
||||
cleanup(client, None)
|
||||
for db in dbs.values():
|
||||
cleanup(None, db)
|
||||
logging.info('Shutdown complete.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
from okx_client import OKXClient
|
||||
from market_db import MarketDB
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
import time
|
||||
import signal
|
||||
|
||||
latest_book = {'bids': [], 'asks': [], 'timestamp': None}
|
||||
book_history = deque()
|
||||
trade_history = deque()
|
||||
|
||||
TRADE_HISTORY_SECONDS = 60
|
||||
BOOK_HISTORY_SECONDS = 5
|
||||
|
||||
shutdown_flag = threading.Event()
|
||||
|
||||
def connect(instrument, max_retries=5):
|
||||
logging.info(f"Connecting to OKX for instrument: {instrument}")
|
||||
retries = 0
|
||||
backoff = 1
|
||||
while not shutdown_flag.is_set():
|
||||
try:
|
||||
client = OKXClient(authenticate=False)
|
||||
client.subscribe_trades(instrument)
|
||||
client.subscribe_book(instrument, depth=5, channel="books")
|
||||
logging.info(f"Subscribed to trades and book for {instrument}")
|
||||
return client
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
logging.error(f"Failed to connect to OKX: {e}. Retry {retries}/{max_retries} in {backoff}s.")
|
||||
if retries >= max_retries:
|
||||
logging.critical("Max retries reached. Exiting connect loop.")
|
||||
raise
|
||||
time.sleep(backoff)
|
||||
backoff = min(backoff * 2, 60) # exponential backoff, max 60s
|
||||
return None
|
||||
|
||||
def cleanup(client, db):
|
||||
if client and hasattr(client, 'ws') and client.ws:
|
||||
try:
|
||||
client.ws.close()
|
||||
except Exception as e:
|
||||
logging.warning(f"Error closing websocket: {e}")
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logging.info(f"Received signal {signum}, shutting down...")
|
||||
shutdown_flag.set()
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
def main():
|
||||
instruments = [
|
||||
"ETH-USDT",
|
||||
"BTC-USDT",
|
||||
"SOL-USDT",
|
||||
"DOGE-USDT",
|
||||
"TON-USDT",
|
||||
"ETH-USDC",
|
||||
"SOPH-USDT",
|
||||
"PEPE-USDT",
|
||||
"BTC-USDC",
|
||||
"UNI-USDT"
|
||||
]
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
||||
dbs = {}
|
||||
clients = {}
|
||||
try:
|
||||
for instrument in instruments:
|
||||
dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db")
|
||||
logging.info(f"Database initialized for {instrument}")
|
||||
clients[instrument] = connect(instrument)
|
||||
|
||||
while not shutdown_flag.is_set():
|
||||
for instrument in instruments:
|
||||
client = clients[instrument]
|
||||
db = dbs[instrument]
|
||||
try:
|
||||
data = client.ws.recv()
|
||||
except Exception as e:
|
||||
logging.warning(f"WebSocket disconnected or error for {instrument}: {e}. Reconnecting...")
|
||||
cleanup(client, None)
|
||||
try:
|
||||
clients[instrument] = connect(instrument)
|
||||
except Exception as e:
|
||||
logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.")
|
||||
continue
|
||||
continue
|
||||
|
||||
if shutdown_flag.is_set():
|
||||
break
|
||||
if data == '':
|
||||
continue
|
||||
|
||||
try:
|
||||
msg = json.loads(data)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}")
|
||||
continue
|
||||
|
||||
if 'arg' in msg and msg['arg'].get('channel') == 'trades':
|
||||
for trade in msg.get('data', []):
|
||||
db.insert_trade({
|
||||
'instrument': instrument,
|
||||
'trade_id': trade.get('tradeId'),
|
||||
'price': float(trade.get('px')),
|
||||
'size': float(trade.get('sz')),
|
||||
'side': trade.get('side'),
|
||||
'timestamp': trade.get('ts')
|
||||
})
|
||||
ts = float(trade.get('ts', time.time() * 1000))
|
||||
trade_history.append({
|
||||
'price': trade.get('px'),
|
||||
'size': trade.get('sz'),
|
||||
'side': trade.get('side'),
|
||||
'timestamp': ts
|
||||
})
|
||||
elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'):
|
||||
for book in msg.get('data', []):
|
||||
db.insert_book({
|
||||
'instrument': instrument,
|
||||
'bids': book.get('bids'),
|
||||
'asks': book.get('asks'),
|
||||
'timestamp': book.get('ts')
|
||||
})
|
||||
latest_book['bids'] = book.get('bids', [])
|
||||
latest_book['asks'] = book.get('asks', [])
|
||||
latest_book['timestamp'] = book.get('ts')
|
||||
ts = float(book.get('ts', time.time() * 1000))
|
||||
book_history.append({
|
||||
'bids': book.get('bids', []),
|
||||
'asks': book.get('asks', []),
|
||||
'timestamp': ts
|
||||
})
|
||||
else:
|
||||
logging.info(f"Unknown message for {instrument}: {msg}")
|
||||
except Exception as e:
|
||||
logging.critical(f"Fatal error in main: {e}")
|
||||
finally:
|
||||
for client in clients.values():
|
||||
cleanup(client, None)
|
||||
for db in dbs.values():
|
||||
cleanup(None, db)
|
||||
logging.info('Shutdown complete.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
154
market_db.py
154
market_db.py
@ -1,77 +1,77 @@
|
||||
import sqlite3
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
|
||||
class MarketDB:
|
||||
def __init__(self, market: str, db_dir: str = ""):
|
||||
db_name = f"{market}.db"
|
||||
db_path = db_name if not db_dir else f"{db_dir.rstrip('/')}/{db_name}"
|
||||
if db_dir:
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
self.conn = sqlite3.connect(db_path)
|
||||
logging.info(f"Connected to database at {db_path}")
|
||||
self._create_tables()
|
||||
|
||||
def _create_tables(self):
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
instrument TEXT,
|
||||
trade_id TEXT,
|
||||
price REAL,
|
||||
size REAL,
|
||||
side TEXT,
|
||||
timestamp TEXT
|
||||
)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS book (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
instrument TEXT,
|
||||
bids TEXT,
|
||||
asks TEXT,
|
||||
timestamp TEXT
|
||||
)
|
||||
''')
|
||||
self.conn.commit()
|
||||
logging.info("Database tables ensured.")
|
||||
|
||||
def insert_trade(self, trade: Dict[str, Any]):
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT INTO trades (instrument, trade_id, price, size, side, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
trade.get('instrument'),
|
||||
trade.get('trade_id'),
|
||||
trade.get('price'),
|
||||
trade.get('size'),
|
||||
trade.get('side'),
|
||||
trade.get('timestamp')
|
||||
))
|
||||
self.conn.commit()
|
||||
logging.debug(f"Inserted trade: {trade}")
|
||||
|
||||
def insert_book(self, book: Dict[str, Any]):
|
||||
cursor = self.conn.cursor()
|
||||
bids = book.get('bids', [])
|
||||
asks = book.get('asks', [])
|
||||
best_bid = next((b for b in bids if float(b[1]) > 0), ['-', '-'])
|
||||
best_ask = next((a for a in asks if float(a[1]) > 0), ['-', '-'])
|
||||
cursor.execute('''
|
||||
INSERT INTO book (instrument, bids, asks, timestamp)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (
|
||||
book.get('instrument'),
|
||||
str(bids),
|
||||
str(asks),
|
||||
book.get('timestamp')
|
||||
))
|
||||
self.conn.commit()
|
||||
logging.debug(f"Inserted book: {book.get('instrument', 'N/A')} ts:{book.get('timestamp', 'N/A')} bid:{best_bid} ask:{best_ask}")
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
logging.info("Database connection closed.")
|
||||
import sqlite3
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
|
||||
class MarketDB:
|
||||
def __init__(self, market: str, db_dir: str = ""):
|
||||
db_name = f"{market}.db"
|
||||
db_path = db_name if not db_dir else f"{db_dir.rstrip('/')}/{db_name}"
|
||||
if db_dir:
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
self.conn = sqlite3.connect(db_path)
|
||||
logging.info(f"Connected to database at {db_path}")
|
||||
self._create_tables()
|
||||
|
||||
def _create_tables(self):
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
instrument TEXT,
|
||||
trade_id TEXT,
|
||||
price REAL,
|
||||
size REAL,
|
||||
side TEXT,
|
||||
timestamp TEXT
|
||||
)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS book (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
instrument TEXT,
|
||||
bids TEXT,
|
||||
asks TEXT,
|
||||
timestamp TEXT
|
||||
)
|
||||
''')
|
||||
self.conn.commit()
|
||||
logging.info("Database tables ensured.")
|
||||
|
||||
def insert_trade(self, trade: Dict[str, Any]):
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT INTO trades (instrument, trade_id, price, size, side, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
trade.get('instrument'),
|
||||
trade.get('trade_id'),
|
||||
trade.get('price'),
|
||||
trade.get('size'),
|
||||
trade.get('side'),
|
||||
trade.get('timestamp')
|
||||
))
|
||||
self.conn.commit()
|
||||
logging.debug(f"Inserted trade: {trade}")
|
||||
|
||||
def insert_book(self, book: Dict[str, Any]):
|
||||
cursor = self.conn.cursor()
|
||||
bids = book.get('bids', [])
|
||||
asks = book.get('asks', [])
|
||||
best_bid = next((b for b in bids if float(b[1]) > 0), ['-', '-'])
|
||||
best_ask = next((a for a in asks if float(a[1]) > 0), ['-', '-'])
|
||||
cursor.execute('''
|
||||
INSERT INTO book (instrument, bids, asks, timestamp)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (
|
||||
book.get('instrument'),
|
||||
str(bids),
|
||||
str(asks),
|
||||
book.get('timestamp')
|
||||
))
|
||||
self.conn.commit()
|
||||
logging.debug(f"Inserted book: {book.get('instrument', 'N/A')} ts:{book.get('timestamp', 'N/A')} bid:{best_bid} ask:{best_ask}")
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
logging.info("Database connection closed.")
|
||||
|
||||
466
okx_client.py
466
okx_client.py
@ -1,233 +1,233 @@
|
||||
import os
|
||||
import time
|
||||
import hmac
|
||||
import hashlib
|
||||
import base64
|
||||
import json
|
||||
import pandas as pd
|
||||
import threading
|
||||
import requests
|
||||
import websocket
|
||||
import logging
|
||||
|
||||
class OKXClient:
|
||||
PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public"
|
||||
PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private"
|
||||
REST_URL = "https://www.okx.com"
|
||||
|
||||
def __init__(self, authenticate: bool = True):
|
||||
self.authenticated = False
|
||||
self.api_key = None
|
||||
self.api_secret = None
|
||||
self.api_passphrase = None
|
||||
self.ws = None
|
||||
self.ws_private = None
|
||||
self._lock = threading.Lock()
|
||||
self._private_lock = threading.Lock()
|
||||
|
||||
if authenticate:
|
||||
config_path = os.path.join(os.path.dirname(__file__), '../credentials/okx_creds.json')
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Credentials file not found at {config_path}. Please create it with the required keys.")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Credentials file at {config_path} is not valid JSON.")
|
||||
|
||||
self.api_key = config.get("OKX_API_KEY")
|
||||
self.api_secret = config.get("OKX_API_SECRET")
|
||||
self.api_passphrase = config.get("OKX_API_PASSPHRASE")
|
||||
|
||||
if not self.api_key or not self.api_secret or not self.api_passphrase:
|
||||
raise ValueError("API key, secret, and passphrase must be set in the credentials JSON file.")
|
||||
|
||||
self._authenticate()
|
||||
self._connect_ws()
|
||||
|
||||
def _connect_ws(self):
|
||||
if self.ws is None:
|
||||
self.ws = websocket.create_connection(self.PUBLIC_WS_URL, timeout=10)
|
||||
if self.authenticated and self.api_key and self.api_secret and self.api_passphrase and self.ws_private is None:
|
||||
self.ws_private = websocket.create_connection(self.PRIVATE_WS_URL, timeout=10)
|
||||
|
||||
def _get_timestamp(self):
|
||||
return str(round(time.time(), 3))
|
||||
|
||||
def _sign(self, timestamp, method, request_path, body):
|
||||
if not body:
|
||||
body = ''
|
||||
message = f'{timestamp}{method}{request_path}{body}'
|
||||
mac = hmac.new(self.api_secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256)
|
||||
return base64.b64encode(mac.digest()).decode()
|
||||
|
||||
def _authenticate(self):
|
||||
import websocket
|
||||
timestamp = self._get_timestamp()
|
||||
sign = self._sign(timestamp, 'GET', '/users/self/verify', '')
|
||||
login_params = {
|
||||
"op": "login",
|
||||
"args": [{
|
||||
"apiKey": self.api_key,
|
||||
"passphrase": self.api_passphrase,
|
||||
"timestamp": timestamp,
|
||||
"sign": sign
|
||||
}]
|
||||
}
|
||||
self.ws_private.send(json.dumps(login_params))
|
||||
logging.info("Waiting for login response from OKX...")
|
||||
while True:
|
||||
try:
|
||||
resp = self.ws_private.recv()
|
||||
logging.debug(f"Received from OKX private WS: {resp}")
|
||||
if not resp:
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(resp)
|
||||
except Exception:
|
||||
logging.warning(f"Non-JSON message received: {resp}")
|
||||
continue
|
||||
if msg.get("event") == "login":
|
||||
if msg.get("code") == "0":
|
||||
logging.info("OKX WebSocket login successful.")
|
||||
self.authenticated = True
|
||||
break
|
||||
else:
|
||||
raise Exception(f"WebSocket authentication failed: {msg}")
|
||||
except websocket._exceptions.WebSocketConnectionClosedException as e:
|
||||
logging.error(f"WebSocket connection closed during authentication: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Exception during authentication: {e}")
|
||||
raise
|
||||
|
||||
def subscribe_candlesticks(self, instrument="BTC-USDT", timeframe="1m"):
|
||||
# OKX uses candle1m, candle5m, etc.
|
||||
tf_map = {"1m": "candle1m", "5m": "candle5m", "15m": "candle15m", "1h": "candle1H"}
|
||||
channel = tf_map.get(timeframe, f"candle{timeframe}")
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": channel, "instId": instrument}]
|
||||
}
|
||||
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws.send(json.dumps(params))
|
||||
|
||||
def subscribe_trades(self, instrument="BTC-USDT"):
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": "trades", "instId": instrument}]
|
||||
}
|
||||
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws.send(json.dumps(params))
|
||||
|
||||
def subscribe_ticker(self, instrument="BTC-USDT"):
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": "tickers", "instId": instrument}]
|
||||
}
|
||||
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws.send(json.dumps(params))
|
||||
|
||||
def subscribe_book(self, instrument="BTC-USDT", depth=5, channel="books5"):
|
||||
# OKX supports books5, books50, books-l2-tbt
|
||||
# channel = "books5" if depth <= 5 else "books50"
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": channel, "instId": instrument}]
|
||||
}
|
||||
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws.send(json.dumps(params))
|
||||
|
||||
def subscribe_user_order(self):
|
||||
if not self.authenticated:
|
||||
logging.warning("Attempted to subscribe to user order channel without authentication.")
|
||||
return
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": "orders", "instType": "SPOT"}]
|
||||
}
|
||||
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws_private.send(json.dumps(params))
|
||||
|
||||
def subscribe_user_trade(self):
|
||||
if not self.authenticated:
|
||||
logging.warning("Attempted to subscribe to user trade channel without authentication.")
|
||||
return
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": "trades", "instType": "SPOT"}]
|
||||
}
|
||||
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws_private.send(json.dumps(params))
|
||||
|
||||
def subscribe_user_balance(self):
|
||||
if not self.authenticated:
|
||||
logging.warning("Attempted to subscribe to user balance channel without authentication.")
|
||||
return
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": "balance_and_position"}]
|
||||
}
|
||||
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws_private.send(json.dumps(params))
|
||||
|
||||
def get_balance(self, currency=None):
|
||||
url = f"{self.REST_URL}/api/v5/account/balance"
|
||||
timestamp = self._get_timestamp()
|
||||
method = "GET"
|
||||
request_path = "/api/v5/account/balance"
|
||||
body = ''
|
||||
sign = self._sign(timestamp, method, request_path, body)
|
||||
headers = {
|
||||
"OK-ACCESS-KEY": self.api_key,
|
||||
"OK-ACCESS-SIGN": sign,
|
||||
"OK-ACCESS-TIMESTAMP": timestamp,
|
||||
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
resp = requests.get(url, headers=headers)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
balances = data.get("data", [{}])[0].get("details", [])
|
||||
if currency:
|
||||
return [b for b in balances if b.get("ccy") == currency]
|
||||
return balances
|
||||
return []
|
||||
|
||||
def place_order(self, side, amount, instrument="BTC-USDT"):
|
||||
url = f"{self.REST_URL}/api/v5/trade/order"
|
||||
timestamp = self._get_timestamp()
|
||||
method = "POST"
|
||||
request_path = "/api/v5/trade/order"
|
||||
body_dict = {
|
||||
"instId": instrument,
|
||||
"tdMode": "cash",
|
||||
"side": side.lower(),
|
||||
"ordType": "market",
|
||||
"sz": str(amount)
|
||||
}
|
||||
body = json.dumps(body_dict)
|
||||
sign = self._sign(timestamp, method, request_path, body)
|
||||
headers = {
|
||||
"OK-ACCESS-KEY": self.api_key,
|
||||
"OK-ACCESS-SIGN": sign,
|
||||
"OK-ACCESS-TIMESTAMP": timestamp,
|
||||
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
resp = requests.post(url, headers=headers, data=body)
|
||||
return resp.json()
|
||||
|
||||
def buy_btc(self, amount, instrument="BTC-USDT"):
|
||||
return self.place_order("buy", amount, instrument)
|
||||
|
||||
def sell_btc(self, amount, instrument="BTC-USDT"):
|
||||
return self.place_order("sell", amount, instrument)
|
||||
|
||||
def get_instruments(self):
|
||||
url = f"{self.REST_URL}/api/v5/public/instruments?instType=SPOT"
|
||||
resp = requests.get(url)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
return data.get("data", [])
|
||||
return []
|
||||
import os
|
||||
import time
|
||||
import hmac
|
||||
import hashlib
|
||||
import base64
|
||||
import json
|
||||
import pandas as pd
|
||||
import threading
|
||||
import requests
|
||||
import websocket
|
||||
import logging
|
||||
|
||||
class OKXClient:
|
||||
PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public"
|
||||
PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private"
|
||||
REST_URL = "https://www.okx.com"
|
||||
|
||||
def __init__(self, authenticate: bool = True):
|
||||
self.authenticated = False
|
||||
self.api_key = None
|
||||
self.api_secret = None
|
||||
self.api_passphrase = None
|
||||
self.ws = None
|
||||
self.ws_private = None
|
||||
self._lock = threading.Lock()
|
||||
self._private_lock = threading.Lock()
|
||||
|
||||
if authenticate:
|
||||
config_path = os.path.join(os.path.dirname(__file__), '../credentials/okx_creds.json')
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Credentials file not found at {config_path}. Please create it with the required keys.")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Credentials file at {config_path} is not valid JSON.")
|
||||
|
||||
self.api_key = config.get("OKX_API_KEY")
|
||||
self.api_secret = config.get("OKX_API_SECRET")
|
||||
self.api_passphrase = config.get("OKX_API_PASSPHRASE")
|
||||
|
||||
if not self.api_key or not self.api_secret or not self.api_passphrase:
|
||||
raise ValueError("API key, secret, and passphrase must be set in the credentials JSON file.")
|
||||
|
||||
self._authenticate()
|
||||
self._connect_ws()
|
||||
|
||||
def _connect_ws(self):
|
||||
if self.ws is None:
|
||||
self.ws = websocket.create_connection(self.PUBLIC_WS_URL, timeout=10)
|
||||
if self.authenticated and self.api_key and self.api_secret and self.api_passphrase and self.ws_private is None:
|
||||
self.ws_private = websocket.create_connection(self.PRIVATE_WS_URL, timeout=10)
|
||||
|
||||
def _get_timestamp(self):
|
||||
return str(round(time.time(), 3))
|
||||
|
||||
def _sign(self, timestamp, method, request_path, body):
|
||||
if not body:
|
||||
body = ''
|
||||
message = f'{timestamp}{method}{request_path}{body}'
|
||||
mac = hmac.new(self.api_secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256)
|
||||
return base64.b64encode(mac.digest()).decode()
|
||||
|
||||
def _authenticate(self):
|
||||
import websocket
|
||||
timestamp = self._get_timestamp()
|
||||
sign = self._sign(timestamp, 'GET', '/users/self/verify', '')
|
||||
login_params = {
|
||||
"op": "login",
|
||||
"args": [{
|
||||
"apiKey": self.api_key,
|
||||
"passphrase": self.api_passphrase,
|
||||
"timestamp": timestamp,
|
||||
"sign": sign
|
||||
}]
|
||||
}
|
||||
self.ws_private.send(json.dumps(login_params))
|
||||
logging.info("Waiting for login response from OKX...")
|
||||
while True:
|
||||
try:
|
||||
resp = self.ws_private.recv()
|
||||
logging.debug(f"Received from OKX private WS: {resp}")
|
||||
if not resp:
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(resp)
|
||||
except Exception:
|
||||
logging.warning(f"Non-JSON message received: {resp}")
|
||||
continue
|
||||
if msg.get("event") == "login":
|
||||
if msg.get("code") == "0":
|
||||
logging.info("OKX WebSocket login successful.")
|
||||
self.authenticated = True
|
||||
break
|
||||
else:
|
||||
raise Exception(f"WebSocket authentication failed: {msg}")
|
||||
except websocket._exceptions.WebSocketConnectionClosedException as e:
|
||||
logging.error(f"WebSocket connection closed during authentication: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Exception during authentication: {e}")
|
||||
raise
|
||||
|
||||
def subscribe_candlesticks(self, instrument="BTC-USDT", timeframe="1m"):
|
||||
# OKX uses candle1m, candle5m, etc.
|
||||
tf_map = {"1m": "candle1m", "5m": "candle5m", "15m": "candle15m", "1h": "candle1H"}
|
||||
channel = tf_map.get(timeframe, f"candle{timeframe}")
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": channel, "instId": instrument}]
|
||||
}
|
||||
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws.send(json.dumps(params))
|
||||
|
||||
def subscribe_trades(self, instrument="BTC-USDT"):
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": "trades", "instId": instrument}]
|
||||
}
|
||||
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws.send(json.dumps(params))
|
||||
|
||||
def subscribe_ticker(self, instrument="BTC-USDT"):
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": "tickers", "instId": instrument}]
|
||||
}
|
||||
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws.send(json.dumps(params))
|
||||
|
||||
def subscribe_book(self, instrument="BTC-USDT", depth=5, channel="books5"):
|
||||
# OKX supports books5, books50, books-l2-tbt
|
||||
# channel = "books5" if depth <= 5 else "books50"
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": channel, "instId": instrument}]
|
||||
}
|
||||
logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws.send(json.dumps(params))
|
||||
|
||||
def subscribe_user_order(self):
|
||||
if not self.authenticated:
|
||||
logging.warning("Attempted to subscribe to user order channel without authentication.")
|
||||
return
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": "orders", "instType": "SPOT"}]
|
||||
}
|
||||
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws_private.send(json.dumps(params))
|
||||
|
||||
def subscribe_user_trade(self):
|
||||
if not self.authenticated:
|
||||
logging.warning("Attempted to subscribe to user trade channel without authentication.")
|
||||
return
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": "trades", "instType": "SPOT"}]
|
||||
}
|
||||
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws_private.send(json.dumps(params))
|
||||
|
||||
def subscribe_user_balance(self):
|
||||
if not self.authenticated:
|
||||
logging.warning("Attempted to subscribe to user balance channel without authentication.")
|
||||
return
|
||||
params = {
|
||||
"op": "subscribe",
|
||||
"args": [{"channel": "balance_and_position"}]
|
||||
}
|
||||
logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}")
|
||||
self.ws_private.send(json.dumps(params))
|
||||
|
||||
def get_balance(self, currency=None):
|
||||
url = f"{self.REST_URL}/api/v5/account/balance"
|
||||
timestamp = self._get_timestamp()
|
||||
method = "GET"
|
||||
request_path = "/api/v5/account/balance"
|
||||
body = ''
|
||||
sign = self._sign(timestamp, method, request_path, body)
|
||||
headers = {
|
||||
"OK-ACCESS-KEY": self.api_key,
|
||||
"OK-ACCESS-SIGN": sign,
|
||||
"OK-ACCESS-TIMESTAMP": timestamp,
|
||||
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
resp = requests.get(url, headers=headers)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
balances = data.get("data", [{}])[0].get("details", [])
|
||||
if currency:
|
||||
return [b for b in balances if b.get("ccy") == currency]
|
||||
return balances
|
||||
return []
|
||||
|
||||
def place_order(self, side, amount, instrument="BTC-USDT"):
|
||||
url = f"{self.REST_URL}/api/v5/trade/order"
|
||||
timestamp = self._get_timestamp()
|
||||
method = "POST"
|
||||
request_path = "/api/v5/trade/order"
|
||||
body_dict = {
|
||||
"instId": instrument,
|
||||
"tdMode": "cash",
|
||||
"side": side.lower(),
|
||||
"ordType": "market",
|
||||
"sz": str(amount)
|
||||
}
|
||||
body = json.dumps(body_dict)
|
||||
sign = self._sign(timestamp, method, request_path, body)
|
||||
headers = {
|
||||
"OK-ACCESS-KEY": self.api_key,
|
||||
"OK-ACCESS-SIGN": sign,
|
||||
"OK-ACCESS-TIMESTAMP": timestamp,
|
||||
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
resp = requests.post(url, headers=headers, data=body)
|
||||
return resp.json()
|
||||
|
||||
def buy_btc(self, amount, instrument="BTC-USDT"):
|
||||
return self.place_order("buy", amount, instrument)
|
||||
|
||||
def sell_btc(self, amount, instrument="BTC-USDT"):
|
||||
return self.place_order("sell", amount, instrument)
|
||||
|
||||
def get_instruments(self):
|
||||
url = f"{self.REST_URL}/api/v5/public/instruments?instType=SPOT"
|
||||
resp = requests.get(url)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
return data.get("data", [])
|
||||
return []
|
||||
|
||||
122
test_predictor.py
Normal file
122
test_predictor.py
Normal file
@ -0,0 +1,122 @@
|
||||
import time
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import requests
|
||||
import xgboost as xgb
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# --- CONFIG ---
|
||||
OKX_REST_URL = 'https://www.okx.com/api/v5/market/candles'
|
||||
SYMBOL = 'BTC-USDT'
|
||||
BAR = '1m'
|
||||
HIST_MINUTES = 250 # Number of minutes of history to fetch for features
|
||||
MODEL_PATH = 'data/xgboost_model.json'
|
||||
|
||||
# --- Fetch recent candles from OKX REST API ---
|
||||
def fetch_recent_candles(symbol, bar, limit=HIST_MINUTES):
|
||||
params = {
|
||||
'instId': symbol,
|
||||
'bar': bar,
|
||||
'limit': str(limit)
|
||||
}
|
||||
resp = requests.get(OKX_REST_URL, params=params)
|
||||
data = resp.json()
|
||||
if data['code'] != '0':
|
||||
raise Exception(f"OKX API error: {data['msg']}")
|
||||
# OKX returns most recent first, reverse to chronological
|
||||
candles = data['data'][::-1]
|
||||
df = pd.DataFrame(candles)
|
||||
# OKX columns: [ts, o, h, l, c, vol, volCcy, confirm, ...] (see API docs)
|
||||
# We'll use: ts, o, h, l, c, vol
|
||||
col_map = {
|
||||
0: 'Timestamp',
|
||||
1: 'Open',
|
||||
2: 'High',
|
||||
3: 'Low',
|
||||
4: 'Close',
|
||||
5: 'Volume',
|
||||
}
|
||||
df = df.rename(columns={str(k): v for k, v in col_map.items()})
|
||||
# If columns are not named, use integer index
|
||||
for k, v in col_map.items():
|
||||
if v not in df.columns:
|
||||
df[v] = df.iloc[:, k]
|
||||
df = df[['Timestamp', 'Open', 'High', 'Low', 'Close', 'Volume']]
|
||||
df['Timestamp'] = pd.to_datetime(df['Timestamp'].astype(np.int64), unit='ms')
|
||||
for col in ['Open', 'High', 'Low', 'Close', 'Volume']:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
return df
|
||||
|
||||
# --- Feature Engineering (minimal, real-time) ---
|
||||
def add_features(df):
|
||||
# Log return (target, not used for prediction)
|
||||
df['log_return'] = np.log(df['Close'] / df['Close'].shift(1))
|
||||
# RSI (14)
|
||||
def calc_rsi(close, window=14):
|
||||
delta = close.diff()
|
||||
up = delta.clip(lower=0)
|
||||
down = -1 * delta.clip(upper=0)
|
||||
ma_up = up.rolling(window=window, min_periods=window).mean()
|
||||
ma_down = down.rolling(window=window, min_periods=window).mean()
|
||||
rs = ma_up / ma_down
|
||||
return 100 - (100 / (1 + rs))
|
||||
df['rsi'] = calc_rsi(df['Close'])
|
||||
# EMA 14
|
||||
df['ema_14'] = df['Close'].ewm(span=14, adjust=False).mean()
|
||||
# SMA 50, 200
|
||||
df['sma_50'] = df['Close'].rolling(window=50).mean()
|
||||
df['sma_200'] = df['Close'].rolling(window=200).mean()
|
||||
# ATR 14
|
||||
high_low = df['High'] - df['Low']
|
||||
high_close = np.abs(df['High'] - df['Close'].shift(1))
|
||||
low_close = np.abs(df['Low'] - df['Close'].shift(1))
|
||||
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
||||
df['atr'] = tr.rolling(window=14).mean()
|
||||
# ROC 10
|
||||
df['roc_10'] = df['Close'].pct_change(periods=10) * 100
|
||||
# DPO 20
|
||||
df['dpo_20'] = df['Close'] - df['Close'].rolling(window=21).mean().shift(-10)
|
||||
# Hour
|
||||
df['hour'] = df['Timestamp'].dt.hour
|
||||
# Add more features as needed (match with main.py)
|
||||
return df
|
||||
|
||||
# --- Load model and feature columns ---
|
||||
def load_model_and_features(model_path):
|
||||
model = xgb.Booster()
|
||||
model.load_model(model_path)
|
||||
# Try to infer feature names from main.py (hardcoded for now)
|
||||
feature_cols = [
|
||||
'rsi', 'ema_14', 'sma_50', 'sma_200', 'atr', 'roc_10', 'dpo_20', 'hour'
|
||||
]
|
||||
return model, feature_cols
|
||||
|
||||
# --- Predict next log return and price ---
|
||||
def predict_next(df, model, feature_cols):
|
||||
# Use the last row for prediction
|
||||
X = df[feature_cols].iloc[[-1]].values.astype(np.float32)
|
||||
dmatrix = xgb.DMatrix(X, feature_names=feature_cols)
|
||||
pred_log_return = model.predict(dmatrix)[0]
|
||||
last_price = df['Close'].iloc[-1]
|
||||
pred_price = last_price * np.exp(pred_log_return)
|
||||
return pred_price, pred_log_return, last_price
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Fetching recent candles from OKX...')
|
||||
df = fetch_recent_candles(SYMBOL, BAR)
|
||||
df = add_features(df)
|
||||
model, feature_cols = load_model_and_features(MODEL_PATH)
|
||||
print('Waiting for new candle...')
|
||||
last_timestamp = df['Timestamp'].iloc[-1]
|
||||
while True:
|
||||
time.sleep(5)
|
||||
new_df = fetch_recent_candles(SYMBOL, BAR, limit=HIST_MINUTES)
|
||||
if new_df['Timestamp'].iloc[-1] > last_timestamp:
|
||||
df = new_df
|
||||
df = add_features(df)
|
||||
pred_price, pred_log_return, last_price = predict_next(df, model, feature_cols)
|
||||
print(f"[{df['Timestamp'].iloc[-1]}] Last price: {last_price:.2f} | Predicted next price: {pred_price:.2f} | Predicted log return: {pred_log_return:.6f}")
|
||||
last_timestamp = df['Timestamp'].iloc[-1]
|
||||
else:
|
||||
print('No new candle yet...')
|
||||
216
visualizer.py
216
visualizer.py
@ -1,108 +1,108 @@
|
||||
import dash
|
||||
from dash import dcc, html
|
||||
from dash.dependencies import Output, Input
|
||||
import plotly.graph_objs as go
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
def run_dash(book_history, trade_history, BOOK_HISTORY_SECONDS=5, TRADE_HISTORY_SECONDS=60):
|
||||
app = dash.Dash(__name__)
|
||||
app.layout = html.Div([
|
||||
html.H1("Order Book Depth Chart", style={"textAlign": "center", "color": "#222"}),
|
||||
dcc.Graph(id='order-book-graph', style={"height": "90vh", "width": "100vw"}),
|
||||
dcc.Interval(id='interval-component', interval=2*1000, n_intervals=0)
|
||||
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
|
||||
|
||||
@app.callback(
|
||||
[Output('order-book-graph', 'figure')],
|
||||
[Input('interval-component', 'n_intervals')]
|
||||
)
|
||||
def update_graphs(n):
|
||||
now = time.time() * 1000 # current time in ms
|
||||
|
||||
# Prune book_history to only keep last BOOK_HISTORY_SECONDS
|
||||
while book_history and now - book_history[0]['timestamp'] > BOOK_HISTORY_SECONDS * 1000:
|
||||
book_history.popleft()
|
||||
|
||||
# Prune trade_history to only keep last TRADE_HISTORY_SECONDS
|
||||
while trade_history and now - float(trade_history[0]['timestamp']) > TRADE_HISTORY_SECONDS * 1000:
|
||||
trade_history.popleft()
|
||||
|
||||
# Aggregate bids/asks from book_history
|
||||
bids_dict = {}
|
||||
asks_dict = {}
|
||||
|
||||
for book in book_history:
|
||||
for price, size, *_ in book['bids']:
|
||||
price = float(price)
|
||||
size = float(size)
|
||||
bids_dict[price] = bids_dict.get(price, 0) + size
|
||||
|
||||
for price, size, *_ in book['asks']:
|
||||
price = float(price)
|
||||
size = float(size)
|
||||
asks_dict[price] = asks_dict.get(price, 0) + size
|
||||
|
||||
try:
|
||||
# Prepare and sort bids/asks
|
||||
bids = sorted([[p, s] for p, s in bids_dict.items()], reverse=True)
|
||||
asks = sorted([[p, s] for p, s in asks_dict.items()])
|
||||
|
||||
# Cumulative sum
|
||||
bid_prices = [b[0] for b in bids]
|
||||
bid_sizes = [b[1] for b in bids]
|
||||
ask_prices = [a[0] for a in asks]
|
||||
ask_sizes = [a[1] for a in asks]
|
||||
bid_cumsum = [sum(bid_sizes[:i+1]) for i in range(len(bid_sizes))]
|
||||
ask_cumsum = [sum(ask_sizes[:i+1]) for i in range(len(ask_sizes))]
|
||||
except Exception as e:
|
||||
bid_prices, bid_cumsum, ask_prices, ask_cumsum = [], [], [], []
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
# Add order book lines (primary y-axis)
|
||||
fig.add_trace(go.Scatter(
|
||||
x=bid_prices, y=bid_cumsum, mode='lines', name='Bids',
|
||||
line=dict(color='green'), fill='tozeroy', yaxis='y1'
|
||||
))
|
||||
fig.add_trace(go.Scatter(
|
||||
x=ask_prices, y=ask_cumsum, mode='lines', name='Asks',
|
||||
line=dict(color='red'), fill='tozeroy', yaxis='y1'
|
||||
))
|
||||
|
||||
trade_volume_by_price = defaultdict(float)
|
||||
|
||||
for trade in trade_history:
|
||||
price_bin = round(float(trade['price']), 2)
|
||||
trade_volume_by_price[price_bin] += float(trade['size'])
|
||||
|
||||
prices = list(trade_volume_by_price.keys())
|
||||
volumes = list(trade_volume_by_price.values())
|
||||
|
||||
# Sort by price for display
|
||||
sorted_pairs = sorted(zip(prices, volumes))
|
||||
prices = [p for p, v in sorted_pairs]
|
||||
volumes = [v for p, v in sorted_pairs]
|
||||
|
||||
# Add trade volume bars (secondary y-axis)
|
||||
fig.add_trace(go.Bar(
|
||||
x=prices, y=volumes, marker_color='#7ec8e3', name='Trade Volume',
|
||||
opacity=0.7, yaxis='y2'
|
||||
))
|
||||
|
||||
# Update layout for dual y-axes
|
||||
fig.update_layout(
|
||||
title='Order Book Depth & Realized Trade Volume by Price',
|
||||
xaxis=dict(title='Price'),
|
||||
yaxis=dict(title='Cumulative Size', side='left'),
|
||||
yaxis2=dict(
|
||||
title='Traded Volume',
|
||||
overlaying='y',
|
||||
side='right',
|
||||
showgrid=False
|
||||
),
|
||||
template='plotly_dark'
|
||||
)
|
||||
return [fig]
|
||||
|
||||
app.run(debug=True, use_reloader=False)
|
||||
import dash
|
||||
from dash import dcc, html
|
||||
from dash.dependencies import Output, Input
|
||||
import plotly.graph_objs as go
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
def run_dash(book_history, trade_history, BOOK_HISTORY_SECONDS=5, TRADE_HISTORY_SECONDS=60):
|
||||
app = dash.Dash(__name__)
|
||||
app.layout = html.Div([
|
||||
html.H1("Order Book Depth Chart", style={"textAlign": "center", "color": "#222"}),
|
||||
dcc.Graph(id='order-book-graph', style={"height": "90vh", "width": "100vw"}),
|
||||
dcc.Interval(id='interval-component', interval=2*1000, n_intervals=0)
|
||||
], style={"height": "100vh", "width": "100vw", "margin": 0, "padding": 0, "overflow": "hidden", "backgroundColor": "#f7f7f7"})
|
||||
|
||||
@app.callback(
|
||||
[Output('order-book-graph', 'figure')],
|
||||
[Input('interval-component', 'n_intervals')]
|
||||
)
|
||||
def update_graphs(n):
|
||||
now = time.time() * 1000 # current time in ms
|
||||
|
||||
# Prune book_history to only keep last BOOK_HISTORY_SECONDS
|
||||
while book_history and now - book_history[0]['timestamp'] > BOOK_HISTORY_SECONDS * 1000:
|
||||
book_history.popleft()
|
||||
|
||||
# Prune trade_history to only keep last TRADE_HISTORY_SECONDS
|
||||
while trade_history and now - float(trade_history[0]['timestamp']) > TRADE_HISTORY_SECONDS * 1000:
|
||||
trade_history.popleft()
|
||||
|
||||
# Aggregate bids/asks from book_history
|
||||
bids_dict = {}
|
||||
asks_dict = {}
|
||||
|
||||
for book in book_history:
|
||||
for price, size, *_ in book['bids']:
|
||||
price = float(price)
|
||||
size = float(size)
|
||||
bids_dict[price] = bids_dict.get(price, 0) + size
|
||||
|
||||
for price, size, *_ in book['asks']:
|
||||
price = float(price)
|
||||
size = float(size)
|
||||
asks_dict[price] = asks_dict.get(price, 0) + size
|
||||
|
||||
try:
|
||||
# Prepare and sort bids/asks
|
||||
bids = sorted([[p, s] for p, s in bids_dict.items()], reverse=True)
|
||||
asks = sorted([[p, s] for p, s in asks_dict.items()])
|
||||
|
||||
# Cumulative sum
|
||||
bid_prices = [b[0] for b in bids]
|
||||
bid_sizes = [b[1] for b in bids]
|
||||
ask_prices = [a[0] for a in asks]
|
||||
ask_sizes = [a[1] for a in asks]
|
||||
bid_cumsum = [sum(bid_sizes[:i+1]) for i in range(len(bid_sizes))]
|
||||
ask_cumsum = [sum(ask_sizes[:i+1]) for i in range(len(ask_sizes))]
|
||||
except Exception as e:
|
||||
bid_prices, bid_cumsum, ask_prices, ask_cumsum = [], [], [], []
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
# Add order book lines (primary y-axis)
|
||||
fig.add_trace(go.Scatter(
|
||||
x=bid_prices, y=bid_cumsum, mode='lines', name='Bids',
|
||||
line=dict(color='green'), fill='tozeroy', yaxis='y1'
|
||||
))
|
||||
fig.add_trace(go.Scatter(
|
||||
x=ask_prices, y=ask_cumsum, mode='lines', name='Asks',
|
||||
line=dict(color='red'), fill='tozeroy', yaxis='y1'
|
||||
))
|
||||
|
||||
trade_volume_by_price = defaultdict(float)
|
||||
|
||||
for trade in trade_history:
|
||||
price_bin = round(float(trade['price']), 2)
|
||||
trade_volume_by_price[price_bin] += float(trade['size'])
|
||||
|
||||
prices = list(trade_volume_by_price.keys())
|
||||
volumes = list(trade_volume_by_price.values())
|
||||
|
||||
# Sort by price for display
|
||||
sorted_pairs = sorted(zip(prices, volumes))
|
||||
prices = [p for p, v in sorted_pairs]
|
||||
volumes = [v for p, v in sorted_pairs]
|
||||
|
||||
# Add trade volume bars (secondary y-axis)
|
||||
fig.add_trace(go.Bar(
|
||||
x=prices, y=volumes, marker_color='#7ec8e3', name='Trade Volume',
|
||||
opacity=0.7, yaxis='y2'
|
||||
))
|
||||
|
||||
# Update layout for dual y-axes
|
||||
fig.update_layout(
|
||||
title='Order Book Depth & Realized Trade Volume by Price',
|
||||
xaxis=dict(title='Price'),
|
||||
yaxis=dict(title='Cumulative Size', side='left'),
|
||||
yaxis2=dict(
|
||||
title='Traded Volume',
|
||||
overlaying='y',
|
||||
side='right',
|
||||
showgrid=False
|
||||
),
|
||||
template='plotly_dark'
|
||||
)
|
||||
return [fig]
|
||||
|
||||
app.run(debug=True, use_reloader=False)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user