306 lines
12 KiB
Python
306 lines
12 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Test script for the refactored OKX data collection system.
|
||
|
|
|
||
|
|
This script tests the new common data processing framework and OKX-specific
|
||
|
|
implementations including data validation, transformation, and aggregation.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import signal
|
||
|
|
import sys
|
||
|
|
import time
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from decimal import Decimal
|
||
|
|
|
||
|
|
sys.path.append('.')
|
||
|
|
|
||
|
|
from data.exchanges.okx import OKXCollector
|
||
|
|
from data.exchanges.okx.data_processor import OKXDataProcessor
|
||
|
|
from data.common import (
|
||
|
|
create_standardized_trade,
|
||
|
|
StandardizedTrade,
|
||
|
|
OHLCVCandle,
|
||
|
|
RealTimeCandleProcessor,
|
||
|
|
CandleProcessingConfig
|
||
|
|
)
|
||
|
|
from data.base_collector import DataType
|
||
|
|
from utils.logger import get_logger
|
||
|
|
|
||
|
|
# Global test state
|
||
|
|
test_stats = {
|
||
|
|
'start_time': None,
|
||
|
|
'total_trades': 0,
|
||
|
|
'total_candles': 0,
|
||
|
|
'total_errors': 0,
|
||
|
|
'collectors': []
|
||
|
|
}
|
||
|
|
|
||
|
|
# Signal handler for graceful shutdown
|
||
|
|
def signal_handler(signum, frame):
|
||
|
|
logger = get_logger("main")
|
||
|
|
logger.info(f"Received signal {signum}, shutting down gracefully...")
|
||
|
|
|
||
|
|
# Stop all collectors
|
||
|
|
for collector in test_stats['collectors']:
|
||
|
|
try:
|
||
|
|
if hasattr(collector, 'stop'):
|
||
|
|
asyncio.create_task(collector.stop())
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error stopping collector: {e}")
|
||
|
|
|
||
|
|
sys.exit(0)
|
||
|
|
|
||
|
|
# Register signal handlers
|
||
|
|
signal.signal(signal.SIGINT, signal_handler)
|
||
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
||
|
|
|
||
|
|
|
||
|
|
class RealOKXCollector(OKXCollector):
|
||
|
|
"""Real OKX collector that actually stores to database (if available)."""
|
||
|
|
|
||
|
|
def __init__(self, *args, enable_db_storage=False, **kwargs):
|
||
|
|
super().__init__(*args, **kwargs)
|
||
|
|
self._enable_db_storage = enable_db_storage
|
||
|
|
self._test_mode = True
|
||
|
|
self._raw_data_count = 0
|
||
|
|
self._candle_storage_count = 0
|
||
|
|
|
||
|
|
if not enable_db_storage:
|
||
|
|
# Override database storage for testing
|
||
|
|
self._db_manager = None
|
||
|
|
self._raw_data_manager = None
|
||
|
|
|
||
|
|
async def _store_processed_data(self, data_point) -> None:
|
||
|
|
"""Store or log raw data depending on configuration."""
|
||
|
|
self._raw_data_count += 1
|
||
|
|
if self._enable_db_storage and self._db_manager:
|
||
|
|
# Actually store to database
|
||
|
|
await super()._store_processed_data(data_point)
|
||
|
|
self.logger.debug(f"[REAL] Stored raw data: {data_point.data_type.value} for {data_point.symbol} in raw_trades table")
|
||
|
|
else:
|
||
|
|
# Just log for testing
|
||
|
|
self.logger.debug(f"[TEST] Would store raw data: {data_point.data_type.value} for {data_point.symbol} in raw_trades table")
|
||
|
|
|
||
|
|
async def _store_completed_candle(self, candle) -> None:
|
||
|
|
"""Store or log completed candle depending on configuration."""
|
||
|
|
self._candle_storage_count += 1
|
||
|
|
if self._enable_db_storage and self._db_manager:
|
||
|
|
# Actually store to database
|
||
|
|
await super()._store_completed_candle(candle)
|
||
|
|
self.logger.info(f"[REAL] Stored candle: {candle.symbol} {candle.timeframe} O:{candle.open} H:{candle.high} L:{candle.low} C:{candle.close} V:{candle.volume} in market_data table")
|
||
|
|
else:
|
||
|
|
# Just log for testing
|
||
|
|
self.logger.info(f"[TEST] Would store candle: {candle.symbol} {candle.timeframe} O:{candle.open} H:{candle.high} L:{candle.low} C:{candle.close} V:{candle.volume} in market_data table")
|
||
|
|
|
||
|
|
async def _store_raw_data(self, channel: str, raw_message: dict) -> None:
|
||
|
|
"""Store or log raw WebSocket data depending on configuration."""
|
||
|
|
if self._enable_db_storage and self._raw_data_manager:
|
||
|
|
# Actually store to database
|
||
|
|
await super()._store_raw_data(channel, raw_message)
|
||
|
|
if 'data' in raw_message:
|
||
|
|
self.logger.debug(f"[REAL] Stored {len(raw_message['data'])} raw WebSocket items for channel {channel} in raw_trades table")
|
||
|
|
else:
|
||
|
|
# Just log for testing
|
||
|
|
if 'data' in raw_message:
|
||
|
|
self.logger.debug(f"[TEST] Would store {len(raw_message['data'])} raw WebSocket items for channel {channel} in raw_trades table")
|
||
|
|
|
||
|
|
def get_test_stats(self) -> dict:
|
||
|
|
"""Get test-specific statistics."""
|
||
|
|
base_stats = self.get_status()
|
||
|
|
base_stats.update({
|
||
|
|
'test_mode': self._test_mode,
|
||
|
|
'db_storage_enabled': self._enable_db_storage,
|
||
|
|
'raw_data_stored': self._raw_data_count,
|
||
|
|
'candles_stored': self._candle_storage_count
|
||
|
|
})
|
||
|
|
return base_stats
|
||
|
|
|
||
|
|
|
||
|
|
async def test_common_utilities():
|
||
|
|
"""Test the common data processing utilities."""
|
||
|
|
logger = get_logger("refactored_test")
|
||
|
|
logger.info("Testing common data utilities...")
|
||
|
|
|
||
|
|
# Test create_standardized_trade
|
||
|
|
trade = create_standardized_trade(
|
||
|
|
symbol="BTC-USDT",
|
||
|
|
trade_id="12345",
|
||
|
|
price=Decimal("50000.50"),
|
||
|
|
size=Decimal("0.1"),
|
||
|
|
side="buy",
|
||
|
|
timestamp=datetime.now(timezone.utc),
|
||
|
|
exchange="okx",
|
||
|
|
raw_data={"test": "data"}
|
||
|
|
)
|
||
|
|
logger.info(f"Created standardized trade: {trade}")
|
||
|
|
|
||
|
|
# Test OKX data processor
|
||
|
|
processor = OKXDataProcessor("BTC-USDT", component_name="test_processor")
|
||
|
|
|
||
|
|
# Test with sample OKX message
|
||
|
|
sample_message = {
|
||
|
|
"arg": {"channel": "trades", "instId": "BTC-USDT"},
|
||
|
|
"data": [{
|
||
|
|
"instId": "BTC-USDT",
|
||
|
|
"tradeId": "123456789",
|
||
|
|
"px": "50000.50",
|
||
|
|
"sz": "0.1",
|
||
|
|
"side": "buy",
|
||
|
|
"ts": str(int(datetime.now(timezone.utc).timestamp() * 1000))
|
||
|
|
}]
|
||
|
|
}
|
||
|
|
|
||
|
|
success, data_points, errors = processor.validate_and_process_message(sample_message)
|
||
|
|
logger.info(f"Message processing successful: {len(data_points)} data points")
|
||
|
|
if data_points:
|
||
|
|
logger.info(f"Data point: {data_points[0].exchange} {data_points[0].symbol} {data_points[0].data_type.value}")
|
||
|
|
|
||
|
|
# Get processor statistics
|
||
|
|
stats = processor.get_processing_stats()
|
||
|
|
logger.info(f"Processor stats: {stats}")
|
||
|
|
|
||
|
|
|
||
|
|
async def test_single_collector(symbol: str, duration: int = 30, enable_db_storage: bool = False):
|
||
|
|
"""Test a single OKX collector for the specified duration."""
|
||
|
|
logger = get_logger("refactored_test")
|
||
|
|
logger.info(f"Testing OKX collector for {symbol} for {duration} seconds...")
|
||
|
|
|
||
|
|
# Create collector (Real or Test version based on flag)
|
||
|
|
if enable_db_storage:
|
||
|
|
logger.info(f"Using REAL database storage for {symbol}")
|
||
|
|
collector = RealOKXCollector(
|
||
|
|
symbol=symbol,
|
||
|
|
data_types=[DataType.TRADE, DataType.ORDERBOOK, DataType.TICKER],
|
||
|
|
store_raw_data=True,
|
||
|
|
enable_db_storage=True
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
logger.info(f"Using TEST mode (no database) for {symbol}")
|
||
|
|
collector = RealOKXCollector(
|
||
|
|
symbol=symbol,
|
||
|
|
data_types=[DataType.TRADE, DataType.ORDERBOOK, DataType.TICKER],
|
||
|
|
store_raw_data=True,
|
||
|
|
enable_db_storage=False
|
||
|
|
)
|
||
|
|
|
||
|
|
test_stats['collectors'].append(collector)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Connect and start collection
|
||
|
|
if not await collector.connect():
|
||
|
|
logger.error(f"Failed to connect collector for {symbol}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
if not await collector.subscribe_to_data([symbol], collector.data_types):
|
||
|
|
logger.error(f"Failed to subscribe to data for {symbol}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
if not await collector.start():
|
||
|
|
logger.error(f"Failed to start collector for {symbol}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
logger.info(f"Successfully started collector for {symbol}")
|
||
|
|
|
||
|
|
# Monitor for specified duration
|
||
|
|
start_time = time.time()
|
||
|
|
while time.time() - start_time < duration:
|
||
|
|
await asyncio.sleep(5)
|
||
|
|
|
||
|
|
# Get and log statistics
|
||
|
|
stats = collector.get_test_stats()
|
||
|
|
logger.info(f"[{symbol}] Stats: "
|
||
|
|
f"Messages: {stats['processing_stats']['messages_received']}, "
|
||
|
|
f"Trades: {stats['processing_stats']['trades_processed']}, "
|
||
|
|
f"Candles: {stats['processing_stats']['candles_processed']}, "
|
||
|
|
f"Raw stored: {stats['raw_data_stored']}, "
|
||
|
|
f"Candles stored: {stats['candles_stored']}")
|
||
|
|
|
||
|
|
# Stop collector
|
||
|
|
await collector.unsubscribe_from_data([symbol], collector.data_types)
|
||
|
|
await collector.stop()
|
||
|
|
await collector.disconnect()
|
||
|
|
|
||
|
|
logger.info(f"Completed test for {symbol}")
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error in collector test for {symbol}: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
async def test_multiple_collectors(symbols: list, duration: int = 45):
|
||
|
|
"""Test multiple collectors running in parallel."""
|
||
|
|
logger = get_logger("refactored_test")
|
||
|
|
logger.info(f"Testing multiple collectors for {symbols} for {duration} seconds...")
|
||
|
|
|
||
|
|
# Create separate tasks for each unique symbol (avoid duplicates)
|
||
|
|
unique_symbols = list(set(symbols)) # Remove duplicates
|
||
|
|
tasks = []
|
||
|
|
|
||
|
|
for symbol in unique_symbols:
|
||
|
|
logger.info(f"Testing OKX collector for {symbol} for {duration} seconds...")
|
||
|
|
task = asyncio.create_task(test_single_collector(symbol, duration))
|
||
|
|
tasks.append(task)
|
||
|
|
|
||
|
|
# Wait for all collectors to complete
|
||
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
|
|
||
|
|
# Count successful collectors
|
||
|
|
successful = sum(1 for result in results if result is True)
|
||
|
|
logger.info(f"Multi-collector test completed: {successful}/{len(unique_symbols)} successful")
|
||
|
|
|
||
|
|
return successful == len(unique_symbols)
|
||
|
|
|
||
|
|
|
||
|
|
async def main():
|
||
|
|
"""Main test function."""
|
||
|
|
test_stats['start_time'] = time.time()
|
||
|
|
|
||
|
|
logger = get_logger("main")
|
||
|
|
logger.info("Starting refactored OKX test suite...")
|
||
|
|
|
||
|
|
# Check if user wants real database storage
|
||
|
|
import sys
|
||
|
|
enable_db_storage = '--real-db' in sys.argv
|
||
|
|
if enable_db_storage:
|
||
|
|
logger.info("🗄️ REAL DATABASE STORAGE ENABLED")
|
||
|
|
logger.info(" Raw trades and completed candles will be stored in database tables")
|
||
|
|
else:
|
||
|
|
logger.info("🧪 TEST MODE ENABLED (default)")
|
||
|
|
logger.info(" Database operations will be simulated (no actual storage)")
|
||
|
|
logger.info(" Use --real-db flag to enable real database storage")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Test 1: Common utilities
|
||
|
|
await test_common_utilities()
|
||
|
|
|
||
|
|
# Test 2: Single collector (with optional real DB storage)
|
||
|
|
await test_single_collector("BTC-USDT", 30, enable_db_storage)
|
||
|
|
|
||
|
|
# Test 3: Multiple collectors (unique symbols only)
|
||
|
|
unique_symbols = ["BTC-USDT", "ETH-USDT"] # Ensure no duplicates
|
||
|
|
await test_multiple_collectors(unique_symbols, 45)
|
||
|
|
|
||
|
|
# Final results
|
||
|
|
runtime = time.time() - test_stats['start_time']
|
||
|
|
logger.info("=== FINAL TEST RESULTS ===")
|
||
|
|
logger.info(f"Total runtime: {runtime:.1f}s")
|
||
|
|
logger.info(f"Total trades: {test_stats['total_trades']}")
|
||
|
|
logger.info(f"Total candles: {test_stats['total_candles']}")
|
||
|
|
logger.info(f"Total errors: {test_stats['total_errors']}")
|
||
|
|
if enable_db_storage:
|
||
|
|
logger.info("✅ All tests completed successfully with REAL database storage!")
|
||
|
|
else:
|
||
|
|
logger.info("✅ All tests completed successfully in TEST mode!")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Test suite failed: {e}")
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
logger.info("Test suite completed")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
asyncio.run(main())
|