- Split the `aggregation.py` file into a dedicated sub-package, improving modularity and maintainability. - Moved `TimeframeBucket`, `RealTimeCandleProcessor`, and `BatchCandleProcessor` classes into their respective files within the new `aggregation` sub-package. - Introduced utility functions for trade aggregation and validation, enhancing code organization. - Updated import paths throughout the codebase to reflect the new structure, ensuring compatibility. - Added safety net tests for the aggregation package to verify core functionality and prevent regressions during refactoring. These changes enhance the overall architecture of the aggregation module, making it more scalable and easier to manage.
212 lines
7.5 KiB
Python
212 lines
7.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Quick OKX Aggregation Test
|
|
|
|
A simplified version for quick testing of different symbols and timeframe combinations.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import sys
|
|
from datetime import datetime, timezone
|
|
from decimal import Decimal
|
|
from typing import Dict, List, Any
|
|
|
|
# Import our modules
|
|
from data.common.data_types import StandardizedTrade, CandleProcessingConfig, OHLCVCandle
|
|
from data.common.aggregation.realtime import RealTimeCandleProcessor
|
|
from data.exchanges.okx.websocket import OKXWebSocketClient, OKXSubscription, OKXChannelType
|
|
|
|
# Set up minimal logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s: %(message)s', datefmt='%H:%M:%S')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class QuickAggregationTester:
|
|
"""Quick tester for real-time aggregation."""
|
|
|
|
def __init__(self, symbol: str, timeframes: List[str]):
|
|
self.symbol = symbol
|
|
self.timeframes = timeframes
|
|
self.ws_client = None
|
|
|
|
# Create processor
|
|
config = CandleProcessingConfig(timeframes=timeframes, auto_save_candles=False)
|
|
self.processor = RealTimeCandleProcessor(symbol, "okx", config, logger=logger)
|
|
self.processor.add_candle_callback(self._on_candle)
|
|
|
|
# Stats
|
|
self.trade_count = 0
|
|
self.candle_counts = {tf: 0 for tf in timeframes}
|
|
|
|
logger.info(f"Testing {symbol} with timeframes: {', '.join(timeframes)}")
|
|
|
|
async def run(self, duration: int = 60):
|
|
"""Run the test for specified duration."""
|
|
try:
|
|
# Connect and subscribe
|
|
await self._setup_websocket()
|
|
await self._subscribe()
|
|
|
|
logger.info(f"🔍 Monitoring for {duration} seconds...")
|
|
start_time = datetime.now(timezone.utc)
|
|
|
|
# Monitor
|
|
while (datetime.now(timezone.utc) - start_time).total_seconds() < duration:
|
|
await asyncio.sleep(5)
|
|
self._print_quick_status()
|
|
|
|
# Final stats
|
|
self._print_final_stats(duration)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Test failed: {e}")
|
|
finally:
|
|
if self.ws_client:
|
|
await self.ws_client.disconnect()
|
|
|
|
async def _setup_websocket(self):
|
|
"""Setup WebSocket connection."""
|
|
self.ws_client = OKXWebSocketClient("quick_test", logger=logger)
|
|
self.ws_client.add_message_callback(self._on_message)
|
|
|
|
if not await self.ws_client.connect(use_public=True):
|
|
raise RuntimeError("Failed to connect")
|
|
|
|
logger.info("✅ Connected to OKX")
|
|
|
|
async def _subscribe(self):
|
|
"""Subscribe to trades."""
|
|
subscription = OKXSubscription("trades", self.symbol, True)
|
|
if not await self.ws_client.subscribe([subscription]):
|
|
raise RuntimeError("Failed to subscribe")
|
|
|
|
logger.info(f"✅ Subscribed to {self.symbol} trades")
|
|
|
|
def _on_message(self, message: Dict[str, Any]):
|
|
"""Handle WebSocket message."""
|
|
try:
|
|
if not isinstance(message, dict) or 'data' not in message:
|
|
return
|
|
|
|
arg = message.get('arg', {})
|
|
if arg.get('channel') != 'trades' or arg.get('instId') != self.symbol:
|
|
return
|
|
|
|
for trade_data in message['data']:
|
|
self._process_trade(trade_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Message processing error: {e}")
|
|
|
|
def _process_trade(self, trade_data: Dict[str, Any]):
|
|
"""Process trade data."""
|
|
try:
|
|
self.trade_count += 1
|
|
|
|
# Create standardized trade
|
|
trade = StandardizedTrade(
|
|
symbol=trade_data['instId'],
|
|
trade_id=trade_data['tradeId'],
|
|
price=Decimal(trade_data['px']),
|
|
size=Decimal(trade_data['sz']),
|
|
side=trade_data['side'],
|
|
timestamp=datetime.fromtimestamp(int(trade_data['ts']) / 1000, tz=timezone.utc),
|
|
exchange="okx",
|
|
raw_data=trade_data
|
|
)
|
|
|
|
# Process through aggregation
|
|
self.processor.process_trade(trade)
|
|
|
|
# Log every 20th trade
|
|
if self.trade_count % 20 == 1:
|
|
logger.info(f"Trade #{self.trade_count}: {trade.side} {trade.size} @ ${trade.price}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Trade processing error: {e}")
|
|
|
|
def _on_candle(self, candle: OHLCVCandle):
|
|
"""Handle completed candle."""
|
|
self.candle_counts[candle.timeframe] += 1
|
|
|
|
# Calculate metrics
|
|
change = candle.close - candle.open
|
|
change_pct = (change / candle.open * 100) if candle.open > 0 else 0
|
|
|
|
logger.info(
|
|
f"🕯️ {candle.timeframe.upper()} at {candle.end_time.strftime('%H:%M:%S')}: "
|
|
f"${candle.close} ({change_pct:+.2f}%) V={candle.volume} T={candle.trade_count}"
|
|
)
|
|
|
|
def _print_quick_status(self):
|
|
"""Print quick status update."""
|
|
total_candles = sum(self.candle_counts.values())
|
|
candle_summary = ", ".join([f"{tf}:{count}" for tf, count in self.candle_counts.items()])
|
|
logger.info(f"📊 Trades: {self.trade_count} | Candles: {total_candles} ({candle_summary})")
|
|
|
|
def _print_final_stats(self, duration: int):
|
|
"""Print final statistics."""
|
|
logger.info("=" * 50)
|
|
logger.info("📊 FINAL RESULTS")
|
|
logger.info(f"Duration: {duration}s")
|
|
logger.info(f"Trades processed: {self.trade_count}")
|
|
logger.info(f"Trade rate: {self.trade_count/duration:.1f}/sec")
|
|
|
|
total_candles = sum(self.candle_counts.values())
|
|
logger.info(f"Total candles: {total_candles}")
|
|
|
|
for tf in self.timeframes:
|
|
count = self.candle_counts[tf]
|
|
expected = self._expected_candles(tf, duration)
|
|
logger.info(f" {tf}: {count} candles (expected ~{expected})")
|
|
|
|
logger.info("=" * 50)
|
|
|
|
def _expected_candles(self, timeframe: str, duration: int) -> int:
|
|
"""Calculate expected number of candles."""
|
|
if timeframe == '1s':
|
|
return duration
|
|
elif timeframe == '5s':
|
|
return duration // 5
|
|
elif timeframe == '10s':
|
|
return duration // 10
|
|
elif timeframe == '15s':
|
|
return duration // 15
|
|
elif timeframe == '30s':
|
|
return duration // 30
|
|
elif timeframe == '1m':
|
|
return duration // 60
|
|
else:
|
|
return 0
|
|
|
|
|
|
async def main():
|
|
"""Main function with argument parsing."""
|
|
# Parse command line arguments
|
|
symbol = sys.argv[1] if len(sys.argv) > 1 else "BTC-USDT"
|
|
duration = int(sys.argv[2]) if len(sys.argv) > 2 else 60
|
|
|
|
# Default to testing all second timeframes
|
|
timeframes = sys.argv[3].split(',') if len(sys.argv) > 3 else ['1s', '5s', '10s', '15s', '30s']
|
|
|
|
print(f"🚀 Quick Aggregation Test")
|
|
print(f"Symbol: {symbol}")
|
|
print(f"Duration: {duration} seconds")
|
|
print(f"Timeframes: {', '.join(timeframes)}")
|
|
print("Press Ctrl+C to stop early\n")
|
|
|
|
# Run test
|
|
tester = QuickAggregationTester(symbol, timeframes)
|
|
await tester.run(duration)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
print("\n⏹️ Test stopped")
|
|
except Exception as e:
|
|
print(f"\n❌ Error: {e}")
|
|
import traceback
|
|
traceback.print_exc() |