- Split the `aggregation.py` file into a dedicated sub-package, improving modularity and maintainability. - Moved `TimeframeBucket`, `RealTimeCandleProcessor`, and `BatchCandleProcessor` classes into their respective files within the new `aggregation` sub-package. - Introduced utility functions for trade aggregation and validation, enhancing code organization. - Updated import paths throughout the codebase to reflect the new structure, ensuring compatibility. - Added safety net tests for the aggregation package to verify core functionality and prevent regressions during refactoring. These changes enhance the overall architecture of the aggregation module, making it more scalable and easier to manage.
404 lines
16 KiB
Python
404 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Real OKX Data Aggregation Test
|
|
|
|
This script connects to OKX's live WebSocket feed and tests the second-based
|
|
aggregation functionality with real market data. It demonstrates how trades
|
|
are processed into 1s, 5s, 10s, 15s, and 30s candles in real-time.
|
|
|
|
NO DATABASE OPERATIONS - Pure aggregation testing with live data.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from decimal import Decimal
|
|
from typing import Dict, List, Any
|
|
from collections import defaultdict
|
|
|
|
# Import our modules
|
|
from data.common.data_types import StandardizedTrade, CandleProcessingConfig, OHLCVCandle
|
|
from data.common.aggregation.realtime import RealTimeCandleProcessor
|
|
from data.exchanges.okx.websocket import OKXWebSocketClient, OKXSubscription, OKXChannelType
|
|
from data.exchanges.okx.data_processor import OKXDataProcessor
|
|
|
|
# Set up logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
|
datefmt='%H:%M:%S'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RealTimeAggregationTester:
|
|
"""
|
|
Test real-time second-based aggregation with live OKX data.
|
|
"""
|
|
|
|
def __init__(self, symbol: str = "BTC-USDT"):
|
|
self.symbol = symbol
|
|
self.component_name = f"real_test_{symbol.replace('-', '_').lower()}"
|
|
|
|
# WebSocket client
|
|
self._ws_client = None
|
|
|
|
# Aggregation processor with all second timeframes
|
|
self.config = CandleProcessingConfig(
|
|
timeframes=['1s', '5s', '10s', '15s', '30s'],
|
|
auto_save_candles=False, # Don't save to database
|
|
emit_incomplete_candles=False
|
|
)
|
|
|
|
self.processor = RealTimeCandleProcessor(
|
|
symbol=symbol,
|
|
exchange="okx",
|
|
config=self.config,
|
|
component_name=f"{self.component_name}_processor",
|
|
logger=logger
|
|
)
|
|
|
|
# Statistics tracking
|
|
self.stats = {
|
|
'trades_received': 0,
|
|
'trades_processed': 0,
|
|
'candles_completed': defaultdict(int),
|
|
'last_trade_time': None,
|
|
'session_start': datetime.now(timezone.utc)
|
|
}
|
|
|
|
# Candle tracking for analysis
|
|
self.completed_candles = []
|
|
self.latest_candles = {} # Latest candle for each timeframe
|
|
|
|
# Set up callbacks
|
|
self.processor.add_candle_callback(self._on_candle_completed)
|
|
|
|
logger.info(f"Initialized real-time aggregation tester for {symbol}")
|
|
logger.info(f"Testing timeframes: {self.config.timeframes}")
|
|
|
|
async def start_test(self, duration_seconds: int = 300):
|
|
"""
|
|
Start the real-time aggregation test.
|
|
|
|
Args:
|
|
duration_seconds: How long to run the test (default: 5 minutes)
|
|
"""
|
|
try:
|
|
logger.info("=" * 80)
|
|
logger.info("STARTING REAL-TIME OKX AGGREGATION TEST")
|
|
logger.info("=" * 80)
|
|
logger.info(f"Symbol: {self.symbol}")
|
|
logger.info(f"Duration: {duration_seconds} seconds")
|
|
logger.info(f"Timeframes: {', '.join(self.config.timeframes)}")
|
|
logger.info("=" * 80)
|
|
|
|
# Connect to OKX WebSocket
|
|
await self._connect_websocket()
|
|
|
|
# Subscribe to trades
|
|
await self._subscribe_to_trades()
|
|
|
|
# Monitor for specified duration
|
|
await self._monitor_aggregation(duration_seconds)
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Test interrupted by user")
|
|
except Exception as e:
|
|
logger.error(f"Test failed: {e}")
|
|
raise
|
|
finally:
|
|
await self._cleanup()
|
|
await self._print_final_statistics()
|
|
|
|
async def _connect_websocket(self):
|
|
"""Connect to OKX WebSocket."""
|
|
logger.info("Connecting to OKX WebSocket...")
|
|
|
|
self._ws_client = OKXWebSocketClient(
|
|
component_name=f"{self.component_name}_ws",
|
|
ping_interval=25.0,
|
|
pong_timeout=10.0,
|
|
max_reconnect_attempts=3,
|
|
reconnect_delay=5.0,
|
|
logger=logger
|
|
)
|
|
|
|
# Add message callback
|
|
self._ws_client.add_message_callback(self._on_websocket_message)
|
|
|
|
# Connect
|
|
if not await self._ws_client.connect(use_public=True):
|
|
raise RuntimeError("Failed to connect to OKX WebSocket")
|
|
|
|
logger.info("✅ Connected to OKX WebSocket")
|
|
|
|
async def _subscribe_to_trades(self):
|
|
"""Subscribe to trade data for the symbol."""
|
|
logger.info(f"Subscribing to trades for {self.symbol}...")
|
|
|
|
subscription = OKXSubscription(
|
|
channel=OKXChannelType.TRADES.value,
|
|
inst_id=self.symbol,
|
|
enabled=True
|
|
)
|
|
|
|
if not await self._ws_client.subscribe([subscription]):
|
|
raise RuntimeError(f"Failed to subscribe to trades for {self.symbol}")
|
|
|
|
logger.info(f"✅ Subscribed to {self.symbol} trades")
|
|
|
|
def _on_websocket_message(self, message: Dict[str, Any]):
|
|
"""Handle incoming WebSocket message."""
|
|
try:
|
|
# Only process trade data messages
|
|
if not isinstance(message, dict):
|
|
return
|
|
|
|
if 'data' not in message or 'arg' not in message:
|
|
return
|
|
|
|
arg = message['arg']
|
|
if arg.get('channel') != 'trades' or arg.get('instId') != self.symbol:
|
|
return
|
|
|
|
# Process each trade in the message
|
|
for trade_data in message['data']:
|
|
self._process_trade_data(trade_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing WebSocket message: {e}")
|
|
|
|
def _process_trade_data(self, trade_data: Dict[str, Any]):
|
|
"""Process individual trade data."""
|
|
try:
|
|
self.stats['trades_received'] += 1
|
|
|
|
# Convert OKX trade to StandardizedTrade
|
|
trade = StandardizedTrade(
|
|
symbol=trade_data['instId'],
|
|
trade_id=trade_data['tradeId'],
|
|
price=Decimal(trade_data['px']),
|
|
size=Decimal(trade_data['sz']),
|
|
side=trade_data['side'],
|
|
timestamp=datetime.fromtimestamp(int(trade_data['ts']) / 1000, tz=timezone.utc),
|
|
exchange="okx",
|
|
raw_data=trade_data
|
|
)
|
|
|
|
# Update statistics
|
|
self.stats['trades_processed'] += 1
|
|
self.stats['last_trade_time'] = trade.timestamp
|
|
|
|
# Process through aggregation
|
|
completed_candles = self.processor.process_trade(trade)
|
|
|
|
# Log trade details
|
|
if self.stats['trades_processed'] % 10 == 1: # Log every 10th trade
|
|
logger.info(
|
|
f"Trade #{self.stats['trades_processed']}: "
|
|
f"{trade.side.upper()} {trade.size} @ ${trade.price} "
|
|
f"(ID: {trade.trade_id}) at {trade.timestamp.strftime('%H:%M:%S.%f')[:-3]}"
|
|
)
|
|
|
|
# Log completed candles
|
|
if completed_candles:
|
|
logger.info(f"🕯️ Completed {len(completed_candles)} candle(s)")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing trade data: {e}")
|
|
|
|
def _on_candle_completed(self, candle: OHLCVCandle):
|
|
"""Handle completed candle."""
|
|
try:
|
|
# Update statistics
|
|
self.stats['candles_completed'][candle.timeframe] += 1
|
|
self.completed_candles.append(candle)
|
|
self.latest_candles[candle.timeframe] = candle
|
|
|
|
# Calculate candle metrics
|
|
candle_range = candle.high - candle.low
|
|
price_change = candle.close - candle.open
|
|
change_percent = (price_change / candle.open * 100) if candle.open > 0 else 0
|
|
|
|
# Log candle completion with detailed info
|
|
logger.info(
|
|
f"📊 {candle.timeframe.upper()} CANDLE COMPLETED at {candle.end_time.strftime('%H:%M:%S')}: "
|
|
f"O=${candle.open} H=${candle.high} L=${candle.low} C=${candle.close} "
|
|
f"V={candle.volume} T={candle.trade_count} "
|
|
f"Range=${candle_range:.2f} Change={change_percent:+.2f}%"
|
|
)
|
|
|
|
# Show timeframe summary every 10 candles
|
|
total_candles = sum(self.stats['candles_completed'].values())
|
|
if total_candles % 10 == 0:
|
|
self._print_timeframe_summary()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error handling completed candle: {e}")
|
|
|
|
async def _monitor_aggregation(self, duration_seconds: int):
|
|
"""Monitor the aggregation process."""
|
|
logger.info(f"🔍 Monitoring aggregation for {duration_seconds} seconds...")
|
|
logger.info("Waiting for trade data to start arriving...")
|
|
|
|
start_time = datetime.now(timezone.utc)
|
|
last_status_time = start_time
|
|
status_interval = 30 # Print status every 30 seconds
|
|
|
|
while (datetime.now(timezone.utc) - start_time).total_seconds() < duration_seconds:
|
|
await asyncio.sleep(1)
|
|
|
|
current_time = datetime.now(timezone.utc)
|
|
|
|
# Print periodic status
|
|
if (current_time - last_status_time).total_seconds() >= status_interval:
|
|
self._print_status_update(current_time - start_time)
|
|
last_status_time = current_time
|
|
|
|
logger.info("⏰ Test duration completed")
|
|
|
|
def _print_status_update(self, elapsed_time):
|
|
"""Print periodic status update."""
|
|
logger.info("=" * 60)
|
|
logger.info(f"📈 STATUS UPDATE - Elapsed: {elapsed_time.total_seconds():.0f}s")
|
|
logger.info(f"Trades received: {self.stats['trades_received']}")
|
|
logger.info(f"Trades processed: {self.stats['trades_processed']}")
|
|
|
|
if self.stats['last_trade_time']:
|
|
logger.info(f"Last trade: {self.stats['last_trade_time'].strftime('%H:%M:%S.%f')[:-3]}")
|
|
|
|
# Show candle counts
|
|
total_candles = sum(self.stats['candles_completed'].values())
|
|
logger.info(f"Total candles completed: {total_candles}")
|
|
|
|
for timeframe in self.config.timeframes:
|
|
count = self.stats['candles_completed'][timeframe]
|
|
logger.info(f" {timeframe}: {count} candles")
|
|
|
|
# Show current aggregation status
|
|
current_candles = self.processor.get_current_candles(incomplete=True)
|
|
logger.info(f"Current incomplete candles: {len(current_candles)}")
|
|
|
|
# Show latest prices from latest candles
|
|
if self.latest_candles:
|
|
logger.info("Latest candle closes:")
|
|
for tf in self.config.timeframes:
|
|
if tf in self.latest_candles:
|
|
candle = self.latest_candles[tf]
|
|
logger.info(f" {tf}: ${candle.close} (at {candle.end_time.strftime('%H:%M:%S')})")
|
|
|
|
logger.info("=" * 60)
|
|
|
|
def _print_timeframe_summary(self):
|
|
"""Print summary of timeframe performance."""
|
|
logger.info("⚡ TIMEFRAME SUMMARY:")
|
|
|
|
total_candles = sum(self.stats['candles_completed'].values())
|
|
for timeframe in self.config.timeframes:
|
|
count = self.stats['candles_completed'][timeframe]
|
|
percentage = (count / total_candles * 100) if total_candles > 0 else 0
|
|
logger.info(f" {timeframe:>3s}: {count:>3d} candles ({percentage:5.1f}%)")
|
|
|
|
async def _cleanup(self):
|
|
"""Clean up resources."""
|
|
logger.info("🧹 Cleaning up...")
|
|
|
|
if self._ws_client:
|
|
await self._ws_client.disconnect()
|
|
|
|
# Force complete any remaining candles for final analysis
|
|
remaining_candles = self.processor.force_complete_all_candles()
|
|
if remaining_candles:
|
|
logger.info(f"🔚 Force completed {len(remaining_candles)} remaining candles")
|
|
|
|
async def _print_final_statistics(self):
|
|
"""Print comprehensive final statistics."""
|
|
session_duration = datetime.now(timezone.utc) - self.stats['session_start']
|
|
|
|
logger.info("")
|
|
logger.info("=" * 80)
|
|
logger.info("📊 FINAL TEST RESULTS")
|
|
logger.info("=" * 80)
|
|
|
|
# Basic stats
|
|
logger.info(f"Symbol: {self.symbol}")
|
|
logger.info(f"Session duration: {session_duration.total_seconds():.1f} seconds")
|
|
logger.info(f"Total trades received: {self.stats['trades_received']}")
|
|
logger.info(f"Total trades processed: {self.stats['trades_processed']}")
|
|
|
|
if self.stats['trades_processed'] > 0:
|
|
trade_rate = self.stats['trades_processed'] / session_duration.total_seconds()
|
|
logger.info(f"Average trade rate: {trade_rate:.2f} trades/second")
|
|
|
|
# Candle statistics
|
|
total_candles = sum(self.stats['candles_completed'].values())
|
|
logger.info(f"Total candles completed: {total_candles}")
|
|
|
|
logger.info("\nCandles by timeframe:")
|
|
for timeframe in self.config.timeframes:
|
|
count = self.stats['candles_completed'][timeframe]
|
|
percentage = (count / total_candles * 100) if total_candles > 0 else 0
|
|
|
|
# Calculate expected candles
|
|
if timeframe == '1s':
|
|
expected = int(session_duration.total_seconds())
|
|
elif timeframe == '5s':
|
|
expected = int(session_duration.total_seconds() / 5)
|
|
elif timeframe == '10s':
|
|
expected = int(session_duration.total_seconds() / 10)
|
|
elif timeframe == '15s':
|
|
expected = int(session_duration.total_seconds() / 15)
|
|
elif timeframe == '30s':
|
|
expected = int(session_duration.total_seconds() / 30)
|
|
else:
|
|
expected = "N/A"
|
|
|
|
logger.info(f" {timeframe:>3s}: {count:>3d} candles ({percentage:5.1f}%) - Expected: ~{expected}")
|
|
|
|
# Latest candle analysis
|
|
if self.latest_candles:
|
|
logger.info("\nLatest candle closes:")
|
|
for tf in self.config.timeframes:
|
|
if tf in self.latest_candles:
|
|
candle = self.latest_candles[tf]
|
|
logger.info(f" {tf}: ${candle.close}")
|
|
|
|
# Processor statistics
|
|
processor_stats = self.processor.get_stats()
|
|
logger.info(f"\nProcessor statistics:")
|
|
logger.info(f" Trades processed: {processor_stats.get('trades_processed', 0)}")
|
|
logger.info(f" Candles emitted: {processor_stats.get('candles_emitted', 0)}")
|
|
logger.info(f" Errors: {processor_stats.get('errors_count', 0)}")
|
|
|
|
logger.info("=" * 80)
|
|
logger.info("✅ REAL-TIME AGGREGATION TEST COMPLETED SUCCESSFULLY")
|
|
logger.info("=" * 80)
|
|
|
|
|
|
async def main():
|
|
"""Main test function."""
|
|
# Configuration
|
|
SYMBOL = "BTC-USDT" # High-activity pair for good test data
|
|
DURATION = 180 # 3 minutes for good test coverage
|
|
|
|
print("🚀 Real-Time OKX Second-Based Aggregation Test")
|
|
print(f"Testing symbol: {SYMBOL}")
|
|
print(f"Duration: {DURATION} seconds")
|
|
print("Press Ctrl+C to stop early\n")
|
|
|
|
# Create and run tester
|
|
tester = RealTimeAggregationTester(symbol=SYMBOL)
|
|
await tester.start_test(duration_seconds=DURATION)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
print("\n⏹️ Test stopped by user")
|
|
except Exception as e:
|
|
print(f"\n❌ Test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc() |