diff --git a/main.py b/main.py index dbe8021..dee375f 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ import threading from collections import deque import time import signal +import queue latest_book = {'bids': [], 'asks': [], 'timestamp': None} book_history = deque() @@ -16,6 +17,9 @@ BOOK_HISTORY_SECONDS = 5 shutdown_flag = threading.Event() +PING_INTERVAL = 25 # N seconds, must be < 30 +PONG_TIMEOUT = 25 # Wait this long for pong after ping + def connect(instrument, max_retries=5): logging.info(f"Connecting to OKX for instrument: {instrument}") retries = 0 @@ -53,6 +57,120 @@ def signal_handler(signum, frame): signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) +def initialize_instruments(instruments): + dbs = {} + clients = {} + last_msg_time = {} + ping_sent = {} + pong_queue = {} + for instrument in instruments: + dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db") + logging.info(f"Database initialized for {instrument}") + clients[instrument] = connect(instrument) + last_msg_time[instrument] = time.time() + ping_sent[instrument] = False + pong_queue[instrument] = queue.Queue() + return dbs, clients, last_msg_time, ping_sent, pong_queue + +def handle_message(data, instrument, db, trade_history, book_history, latest_book, pong_queue): + now = time.time() + try: + msg = json.loads(data) + except Exception as e: + logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}") + return False, now + + # Handle pong + if msg.get('event') == 'pong': + pong_queue[instrument].put(True) + return False, now + + if 'arg' in msg and msg['arg'].get('channel') == 'trades': + for trade in msg.get('data', []): + db.insert_trade({ + 'instrument': instrument, + 'trade_id': trade.get('tradeId'), + 'price': float(trade.get('px')), + 'size': float(trade.get('sz')), + 'side': trade.get('side'), + 'timestamp': trade.get('ts') + }) + ts = float(trade.get('ts', time.time() * 1000)) + trade_history.append({ + 'price': trade.get('px'), + 'size': trade.get('sz'), + 'side': trade.get('side'), + 'timestamp': ts + }) + elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'): + for book in msg.get('data', []): + db.insert_book({ + 'instrument': instrument, + 'bids': book.get('bids'), + 'asks': book.get('asks'), + 'timestamp': book.get('ts') + }) + latest_book['bids'] = book.get('bids', []) + latest_book['asks'] = book.get('asks', []) + latest_book['timestamp'] = book.get('ts') + ts = float(book.get('ts', time.time() * 1000)) + book_history.append({ + 'bids': book.get('bids', []), + 'asks': book.get('asks', []), + 'timestamp': ts + }) + else: + logging.info(f"Unknown message for {instrument}: {msg}") + return True, now + +def handle_ping_pong(ws, instrument, last_msg_time, ping_sent, pong_queue): + now = time.time() + time_since_last = now - last_msg_time[instrument] + + if not ping_sent[instrument] and time_since_last > PING_INTERVAL: + try: + ws.send(json.dumps({'op': 'ping'})) + ping_sent[instrument] = True + logging.info(f"Sent ping to {instrument}") + # Wait for pong + pong_received = False + pong_start = time.time() + while time.time() - pong_start < PONG_TIMEOUT: + try: + pong_received = pong_queue[instrument].get(timeout=1) + if pong_received: + logging.info(f"Pong received from {instrument}") + last_msg_time[instrument] = time.time() + ping_sent[instrument] = False + break + except queue.Empty: + continue + if not pong_received: + raise Exception("Pong not received in time") + except Exception as e: + logging.warning(f"No pong from {instrument}: {e}. Reconnecting...") + return False + return True + +def reconnect_instrument(instrument, clients, last_msg_time, ping_sent, pong_queue): + cleanup(clients[instrument], None) + try: + clients[instrument] = connect(instrument) + last_msg_time[instrument] = time.time() + ping_sent[instrument] = False + pong_queue[instrument] = queue.Queue() + return True + except Exception as e: + logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.") + return False + +def shutdown_cleanup(clients, dbs): + for client in clients.values(): + cleanup(client, None) + for db in dbs.values(): + cleanup(None, db) + logging.info('Shutdown complete.') + def main(): instruments = [ "ETH-USDT", @@ -67,91 +185,59 @@ def main(): "UNI-USDT" ] + # Configure logging to both file and stdout logging.basicConfig( level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s', - filename='market_data_collector.log', - filemode='a' + handlers=[ + logging.FileHandler('market_data_collector.log', mode='a'), + logging.StreamHandler() + ] ) - dbs = {} - clients = {} + dbs, clients, last_msg_time, ping_sent, pong_queue = initialize_instruments(instruments) try: - for instrument in instruments: - dbs[instrument] = MarketDB(market=instrument.replace("-", "_"), db_dir="./data/db") - logging.info(f"Database initialized for {instrument}") - clients[instrument] = connect(instrument) - while not shutdown_flag.is_set(): + now = time.time() for instrument in instruments: client = clients[instrument] db = dbs[instrument] + ws = client.ws try: - data = client.ws.recv() + ws.settimeout(1) + try: + data = ws.recv() + except Exception as e: + if isinstance(e, TimeoutError) or 'timed out' in str(e): + data = None + else: + raise + + if shutdown_flag.is_set(): + break + + if data: + processed, msg_time = handle_message( + data, instrument, db, trade_history, book_history, latest_book, pong_queue + ) + if processed: + last_msg_time[instrument] = now + ping_sent[instrument] = False + + # --- Ping/Pong Keepalive Logic --- + if not handle_ping_pong(ws, instrument, last_msg_time, ping_sent, pong_queue): + if not reconnect_instrument(instrument, clients, last_msg_time, ping_sent, pong_queue): + continue + except Exception as e: logging.warning(f"WebSocket disconnected or error for {instrument}: {e}. Reconnecting...") - cleanup(client, None) - try: - clients[instrument] = connect(instrument) - except Exception as e: - logging.critical(f"Could not reconnect {instrument}: {e}. Skipping.") + if not reconnect_instrument(instrument, clients, last_msg_time, ping_sent, pong_queue): continue continue - if shutdown_flag.is_set(): - break - if data == '': - continue - - try: - msg = json.loads(data) - except Exception as e: - logging.warning(f"Failed to parse JSON for {instrument}: {e}, data: {data}") - continue - - if 'arg' in msg and msg['arg'].get('channel') == 'trades': - for trade in msg.get('data', []): - db.insert_trade({ - 'instrument': instrument, - 'trade_id': trade.get('tradeId'), - 'price': float(trade.get('px')), - 'size': float(trade.get('sz')), - 'side': trade.get('side'), - 'timestamp': trade.get('ts') - }) - ts = float(trade.get('ts', time.time() * 1000)) - trade_history.append({ - 'price': trade.get('px'), - 'size': trade.get('sz'), - 'side': trade.get('side'), - 'timestamp': ts - }) - elif 'arg' in msg and msg['arg'].get('channel', '').startswith('books'): - for book in msg.get('data', []): - db.insert_book({ - 'instrument': instrument, - 'bids': book.get('bids'), - 'asks': book.get('asks'), - 'timestamp': book.get('ts') - }) - latest_book['bids'] = book.get('bids', []) - latest_book['asks'] = book.get('asks', []) - latest_book['timestamp'] = book.get('ts') - ts = float(book.get('ts', time.time() * 1000)) - book_history.append({ - 'bids': book.get('bids', []), - 'asks': book.get('asks', []), - 'timestamp': ts - }) - else: - logging.info(f"Unknown message for {instrument}: {msg}") except Exception as e: logging.critical(f"Fatal error in main: {e}") finally: - for client in clients.values(): - cleanup(client, None) - for db in dbs.values(): - cleanup(None, db) - logging.info('Shutdown complete.') + shutdown_cleanup(clients, dbs) if __name__ == '__main__': main()