2025-06-02 12:35:19 +08:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
"""
|
|
|
|
|
Quick OKX Aggregation Test
|
|
|
|
|
|
|
|
|
|
A simplified version for quick testing of different symbols and timeframe combinations.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
import logging
|
|
|
|
|
import sys
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
from decimal import Decimal
|
|
|
|
|
from typing import Dict, List, Any
|
|
|
|
|
|
|
|
|
|
# Import our modules
|
|
|
|
|
from data.common.data_types import StandardizedTrade, CandleProcessingConfig, OHLCVCandle
|
2025-06-07 01:17:22 +08:00
|
|
|
from data.common.aggregation.realtime import RealTimeCandleProcessor
|
2025-06-02 12:35:19 +08:00
|
|
|
from data.exchanges.okx.websocket import OKXWebSocketClient, OKXSubscription, OKXChannelType
|
|
|
|
|
|
|
|
|
|
# Set up minimal logging
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s: %(message)s', datefmt='%H:%M:%S')
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QuickAggregationTester:
|
|
|
|
|
"""Quick tester for real-time aggregation."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, symbol: str, timeframes: List[str]):
|
|
|
|
|
self.symbol = symbol
|
|
|
|
|
self.timeframes = timeframes
|
|
|
|
|
self.ws_client = None
|
|
|
|
|
|
|
|
|
|
# Create processor
|
|
|
|
|
config = CandleProcessingConfig(timeframes=timeframes, auto_save_candles=False)
|
|
|
|
|
self.processor = RealTimeCandleProcessor(symbol, "okx", config, logger=logger)
|
|
|
|
|
self.processor.add_candle_callback(self._on_candle)
|
|
|
|
|
|
|
|
|
|
# Stats
|
|
|
|
|
self.trade_count = 0
|
|
|
|
|
self.candle_counts = {tf: 0 for tf in timeframes}
|
|
|
|
|
|
|
|
|
|
logger.info(f"Testing {symbol} with timeframes: {', '.join(timeframes)}")
|
|
|
|
|
|
|
|
|
|
async def run(self, duration: int = 60):
|
|
|
|
|
"""Run the test for specified duration."""
|
|
|
|
|
try:
|
|
|
|
|
# Connect and subscribe
|
|
|
|
|
await self._setup_websocket()
|
|
|
|
|
await self._subscribe()
|
|
|
|
|
|
|
|
|
|
logger.info(f"🔍 Monitoring for {duration} seconds...")
|
|
|
|
|
start_time = datetime.now(timezone.utc)
|
|
|
|
|
|
|
|
|
|
# Monitor
|
|
|
|
|
while (datetime.now(timezone.utc) - start_time).total_seconds() < duration:
|
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
|
self._print_quick_status()
|
|
|
|
|
|
|
|
|
|
# Final stats
|
|
|
|
|
self._print_final_stats(duration)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Test failed: {e}")
|
|
|
|
|
finally:
|
|
|
|
|
if self.ws_client:
|
|
|
|
|
await self.ws_client.disconnect()
|
|
|
|
|
|
|
|
|
|
async def _setup_websocket(self):
|
|
|
|
|
"""Setup WebSocket connection."""
|
|
|
|
|
self.ws_client = OKXWebSocketClient("quick_test", logger=logger)
|
|
|
|
|
self.ws_client.add_message_callback(self._on_message)
|
|
|
|
|
|
|
|
|
|
if not await self.ws_client.connect(use_public=True):
|
|
|
|
|
raise RuntimeError("Failed to connect")
|
|
|
|
|
|
|
|
|
|
logger.info("✅ Connected to OKX")
|
|
|
|
|
|
|
|
|
|
async def _subscribe(self):
|
|
|
|
|
"""Subscribe to trades."""
|
|
|
|
|
subscription = OKXSubscription("trades", self.symbol, True)
|
|
|
|
|
if not await self.ws_client.subscribe([subscription]):
|
|
|
|
|
raise RuntimeError("Failed to subscribe")
|
|
|
|
|
|
|
|
|
|
logger.info(f"✅ Subscribed to {self.symbol} trades")
|
|
|
|
|
|
|
|
|
|
def _on_message(self, message: Dict[str, Any]):
|
|
|
|
|
"""Handle WebSocket message."""
|
|
|
|
|
try:
|
|
|
|
|
if not isinstance(message, dict) or 'data' not in message:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
arg = message.get('arg', {})
|
|
|
|
|
if arg.get('channel') != 'trades' or arg.get('instId') != self.symbol:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
for trade_data in message['data']:
|
|
|
|
|
self._process_trade(trade_data)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Message processing error: {e}")
|
|
|
|
|
|
|
|
|
|
def _process_trade(self, trade_data: Dict[str, Any]):
|
|
|
|
|
"""Process trade data."""
|
|
|
|
|
try:
|
|
|
|
|
self.trade_count += 1
|
|
|
|
|
|
|
|
|
|
# Create standardized trade
|
|
|
|
|
trade = StandardizedTrade(
|
|
|
|
|
symbol=trade_data['instId'],
|
|
|
|
|
trade_id=trade_data['tradeId'],
|
|
|
|
|
price=Decimal(trade_data['px']),
|
|
|
|
|
size=Decimal(trade_data['sz']),
|
|
|
|
|
side=trade_data['side'],
|
|
|
|
|
timestamp=datetime.fromtimestamp(int(trade_data['ts']) / 1000, tz=timezone.utc),
|
|
|
|
|
exchange="okx",
|
|
|
|
|
raw_data=trade_data
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Process through aggregation
|
|
|
|
|
self.processor.process_trade(trade)
|
|
|
|
|
|
|
|
|
|
# Log every 20th trade
|
|
|
|
|
if self.trade_count % 20 == 1:
|
|
|
|
|
logger.info(f"Trade #{self.trade_count}: {trade.side} {trade.size} @ ${trade.price}")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Trade processing error: {e}")
|
|
|
|
|
|
|
|
|
|
def _on_candle(self, candle: OHLCVCandle):
|
|
|
|
|
"""Handle completed candle."""
|
|
|
|
|
self.candle_counts[candle.timeframe] += 1
|
|
|
|
|
|
|
|
|
|
# Calculate metrics
|
|
|
|
|
change = candle.close - candle.open
|
|
|
|
|
change_pct = (change / candle.open * 100) if candle.open > 0 else 0
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"🕯️ {candle.timeframe.upper()} at {candle.end_time.strftime('%H:%M:%S')}: "
|
|
|
|
|
f"${candle.close} ({change_pct:+.2f}%) V={candle.volume} T={candle.trade_count}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _print_quick_status(self):
|
|
|
|
|
"""Print quick status update."""
|
|
|
|
|
total_candles = sum(self.candle_counts.values())
|
|
|
|
|
candle_summary = ", ".join([f"{tf}:{count}" for tf, count in self.candle_counts.items()])
|
|
|
|
|
logger.info(f"📊 Trades: {self.trade_count} | Candles: {total_candles} ({candle_summary})")
|
|
|
|
|
|
|
|
|
|
def _print_final_stats(self, duration: int):
|
|
|
|
|
"""Print final statistics."""
|
|
|
|
|
logger.info("=" * 50)
|
|
|
|
|
logger.info("📊 FINAL RESULTS")
|
|
|
|
|
logger.info(f"Duration: {duration}s")
|
|
|
|
|
logger.info(f"Trades processed: {self.trade_count}")
|
|
|
|
|
logger.info(f"Trade rate: {self.trade_count/duration:.1f}/sec")
|
|
|
|
|
|
|
|
|
|
total_candles = sum(self.candle_counts.values())
|
|
|
|
|
logger.info(f"Total candles: {total_candles}")
|
|
|
|
|
|
|
|
|
|
for tf in self.timeframes:
|
|
|
|
|
count = self.candle_counts[tf]
|
|
|
|
|
expected = self._expected_candles(tf, duration)
|
|
|
|
|
logger.info(f" {tf}: {count} candles (expected ~{expected})")
|
|
|
|
|
|
|
|
|
|
logger.info("=" * 50)
|
|
|
|
|
|
|
|
|
|
def _expected_candles(self, timeframe: str, duration: int) -> int:
|
|
|
|
|
"""Calculate expected number of candles."""
|
|
|
|
|
if timeframe == '1s':
|
|
|
|
|
return duration
|
|
|
|
|
elif timeframe == '5s':
|
|
|
|
|
return duration // 5
|
|
|
|
|
elif timeframe == '10s':
|
|
|
|
|
return duration // 10
|
|
|
|
|
elif timeframe == '15s':
|
|
|
|
|
return duration // 15
|
|
|
|
|
elif timeframe == '30s':
|
|
|
|
|
return duration // 30
|
|
|
|
|
elif timeframe == '1m':
|
|
|
|
|
return duration // 60
|
|
|
|
|
else:
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def main():
|
|
|
|
|
"""Main function with argument parsing."""
|
|
|
|
|
# Parse command line arguments
|
|
|
|
|
symbol = sys.argv[1] if len(sys.argv) > 1 else "BTC-USDT"
|
|
|
|
|
duration = int(sys.argv[2]) if len(sys.argv) > 2 else 60
|
|
|
|
|
|
|
|
|
|
# Default to testing all second timeframes
|
|
|
|
|
timeframes = sys.argv[3].split(',') if len(sys.argv) > 3 else ['1s', '5s', '10s', '15s', '30s']
|
|
|
|
|
|
|
|
|
|
print(f"🚀 Quick Aggregation Test")
|
|
|
|
|
print(f"Symbol: {symbol}")
|
|
|
|
|
print(f"Duration: {duration} seconds")
|
|
|
|
|
print(f"Timeframes: {', '.join(timeframes)}")
|
|
|
|
|
print("Press Ctrl+C to stop early\n")
|
|
|
|
|
|
|
|
|
|
# Run test
|
|
|
|
|
tester = QuickAggregationTester(symbol, timeframes)
|
|
|
|
|
await tester.run(duration)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
try:
|
|
|
|
|
asyncio.run(main())
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
print("\n⏹️ Test stopped")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"\n❌ Error: {e}")
|
|
|
|
|
import traceback
|
|
|
|
|
traceback.print_exc()
|