- Split the `aggregation.py` file into a dedicated sub-package, improving modularity and maintainability. - Moved `TimeframeBucket`, `RealTimeCandleProcessor`, and `BatchCandleProcessor` classes into their respective files within the new `aggregation` sub-package. - Introduced utility functions for trade aggregation and validation, enhancing code organization. - Updated import paths throughout the codebase to reflect the new structure, ensuring compatibility. - Added safety net tests for the aggregation package to verify core functionality and prevent regressions during refactoring. These changes enhance the overall architecture of the aggregation module, making it more scalable and easier to manage.
307 lines
12 KiB
Python
307 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for the refactored OKX data collection system.
|
|
|
|
This script tests the new common data processing framework and OKX-specific
|
|
implementations including data validation, transformation, and aggregation.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import signal
|
|
import sys
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from decimal import Decimal
|
|
|
|
sys.path.append('.')
|
|
|
|
from data.exchanges.okx import OKXCollector
|
|
from data.exchanges.okx.data_processor import OKXDataProcessor
|
|
from data.common import (
|
|
create_standardized_trade,
|
|
StandardizedTrade,
|
|
OHLCVCandle,
|
|
RealTimeCandleProcessor,
|
|
CandleProcessingConfig
|
|
)
|
|
from data.common.aggregation.realtime import RealTimeCandleProcessor
|
|
from data.base_collector import DataType
|
|
from utils.logger import get_logger
|
|
|
|
# Global test state
|
|
test_stats = {
|
|
'start_time': None,
|
|
'total_trades': 0,
|
|
'total_candles': 0,
|
|
'total_errors': 0,
|
|
'collectors': []
|
|
}
|
|
|
|
# Signal handler for graceful shutdown
|
|
def signal_handler(signum, frame):
|
|
logger = get_logger("main")
|
|
logger.info(f"Received signal {signum}, shutting down gracefully...")
|
|
|
|
# Stop all collectors
|
|
for collector in test_stats['collectors']:
|
|
try:
|
|
if hasattr(collector, 'stop'):
|
|
asyncio.create_task(collector.stop())
|
|
except Exception as e:
|
|
logger.error(f"Error stopping collector: {e}")
|
|
|
|
sys.exit(0)
|
|
|
|
# Register signal handlers
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
|
|
class RealOKXCollector(OKXCollector):
|
|
"""Real OKX collector that actually stores to database (if available)."""
|
|
|
|
def __init__(self, *args, enable_db_storage=False, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._enable_db_storage = enable_db_storage
|
|
self._test_mode = True
|
|
self._raw_data_count = 0
|
|
self._candle_storage_count = 0
|
|
|
|
if not enable_db_storage:
|
|
# Override database storage for testing
|
|
self._db_manager = None
|
|
self._raw_data_manager = None
|
|
|
|
async def _store_processed_data(self, data_point) -> None:
|
|
"""Store or log raw data depending on configuration."""
|
|
self._raw_data_count += 1
|
|
if self._enable_db_storage and self._db_manager:
|
|
# Actually store to database
|
|
await super()._store_processed_data(data_point)
|
|
self.logger.debug(f"[REAL] Stored raw data: {data_point.data_type.value} for {data_point.symbol} in raw_trades table")
|
|
else:
|
|
# Just log for testing
|
|
self.logger.debug(f"[TEST] Would store raw data: {data_point.data_type.value} for {data_point.symbol} in raw_trades table")
|
|
|
|
async def _store_completed_candle(self, candle) -> None:
|
|
"""Store or log completed candle depending on configuration."""
|
|
self._candle_storage_count += 1
|
|
if self._enable_db_storage and self._db_manager:
|
|
# Actually store to database
|
|
await super()._store_completed_candle(candle)
|
|
self.logger.info(f"[REAL] Stored candle: {candle.symbol} {candle.timeframe} O:{candle.open} H:{candle.high} L:{candle.low} C:{candle.close} V:{candle.volume} in market_data table")
|
|
else:
|
|
# Just log for testing
|
|
self.logger.info(f"[TEST] Would store candle: {candle.symbol} {candle.timeframe} O:{candle.open} H:{candle.high} L:{candle.low} C:{candle.close} V:{candle.volume} in market_data table")
|
|
|
|
async def _store_raw_data(self, channel: str, raw_message: dict) -> None:
|
|
"""Store or log raw WebSocket data depending on configuration."""
|
|
if self._enable_db_storage and self._raw_data_manager:
|
|
# Actually store to database
|
|
await super()._store_raw_data(channel, raw_message)
|
|
if 'data' in raw_message:
|
|
self.logger.debug(f"[REAL] Stored {len(raw_message['data'])} raw WebSocket items for channel {channel} in raw_trades table")
|
|
else:
|
|
# Just log for testing
|
|
if 'data' in raw_message:
|
|
self.logger.debug(f"[TEST] Would store {len(raw_message['data'])} raw WebSocket items for channel {channel} in raw_trades table")
|
|
|
|
def get_test_stats(self) -> dict:
|
|
"""Get test-specific statistics."""
|
|
base_stats = self.get_status()
|
|
base_stats.update({
|
|
'test_mode': self._test_mode,
|
|
'db_storage_enabled': self._enable_db_storage,
|
|
'raw_data_stored': self._raw_data_count,
|
|
'candles_stored': self._candle_storage_count
|
|
})
|
|
return base_stats
|
|
|
|
|
|
async def test_common_utilities():
|
|
"""Test the common data processing utilities."""
|
|
logger = get_logger("refactored_test")
|
|
logger.info("Testing common data utilities...")
|
|
|
|
# Test create_standardized_trade
|
|
trade = create_standardized_trade(
|
|
symbol="BTC-USDT",
|
|
trade_id="12345",
|
|
price=Decimal("50000.50"),
|
|
size=Decimal("0.1"),
|
|
side="buy",
|
|
timestamp=datetime.now(timezone.utc),
|
|
exchange="okx",
|
|
raw_data={"test": "data"}
|
|
)
|
|
logger.info(f"Created standardized trade: {trade}")
|
|
|
|
# Test OKX data processor
|
|
processor = OKXDataProcessor("BTC-USDT", component_name="test_processor")
|
|
|
|
# Test with sample OKX message
|
|
sample_message = {
|
|
"arg": {"channel": "trades", "instId": "BTC-USDT"},
|
|
"data": [{
|
|
"instId": "BTC-USDT",
|
|
"tradeId": "123456789",
|
|
"px": "50000.50",
|
|
"sz": "0.1",
|
|
"side": "buy",
|
|
"ts": str(int(datetime.now(timezone.utc).timestamp() * 1000))
|
|
}]
|
|
}
|
|
|
|
success, data_points, errors = processor.validate_and_process_message(sample_message)
|
|
logger.info(f"Message processing successful: {len(data_points)} data points")
|
|
if data_points:
|
|
logger.info(f"Data point: {data_points[0].exchange} {data_points[0].symbol} {data_points[0].data_type.value}")
|
|
|
|
# Get processor statistics
|
|
stats = processor.get_processing_stats()
|
|
logger.info(f"Processor stats: {stats}")
|
|
|
|
|
|
async def test_single_collector(symbol: str, duration: int = 30, enable_db_storage: bool = False):
|
|
"""Test a single OKX collector for the specified duration."""
|
|
logger = get_logger("refactored_test")
|
|
logger.info(f"Testing OKX collector for {symbol} for {duration} seconds...")
|
|
|
|
# Create collector (Real or Test version based on flag)
|
|
if enable_db_storage:
|
|
logger.info(f"Using REAL database storage for {symbol}")
|
|
collector = RealOKXCollector(
|
|
symbol=symbol,
|
|
data_types=[DataType.TRADE, DataType.ORDERBOOK, DataType.TICKER],
|
|
store_raw_data=True,
|
|
enable_db_storage=True
|
|
)
|
|
else:
|
|
logger.info(f"Using TEST mode (no database) for {symbol}")
|
|
collector = RealOKXCollector(
|
|
symbol=symbol,
|
|
data_types=[DataType.TRADE, DataType.ORDERBOOK, DataType.TICKER],
|
|
store_raw_data=True,
|
|
enable_db_storage=False
|
|
)
|
|
|
|
test_stats['collectors'].append(collector)
|
|
|
|
try:
|
|
# Connect and start collection
|
|
if not await collector.connect():
|
|
logger.error(f"Failed to connect collector for {symbol}")
|
|
return False
|
|
|
|
if not await collector.subscribe_to_data([symbol], collector.data_types):
|
|
logger.error(f"Failed to subscribe to data for {symbol}")
|
|
return False
|
|
|
|
if not await collector.start():
|
|
logger.error(f"Failed to start collector for {symbol}")
|
|
return False
|
|
|
|
logger.info(f"Successfully started collector for {symbol}")
|
|
|
|
# Monitor for specified duration
|
|
start_time = time.time()
|
|
while time.time() - start_time < duration:
|
|
await asyncio.sleep(5)
|
|
|
|
# Get and log statistics
|
|
stats = collector.get_test_stats()
|
|
logger.info(f"[{symbol}] Stats: "
|
|
f"Messages: {stats['processing_stats']['messages_received']}, "
|
|
f"Trades: {stats['processing_stats']['trades_processed']}, "
|
|
f"Candles: {stats['processing_stats']['candles_processed']}, "
|
|
f"Raw stored: {stats['raw_data_stored']}, "
|
|
f"Candles stored: {stats['candles_stored']}")
|
|
|
|
# Stop collector
|
|
await collector.unsubscribe_from_data([symbol], collector.data_types)
|
|
await collector.stop()
|
|
await collector.disconnect()
|
|
|
|
logger.info(f"Completed test for {symbol}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in collector test for {symbol}: {e}")
|
|
return False
|
|
|
|
|
|
async def test_multiple_collectors(symbols: list, duration: int = 45):
|
|
"""Test multiple collectors running in parallel."""
|
|
logger = get_logger("refactored_test")
|
|
logger.info(f"Testing multiple collectors for {symbols} for {duration} seconds...")
|
|
|
|
# Create separate tasks for each unique symbol (avoid duplicates)
|
|
unique_symbols = list(set(symbols)) # Remove duplicates
|
|
tasks = []
|
|
|
|
for symbol in unique_symbols:
|
|
logger.info(f"Testing OKX collector for {symbol} for {duration} seconds...")
|
|
task = asyncio.create_task(test_single_collector(symbol, duration))
|
|
tasks.append(task)
|
|
|
|
# Wait for all collectors to complete
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Count successful collectors
|
|
successful = sum(1 for result in results if result is True)
|
|
logger.info(f"Multi-collector test completed: {successful}/{len(unique_symbols)} successful")
|
|
|
|
return successful == len(unique_symbols)
|
|
|
|
|
|
async def main():
|
|
"""Main test function."""
|
|
test_stats['start_time'] = time.time()
|
|
|
|
logger = get_logger("main")
|
|
logger.info("Starting refactored OKX test suite...")
|
|
|
|
# Check if user wants real database storage
|
|
import sys
|
|
enable_db_storage = '--real-db' in sys.argv
|
|
if enable_db_storage:
|
|
logger.info("🗄️ REAL DATABASE STORAGE ENABLED")
|
|
logger.info(" Raw trades and completed candles will be stored in database tables")
|
|
else:
|
|
logger.info("🧪 TEST MODE ENABLED (default)")
|
|
logger.info(" Database operations will be simulated (no actual storage)")
|
|
logger.info(" Use --real-db flag to enable real database storage")
|
|
|
|
try:
|
|
# Test 1: Common utilities
|
|
await test_common_utilities()
|
|
|
|
# Test 2: Single collector (with optional real DB storage)
|
|
await test_single_collector("BTC-USDT", 30, enable_db_storage)
|
|
|
|
# Test 3: Multiple collectors (unique symbols only)
|
|
unique_symbols = ["BTC-USDT", "ETH-USDT"] # Ensure no duplicates
|
|
await test_multiple_collectors(unique_symbols, 45)
|
|
|
|
# Final results
|
|
runtime = time.time() - test_stats['start_time']
|
|
logger.info("=== FINAL TEST RESULTS ===")
|
|
logger.info(f"Total runtime: {runtime:.1f}s")
|
|
logger.info(f"Total trades: {test_stats['total_trades']}")
|
|
logger.info(f"Total candles: {test_stats['total_candles']}")
|
|
logger.info(f"Total errors: {test_stats['total_errors']}")
|
|
if enable_db_storage:
|
|
logger.info("✅ All tests completed successfully with REAL database storage!")
|
|
else:
|
|
logger.info("✅ All tests completed successfully in TEST mode!")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Test suite failed: {e}")
|
|
sys.exit(1)
|
|
|
|
logger.info("Test suite completed")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |