From 8f96e14b8b6d5788cb739d9830f926ce29b6aeb3 Mon Sep 17 00:00:00 2001 From: Simon Moisy Date: Thu, 29 May 2025 16:28:27 +0800 Subject: [PATCH] init --- .gitignore | 2 + main.py | 152 +++++++++++++++++++++++++++++++ market_db.py | 77 ++++++++++++++++ okx_client.py | 233 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | Bin 0 -> 442 bytes visualizer.py | 108 ++++++++++++++++++++++ 6 files changed, 572 insertions(+) create mode 100644 main.py create mode 100644 market_db.py create mode 100644 okx_client.py create mode 100644 requirements.txt create mode 100644 visualizer.py diff --git a/.gitignore b/.gitignore index 0dbf2f2..b3be1b9 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] +/credentials/*.json +/data/ *$py.class # C extensions diff --git a/main.py b/main.py new file mode 100644 index 0000000..5ea7220 --- /dev/null +++ b/main.py @@ -0,0 +1,152 @@ +from okx_client import OKXClient +from market_db import MarketDB +import json +import logging +import threading +from collections import deque +import time +import signal + +latest_book = {'bids': [], 'asks': [], 'timestamp': None} +book_history = deque() +trade_history = deque() + +TRADE_HISTORY_SECONDS = 60 +BOOK_HISTORY_SECONDS = 5 + +shutdown_flag = threading.Event() + +def connect(instrument, max_retries=5): + logging.info(f"Connecting to OKX for instrument: {instrument}") + retries = 0 + backoff = 1 + while not shutdown_flag.is_set(): + try: + client = OKXClient(authenticate=False) + client.subscribe_trades(instrument) + client.subscribe_book(instrument, depth=5, channel="books") + logging.info(f"Subscribed to trades and book for {instrument}") + return client + except Exception as e: + retries += 1 + logging.error(f"Failed to connect to OKX: {e}. Retry {retries}/{max_retries} in {backoff}s.") + if retries >= max_retries: + logging.critical("Max retries reached. Exiting connect loop.") + raise + time.sleep(backoff) + backoff = min(backoff * 2, 60) # exponential backoff, max 60s + return None + +def cleanup(client, db): + if client and hasattr(client, 'ws') and client.ws: + try: + client.ws.close() + except Exception as e: + logging.warning(f"Error closing websocket: {e}") + if db: + db.close() + +def signal_handler(signum, frame): + logging.info(f"Received signal {signum}, shutting down...") + shutdown_flag.set() + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + +def main(): + instruments = [ + "ETH-USDT", + "BTC-USDT", + "SOL-USDT", + "DOGE-USDT", + "TON-USDT", + "ETH-USDC", + "SOPH-USDT", + "PEPE-USDT", + "BTC-USDC", + "UNI-USDT" + ] + + logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') + dbs = {} + clients = {} + try: + for instrument in instruments: + dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db") + logging.info(f"Database initialized for {instrument}") + clients[instrument] = connect(instrument) + + while not shutdown_flag.is_set(): + for instrument in instruments: + client = clients[instrument] + db = dbs[instrument] + try: + data = client.ws.recv() + except Exception as e: + logging.warning(f"WebSocket disconnected or error for {instrument}: {e}. Reconnecting...") + cleanup(client, None) + try: + clients[instrument] = connect(instrument) + except Exception as e: + logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.") + continue + continue + + if shutdown_flag.is_set(): + break + if data == '': + continue + + try: + msg = json.loads(data) + except Exception as e: + logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}") + continue + + if 'arg' in msg and msg['arg'].get('channel') == 'trades': + for trade in msg.get('data', []): + db.insert_trade({ + 'instrument': instrument, + 'trade_id': trade.get('tradeId'), + 'price': float(trade.get('px')), + 'size': float(trade.get('sz')), + 'side': trade.get('side'), + 'timestamp': trade.get('ts') + }) + ts = float(trade.get('ts', time.time() * 1000)) + trade_history.append({ + 'price': trade.get('px'), + 'size': trade.get('sz'), + 'side': trade.get('side'), + 'timestamp': ts + }) + elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'): + for book in msg.get('data', []): + db.insert_book({ + 'instrument': instrument, + 'bids': book.get('bids'), + 'asks': book.get('asks'), + 'timestamp': book.get('ts') + }) + latest_book['bids'] = book.get('bids', []) + latest_book['asks'] = book.get('asks', []) + latest_book['timestamp'] = book.get('ts') + ts = float(book.get('ts', time.time() * 1000)) + book_history.append({ + 'bids': book.get('bids', []), + 'asks': book.get('asks', []), + 'timestamp': ts + }) + else: + logging.info(f"Unknown message for {instrument}: {msg}") + except Exception as e: + logging.critical(f"Fatal error in main: {e}") + finally: + for client in clients.values(): + cleanup(client, None) + for db in dbs.values(): + cleanup(None, db) + logging.info('Shutdown complete.') + +if __name__ == '__main__': + main() diff --git a/market_db.py b/market_db.py new file mode 100644 index 0000000..8bc6abf --- /dev/null +++ b/market_db.py @@ -0,0 +1,77 @@ +import sqlite3 +import os +from typing import Any, Dict, List, Optional +import logging + +class MarketDB: + def __init__(self, market: str, db_dir: str = ""): + db_name = f"{market}.db" + db_path = db_name if not db_dir else f"{db_dir.rstrip('/')}/{db_name}" + if db_dir: + os.makedirs(db_dir, exist_ok=True) + self.conn = sqlite3.connect(db_path) + logging.info(f"Connected to database at {db_path}") + self._create_tables() + + def _create_tables(self): + cursor = self.conn.cursor() + cursor.execute(''' + CREATE TABLE IF NOT EXISTS trades ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + instrument TEXT, + trade_id TEXT, + price REAL, + size REAL, + side TEXT, + timestamp TEXT + ) + ''') + cursor.execute(''' + CREATE TABLE IF NOT EXISTS book ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + instrument TEXT, + bids TEXT, + asks TEXT, + timestamp TEXT + ) + ''') + self.conn.commit() + logging.info("Database tables ensured.") + + def insert_trade(self, trade: Dict[str, Any]): + cursor = self.conn.cursor() + cursor.execute(''' + INSERT INTO trades (instrument, trade_id, price, size, side, timestamp) + VALUES (?, ?, ?, ?, ?, ?) + ''', ( + trade.get('instrument'), + trade.get('trade_id'), + trade.get('price'), + trade.get('size'), + trade.get('side'), + trade.get('timestamp') + )) + self.conn.commit() + logging.debug(f"Inserted trade: {trade}") + + def insert_book(self, book: Dict[str, Any]): + cursor = self.conn.cursor() + bids = book.get('bids', []) + asks = book.get('asks', []) + best_bid = next((b for b in bids if float(b[1]) > 0), ['-', '-']) + best_ask = next((a for a in asks if float(a[1]) > 0), ['-', '-']) + cursor.execute(''' + INSERT INTO book (instrument, bids, asks, timestamp) + VALUES (?, ?, ?, ?) + ''', ( + book.get('instrument'), + str(bids), + str(asks), + book.get('timestamp') + )) + self.conn.commit() + logging.debug(f"Inserted book: {book.get('instrument', 'N/A')} ts:{book.get('timestamp', 'N/A')} bid:{best_bid} ask:{best_ask}") + + def close(self): + self.conn.close() + logging.info("Database connection closed.") diff --git a/okx_client.py b/okx_client.py new file mode 100644 index 0000000..52957ad --- /dev/null +++ b/okx_client.py @@ -0,0 +1,233 @@ +import os +import time +import hmac +import hashlib +import base64 +import json +import pandas as pd +import threading +import requests +import websocket +import logging + +class OKXClient: + PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public" + PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private" + REST_URL = "https://www.okx.com" + + def __init__(self, authenticate: bool = True): + self.authenticated = False + self.api_key = None + self.api_secret = None + self.api_passphrase = None + self.ws = None + self.ws_private = None + self._lock = threading.Lock() + self._private_lock = threading.Lock() + + if authenticate: + config_path = os.path.join(os.path.dirname(__file__), '../credentials/okx_creds.json') + try: + with open(config_path, 'r') as f: + config = json.load(f) + except FileNotFoundError: + raise FileNotFoundError(f"Credentials file not found at {config_path}. Please create it with the required keys.") + except json.JSONDecodeError: + raise ValueError(f"Credentials file at {config_path} is not valid JSON.") + + self.api_key = config.get("OKX_API_KEY") + self.api_secret = config.get("OKX_API_SECRET") + self.api_passphrase = config.get("OKX_API_PASSPHRASE") + + if not self.api_key or not self.api_secret or not self.api_passphrase: + raise ValueError("API key, secret, and passphrase must be set in the credentials JSON file.") + + self._authenticate() + self._connect_ws() + + def _connect_ws(self): + if self.ws is None: + self.ws = websocket.create_connection(self.PUBLIC_WS_URL, timeout=10) + if self.authenticated and self.api_key and self.api_secret and self.api_passphrase and self.ws_private is None: + self.ws_private = websocket.create_connection(self.PRIVATE_WS_URL, timeout=10) + + def _get_timestamp(self): + return str(round(time.time(), 3)) + + def _sign(self, timestamp, method, request_path, body): + if not body: + body = '' + message = f'{timestamp}{method}{request_path}{body}' + mac = hmac.new(self.api_secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256) + return base64.b64encode(mac.digest()).decode() + + def _authenticate(self): + import websocket + timestamp = self._get_timestamp() + sign = self._sign(timestamp, 'GET', '/users/self/verify', '') + login_params = { + "op": "login", + "args": [{ + "apiKey": self.api_key, + "passphrase": self.api_passphrase, + "timestamp": timestamp, + "sign": sign + }] + } + self.ws_private.send(json.dumps(login_params)) + logging.info("Waiting for login response from OKX...") + while True: + try: + resp = self.ws_private.recv() + logging.debug(f"Received from OKX private WS: {resp}") + if not resp: + continue + try: + msg = json.loads(resp) + except Exception: + logging.warning(f"Non-JSON message received: {resp}") + continue + if msg.get("event") == "login": + if msg.get("code") == "0": + logging.info("OKX WebSocket login successful.") + self.authenticated = True + break + else: + raise Exception(f"WebSocket authentication failed: {msg}") + except websocket._exceptions.WebSocketConnectionClosedException as e: + logging.error(f"WebSocket connection closed during authentication: {e}") + raise + except Exception as e: + logging.error(f"Exception during authentication: {e}") + raise + + def subscribe_candlesticks(self, instrument="BTC-USDT", timeframe="1m"): + # OKX uses candle1m, candle5m, etc. + tf_map = {"1m": "candle1m", "5m": "candle5m", "15m": "candle15m", "1h": "candle1H"} + channel = tf_map.get(timeframe, f"candle{timeframe}") + params = { + "op": "subscribe", + "args": [{"channel": channel, "instId": instrument}] + } + logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") + self.ws.send(json.dumps(params)) + + def subscribe_trades(self, instrument="BTC-USDT"): + params = { + "op": "subscribe", + "args": [{"channel": "trades", "instId": instrument}] + } + logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") + self.ws.send(json.dumps(params)) + + def subscribe_ticker(self, instrument="BTC-USDT"): + params = { + "op": "subscribe", + "args": [{"channel": "tickers", "instId": instrument}] + } + logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") + self.ws.send(json.dumps(params)) + + def subscribe_book(self, instrument="BTC-USDT", depth=5, channel="books5"): + # OKX supports books5, books50, books-l2-tbt + # channel = "books5" if depth <= 5 else "books50" + params = { + "op": "subscribe", + "args": [{"channel": channel, "instId": instrument}] + } + logging.debug(f"[PUBLIC WS] Sending subscription: {json.dumps(params)}") + self.ws.send(json.dumps(params)) + + def subscribe_user_order(self): + if not self.authenticated: + logging.warning("Attempted to subscribe to user order channel without authentication.") + return + params = { + "op": "subscribe", + "args": [{"channel": "orders", "instType": "SPOT"}] + } + logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}") + self.ws_private.send(json.dumps(params)) + + def subscribe_user_trade(self): + if not self.authenticated: + logging.warning("Attempted to subscribe to user trade channel without authentication.") + return + params = { + "op": "subscribe", + "args": [{"channel": "trades", "instType": "SPOT"}] + } + logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}") + self.ws_private.send(json.dumps(params)) + + def subscribe_user_balance(self): + if not self.authenticated: + logging.warning("Attempted to subscribe to user balance channel without authentication.") + return + params = { + "op": "subscribe", + "args": [{"channel": "balance_and_position"}] + } + logging.debug(f"[PRIVATE WS] Sending subscription: {json.dumps(params)}") + self.ws_private.send(json.dumps(params)) + + def get_balance(self, currency=None): + url = f"{self.REST_URL}/api/v5/account/balance" + timestamp = self._get_timestamp() + method = "GET" + request_path = "/api/v5/account/balance" + body = '' + sign = self._sign(timestamp, method, request_path, body) + headers = { + "OK-ACCESS-KEY": self.api_key, + "OK-ACCESS-SIGN": sign, + "OK-ACCESS-TIMESTAMP": timestamp, + "OK-ACCESS-PASSPHRASE": self.api_passphrase, + "Content-Type": "application/json" + } + resp = requests.get(url, headers=headers) + if resp.status_code == 200: + data = resp.json() + balances = data.get("data", [{}])[0].get("details", []) + if currency: + return [b for b in balances if b.get("ccy") == currency] + return balances + return [] + + def place_order(self, side, amount, instrument="BTC-USDT"): + url = f"{self.REST_URL}/api/v5/trade/order" + timestamp = self._get_timestamp() + method = "POST" + request_path = "/api/v5/trade/order" + body_dict = { + "instId": instrument, + "tdMode": "cash", + "side": side.lower(), + "ordType": "market", + "sz": str(amount) + } + body = json.dumps(body_dict) + sign = self._sign(timestamp, method, request_path, body) + headers = { + "OK-ACCESS-KEY": self.api_key, + "OK-ACCESS-SIGN": sign, + "OK-ACCESS-TIMESTAMP": timestamp, + "OK-ACCESS-PASSPHRASE": self.api_passphrase, + "Content-Type": "application/json" + } + resp = requests.post(url, headers=headers, data=body) + return resp.json() + + def buy_btc(self, amount, instrument="BTC-USDT"): + return self.place_order("buy", amount, instrument) + + def sell_btc(self, amount, instrument="BTC-USDT"): + return self.place_order("sell", amount, instrument) + + def get_instruments(self): + url = f"{self.REST_URL}/api/v5/public/instruments?instType=SPOT" + resp = requests.get(url) + if resp.status_code == 200: + data = resp.json() + return data.get("data", []) + return [] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3e391e08cdbad83c1d8609675a5118cd7ecb0b22 GIT binary patch literal 442 zcmYL_-EM+F5QWdRiH`zN(6(OqFhFT)N`bXGC2G}i`sS{IH@k|l^Oa|;i{3#Nu4^xQwC%pt1h?09 zaWtQFaq7Xh@Cv-*KJhhorf~whbYK!k1-vIu(5(0qYCEO}H}V!;BX?d BOOK_HISTORY_SECONDS * 1000: + book_history.popleft() + + # Prune trade_history to only keep last TRADE_HISTORY_SECONDS + while trade_history and now - float(trade_history[0]['timestamp']) > TRADE_HISTORY_SECONDS * 1000: + trade_history.popleft() + + # Aggregate bids/asks from book_history + bids_dict = {} + asks_dict = {} + + for book in book_history: + for price, size, *_ in book['bids']: + price = float(price) + size = float(size) + bids_dict[price] = bids_dict.get(price, 0) + size + + for price, size, *_ in book['asks']: + price = float(price) + size = float(size) + asks_dict[price] = asks_dict.get(price, 0) + size + + try: + # Prepare and sort bids/asks + bids = sorted([[p, s] for p, s in bids_dict.items()], reverse=True) + asks = sorted([[p, s] for p, s in asks_dict.items()]) + + # Cumulative sum + bid_prices = [b[0] for b in bids] + bid_sizes = [b[1] for b in bids] + ask_prices = [a[0] for a in asks] + ask_sizes = [a[1] for a in asks] + bid_cumsum = [sum(bid_sizes[:i+1]) for i in range(len(bid_sizes))] + ask_cumsum = [sum(ask_sizes[:i+1]) for i in range(len(ask_sizes))] + except Exception as e: + bid_prices, bid_cumsum, ask_prices, ask_cumsum = [], [], [], [] + + fig = go.Figure() + + # Add order book lines (primary y-axis) + fig.add_trace(go.Scatter( + x=bid_prices, y=bid_cumsum, mode='lines', name='Bids', + line=dict(color='green'), fill='tozeroy', yaxis='y1' + )) + fig.add_trace(go.Scatter( + x=ask_prices, y=ask_cumsum, mode='lines', name='Asks', + line=dict(color='red'), fill='tozeroy', yaxis='y1' + )) + + trade_volume_by_price = defaultdict(float) + + for trade in trade_history: + price_bin = round(float(trade['price']), 2) + trade_volume_by_price[price_bin] += float(trade['size']) + + prices = list(trade_volume_by_price.keys()) + volumes = list(trade_volume_by_price.values()) + + # Sort by price for display + sorted_pairs = sorted(zip(prices, volumes)) + prices = [p for p, v in sorted_pairs] + volumes = [v for p, v in sorted_pairs] + + # Add trade volume bars (secondary y-axis) + fig.add_trace(go.Bar( + x=prices, y=volumes, marker_color='#7ec8e3', name='Trade Volume', + opacity=0.7, yaxis='y2' + )) + + # Update layout for dual y-axes + fig.update_layout( + title='Order Book Depth & Realized Trade Volume by Price', + xaxis=dict(title='Price'), + yaxis=dict(title='Cumulative Size', side='left'), + yaxis2=dict( + title='Traded Volume', + overlaying='y', + side='right', + showgrid=False + ), + template='plotly_dark' + ) + return [fig] + + app.run(debug=True, use_reloader=False)