Add common data processing framework for OKX exchange

- Introduced a modular architecture for data processing, including common utilities for validation, transformation, and aggregation.
- Implemented `StandardizedTrade`, `OHLCVCandle`, and `TimeframeBucket` classes for unified data handling across exchanges.
- Developed `OKXDataProcessor` for OKX-specific data validation and processing, leveraging the new common framework.
- Enhanced `OKXCollector` to utilize the common data processing utilities, improving modularity and maintainability.
- Updated documentation to reflect the new architecture and provide guidance on the data processing framework.
- Created comprehensive tests for the new data processing components to ensure reliability and functionality.
This commit is contained in:
Vasily.onl
2025-05-31 21:58:47 +08:00
parent fa63e7eb2e
commit 8bb5f28fd2
15 changed files with 4015 additions and 214 deletions

52
data/common/__init__.py Normal file
View File

@@ -0,0 +1,52 @@
"""
Common data processing utilities for all exchanges.
This package contains shared components for data validation, transformation,
and aggregation that can be used across different exchange implementations.
"""
from .data_types import (
StandardizedTrade,
OHLCVCandle,
MarketDataPoint,
DataValidationResult
)
from .aggregation import (
TimeframeBucket,
RealTimeCandleProcessor,
CandleProcessingConfig
)
from .transformation import (
BaseDataTransformer,
UnifiedDataTransformer,
create_standardized_trade
)
from .validation import (
BaseDataValidator,
ValidationResult
)
__all__ = [
# Data types
'StandardizedTrade',
'OHLCVCandle',
'MarketDataPoint',
'DataValidationResult',
# Aggregation
'TimeframeBucket',
'RealTimeCandleProcessor',
'CandleProcessingConfig',
# Transformation
'BaseDataTransformer',
'UnifiedDataTransformer',
'create_standardized_trade',
# Validation
'BaseDataValidator',
'ValidationResult'
]

553
data/common/aggregation.py Normal file
View File

@@ -0,0 +1,553 @@
"""
Common aggregation utilities for all exchanges.
This module provides shared functionality for building OHLCV candles
from trade data, regardless of the source exchange.
AGGREGATION STRATEGY:
- Uses RIGHT-ALIGNED timestamps (industry standard)
- Candle timestamp = end time of the interval (close time)
- 5-minute candle with timestamp 09:05:00 represents data from 09:00:01 to 09:05:00
- Prevents future leakage by only completing candles when time boundary is crossed
- Aligns with major exchanges (Binance, OKX, Coinbase)
PROCESS FLOW:
1. Trade arrives with timestamp T
2. Calculate which time bucket this trade belongs to
3. If bucket doesn't exist or time boundary crossed, complete previous bucket
4. Add trade to current bucket
5. Only emit completed candles (never future data)
"""
from datetime import datetime, timezone, timedelta
from decimal import Decimal
from typing import Dict, List, Optional, Any, Iterator, Callable
from collections import defaultdict
from .data_types import (
StandardizedTrade,
OHLCVCandle,
CandleProcessingConfig,
ProcessingStats
)
from utils.logger import get_logger
class TimeframeBucket:
"""
Time bucket for building OHLCV candles from trades.
This class accumulates trades within a specific time period
and calculates OHLCV data incrementally.
IMPORTANT: Uses RIGHT-ALIGNED timestamps
- start_time: Beginning of the interval (inclusive)
- end_time: End of the interval (exclusive) - this becomes the candle timestamp
- Example: 09:00:00 - 09:05:00 bucket -> candle timestamp = 09:05:00
"""
def __init__(self, symbol: str, timeframe: str, start_time: datetime, exchange: str = "unknown"):
"""
Initialize time bucket for candle aggregation.
Args:
symbol: Trading symbol (e.g., 'BTC-USDT')
timeframe: Time period (e.g., '1m', '5m', '1h')
start_time: Start time for this bucket (inclusive)
exchange: Exchange name
"""
self.symbol = symbol
self.timeframe = timeframe
self.start_time = start_time
self.end_time = self._calculate_end_time(start_time, timeframe)
self.exchange = exchange
# OHLCV data
self.open: Optional[Decimal] = None
self.high: Optional[Decimal] = None
self.low: Optional[Decimal] = None
self.close: Optional[Decimal] = None
self.volume: Decimal = Decimal('0')
self.trade_count: int = 0
# Tracking
self.first_trade_time: Optional[datetime] = None
self.last_trade_time: Optional[datetime] = None
self.trades: List[StandardizedTrade] = []
def add_trade(self, trade: StandardizedTrade) -> bool:
"""
Add trade to this bucket if it belongs to this time period.
Args:
trade: Standardized trade data
Returns:
True if trade was added, False if outside time range
"""
# Check if trade belongs in this bucket (start_time <= trade.timestamp < end_time)
if not (self.start_time <= trade.timestamp < self.end_time):
return False
# First trade sets open price
if self.open is None:
self.open = trade.price
self.high = trade.price
self.low = trade.price
self.first_trade_time = trade.timestamp
# Update OHLCV
self.high = max(self.high, trade.price)
self.low = min(self.low, trade.price)
self.close = trade.price # Last trade sets close
self.volume += trade.size
self.trade_count += 1
self.last_trade_time = trade.timestamp
# Store trade for detailed analysis if needed
self.trades.append(trade)
return True
def to_candle(self, is_complete: bool = True) -> OHLCVCandle:
"""
Convert bucket to OHLCV candle.
IMPORTANT: Candle timestamp = end_time (right-aligned, industry standard)
"""
return OHLCVCandle(
symbol=self.symbol,
timeframe=self.timeframe,
start_time=self.start_time,
end_time=self.end_time,
open=self.open or Decimal('0'),
high=self.high or Decimal('0'),
low=self.low or Decimal('0'),
close=self.close or Decimal('0'),
volume=self.volume,
trade_count=self.trade_count,
exchange=self.exchange,
is_complete=is_complete,
first_trade_time=self.first_trade_time,
last_trade_time=self.last_trade_time
)
def _calculate_end_time(self, start_time: datetime, timeframe: str) -> datetime:
"""Calculate end time for this timeframe (right-aligned timestamp)."""
if timeframe == '1m':
return start_time + timedelta(minutes=1)
elif timeframe == '5m':
return start_time + timedelta(minutes=5)
elif timeframe == '15m':
return start_time + timedelta(minutes=15)
elif timeframe == '30m':
return start_time + timedelta(minutes=30)
elif timeframe == '1h':
return start_time + timedelta(hours=1)
elif timeframe == '4h':
return start_time + timedelta(hours=4)
elif timeframe == '1d':
return start_time + timedelta(days=1)
else:
raise ValueError(f"Unsupported timeframe: {timeframe}")
class RealTimeCandleProcessor:
"""
Real-time candle processor for live trade data.
This class processes trades immediately as they arrive from WebSocket,
building candles incrementally and emitting completed candles when
time boundaries are crossed.
AGGREGATION PROCESS (NO FUTURE LEAKAGE):
1. Trade arrives from WebSocket/API with timestamp T
2. For each configured timeframe (1m, 5m, etc.):
a. Calculate which time bucket this trade belongs to
b. Get current bucket for this timeframe
c. Check if trade timestamp crosses time boundary
d. If boundary crossed: complete and emit previous bucket, create new bucket
e. Add trade to current bucket (updates OHLCV)
3. Only emit candles when time boundary is definitively crossed
4. Never emit incomplete/future candles during real-time processing
TIMESTAMP ALIGNMENT:
- Uses RIGHT-ALIGNED timestamps (industry standard)
- 1-minute candle covering 09:00:00-09:01:00 gets timestamp 09:01:00
- 5-minute candle covering 09:00:00-09:05:00 gets timestamp 09:05:00
- Candle represents PAST data, never future
"""
def __init__(self,
symbol: str,
exchange: str,
config: Optional[CandleProcessingConfig] = None,
component_name: str = "realtime_candle_processor"):
"""
Initialize real-time candle processor.
Args:
symbol: Trading symbol (e.g., 'BTC-USDT')
exchange: Exchange name (e.g., 'okx', 'binance')
config: Processing configuration
component_name: Name for logging
"""
self.symbol = symbol
self.exchange = exchange
self.config = config or CandleProcessingConfig()
self.component_name = component_name
self.logger = get_logger(self.component_name)
# Current buckets for each timeframe
self.current_buckets: Dict[str, TimeframeBucket] = {}
# Callback functions for completed candles
self.candle_callbacks: List[Callable[[OHLCVCandle], None]] = []
# Statistics
self.stats = ProcessingStats(active_timeframes=len(self.config.timeframes))
self.logger.info(f"Initialized real-time candle processor for {symbol} on {exchange} with timeframes: {self.config.timeframes}")
def add_candle_callback(self, callback: Callable[[OHLCVCandle], None]) -> None:
"""Add callback function to receive completed candles."""
self.candle_callbacks.append(callback)
self.logger.debug(f"Added candle callback: {callback.__name__ if hasattr(callback, '__name__') else str(callback)}")
def process_trade(self, trade: StandardizedTrade) -> List[OHLCVCandle]:
"""
Process single trade - main entry point for real-time processing.
This is called for each trade as it arrives from WebSocket.
CRITICAL: Only returns completed candles (time boundary crossed)
Never returns incomplete/future candles to prevent leakage.
Args:
trade: Standardized trade data
Returns:
List of completed candles (if any time boundaries were crossed)
"""
try:
completed_candles = []
# Process trade for each timeframe
for timeframe in self.config.timeframes:
candle = self._process_trade_for_timeframe(trade, timeframe)
if candle:
completed_candles.append(candle)
# Update statistics
self.stats.trades_processed += 1
self.stats.last_trade_time = trade.timestamp
# Emit completed candles to callbacks
for candle in completed_candles:
self._emit_candle(candle)
return completed_candles
except Exception as e:
self.logger.error(f"Error processing trade for {self.symbol}: {e}")
self.stats.errors_count += 1
return []
def _process_trade_for_timeframe(self, trade: StandardizedTrade, timeframe: str) -> Optional[OHLCVCandle]:
"""
Process trade for specific timeframe.
CRITICAL LOGIC FOR PREVENTING FUTURE LEAKAGE:
1. Calculate which bucket this trade belongs to
2. Check if current bucket exists and matches
3. If bucket mismatch (time boundary crossed), complete current bucket first
4. Create new bucket and add trade
5. Only return completed candles, never incomplete ones
"""
try:
# Calculate which bucket this trade belongs to
trade_bucket_start = self._get_bucket_start_time(trade.timestamp, timeframe)
# Check if we have a current bucket for this timeframe
current_bucket = self.current_buckets.get(timeframe)
completed_candle = None
# If no bucket exists or time boundary crossed, handle transition
if current_bucket is None:
# First bucket for this timeframe
current_bucket = TimeframeBucket(self.symbol, timeframe, trade_bucket_start, self.exchange)
self.current_buckets[timeframe] = current_bucket
elif current_bucket.start_time != trade_bucket_start:
# Time boundary crossed - complete previous bucket
if current_bucket.trade_count > 0: # Only complete if it has trades
completed_candle = current_bucket.to_candle(is_complete=True)
self.stats.candles_emitted += 1
self.stats.last_candle_time = completed_candle.end_time
# Create new bucket for current time period
current_bucket = TimeframeBucket(self.symbol, timeframe, trade_bucket_start, self.exchange)
self.current_buckets[timeframe] = current_bucket
# Add trade to current bucket
if not current_bucket.add_trade(trade):
# This should never happen if logic is correct
self.logger.warning(f"Trade {trade.timestamp} could not be added to bucket {current_bucket.start_time}-{current_bucket.end_time}")
return completed_candle
except Exception as e:
self.logger.error(f"Error processing trade for timeframe {timeframe}: {e}")
self.stats.errors_count += 1
return None
def _get_bucket_start_time(self, timestamp: datetime, timeframe: str) -> datetime:
"""
Calculate bucket start time for given timestamp and timeframe.
This function determines which time bucket a trade belongs to.
The start time is the LEFT boundary of the interval.
EXAMPLES:
- Trade at 09:03:45 for 5m timeframe -> bucket start = 09:00:00
- Trade at 09:07:23 for 5m timeframe -> bucket start = 09:05:00
- Trade at 14:00:00 for 1h timeframe -> bucket start = 14:00:00
Args:
timestamp: Trade timestamp
timeframe: Target timeframe
Returns:
Bucket start time (left boundary)
"""
# Normalize to UTC and remove microseconds for clean boundaries
dt = timestamp.replace(second=0, microsecond=0)
if timeframe == '1m':
# 1-minute buckets align to minute boundaries
return dt
elif timeframe == '5m':
# 5-minute buckets: 00:00, 00:05, 00:10, etc.
return dt.replace(minute=(dt.minute // 5) * 5)
elif timeframe == '15m':
# 15-minute buckets: 00:00, 00:15, 00:30, 00:45
return dt.replace(minute=(dt.minute // 15) * 15)
elif timeframe == '30m':
# 30-minute buckets: 00:00, 00:30
return dt.replace(minute=(dt.minute // 30) * 30)
elif timeframe == '1h':
# 1-hour buckets align to hour boundaries
return dt.replace(minute=0)
elif timeframe == '4h':
# 4-hour buckets: 00:00, 04:00, 08:00, 12:00, 16:00, 20:00
return dt.replace(minute=0, hour=(dt.hour // 4) * 4)
elif timeframe == '1d':
# 1-day buckets align to day boundaries (midnight UTC)
return dt.replace(minute=0, hour=0)
else:
raise ValueError(f"Unsupported timeframe: {timeframe}")
def _emit_candle(self, candle: OHLCVCandle) -> None:
"""Emit completed candle to all callbacks."""
try:
for callback in self.candle_callbacks:
callback(candle)
except Exception as e:
self.logger.error(f"Error in candle callback: {e}")
self.stats.errors_count += 1
def get_current_candles(self, incomplete: bool = True) -> List[OHLCVCandle]:
"""
Get current incomplete candles for all timeframes.
WARNING: These are incomplete candles and should NOT be used for trading decisions.
They are useful for monitoring/debugging only.
"""
candles = []
for bucket in self.current_buckets.values():
if bucket.trade_count > 0: # Only return buckets with trades
candles.append(bucket.to_candle(is_complete=False))
return candles
def force_complete_all_candles(self) -> List[OHLCVCandle]:
"""
Force completion of all current candles (useful for shutdown/batch processing).
WARNING: This should only be used during shutdown or batch processing,
not during live trading as it forces incomplete candles to be marked complete.
"""
completed_candles = []
for bucket in self.current_buckets.values():
if bucket.trade_count > 0:
candle = bucket.to_candle(is_complete=True)
completed_candles.append(candle)
self._emit_candle(candle)
# Clear buckets
self.current_buckets.clear()
return completed_candles
def get_stats(self) -> Dict[str, Any]:
"""Get processing statistics."""
stats_dict = self.stats.to_dict()
stats_dict['current_buckets'] = {
tf: bucket.trade_count for tf, bucket in self.current_buckets.items()
}
return stats_dict
class BatchCandleProcessor:
"""
Batch candle processor for historical data processing.
This class processes large batches of historical trades efficiently,
building candles for multiple timeframes simultaneously.
"""
def __init__(self,
symbol: str,
exchange: str,
timeframes: List[str],
component_name: str = "batch_candle_processor"):
"""
Initialize batch candle processor.
Args:
symbol: Trading symbol
exchange: Exchange name
timeframes: List of timeframes to process
component_name: Name for logging
"""
self.symbol = symbol
self.exchange = exchange
self.timeframes = timeframes
self.component_name = component_name
self.logger = get_logger(self.component_name)
# Statistics
self.stats = ProcessingStats(active_timeframes=len(timeframes))
self.logger.info(f"Initialized batch candle processor for {symbol} on {exchange}")
def process_trades_to_candles(self, trades: Iterator[StandardizedTrade]) -> List[OHLCVCandle]:
"""
Process trade iterator to candles - optimized for batch processing.
This function handles ALL scenarios:
- Historical: Batch trade iterators
- Backfill: API trade iterators
- Real-time batch: Multiple trades at once
Args:
trades: Iterator of standardized trades
Returns:
List of completed candles
"""
try:
# Create temporary processor for this batch
config = CandleProcessingConfig(timeframes=self.timeframes, auto_save_candles=False)
processor = RealTimeCandleProcessor(
self.symbol, self.exchange, config,
f"batch_processor_{self.symbol}_{self.exchange}"
)
all_candles = []
# Process all trades
for trade in trades:
completed_candles = processor.process_trade(trade)
all_candles.extend(completed_candles)
self.stats.trades_processed += 1
# Force complete any remaining candles
remaining_candles = processor.force_complete_all_candles()
all_candles.extend(remaining_candles)
# Update stats
self.stats.candles_emitted = len(all_candles)
if all_candles:
self.stats.last_candle_time = max(candle.end_time for candle in all_candles)
self.logger.info(f"Batch processed {self.stats.trades_processed} trades to {len(all_candles)} candles")
return all_candles
except Exception as e:
self.logger.error(f"Error in batch processing trades to candles: {e}")
self.stats.errors_count += 1
return []
def get_stats(self) -> Dict[str, Any]:
"""Get processing statistics."""
return self.stats.to_dict()
# Utility functions for common aggregation operations
def aggregate_trades_to_candles(trades: List[StandardizedTrade],
timeframes: List[str],
symbol: str,
exchange: str) -> List[OHLCVCandle]:
"""
Simple utility function to aggregate a list of trades to candles.
Args:
trades: List of standardized trades
timeframes: List of timeframes to generate
symbol: Trading symbol
exchange: Exchange name
Returns:
List of completed candles
"""
processor = BatchCandleProcessor(symbol, exchange, timeframes)
return processor.process_trades_to_candles(iter(trades))
def validate_timeframe(timeframe: str) -> bool:
"""
Validate if timeframe is supported.
Args:
timeframe: Timeframe string (e.g., '1m', '5m', '1h')
Returns:
True if supported, False otherwise
"""
supported = ['1m', '5m', '15m', '30m', '1h', '4h', '1d']
return timeframe in supported
def parse_timeframe(timeframe: str) -> tuple[int, str]:
"""
Parse timeframe string into number and unit.
Args:
timeframe: Timeframe string (e.g., '5m', '1h')
Returns:
Tuple of (number, unit)
Examples:
'5m' -> (5, 'm')
'1h' -> (1, 'h')
'1d' -> (1, 'd')
"""
import re
match = re.match(r'^(\d+)([mhd])$', timeframe.lower())
if not match:
raise ValueError(f"Invalid timeframe format: {timeframe}")
number = int(match.group(1))
unit = match.group(2)
return number, unit
__all__ = [
'TimeframeBucket',
'RealTimeCandleProcessor',
'BatchCandleProcessor',
'aggregate_trades_to_candles',
'validate_timeframe',
'parse_timeframe'
]

182
data/common/data_types.py Normal file
View File

@@ -0,0 +1,182 @@
"""
Common data types for all exchange implementations.
These data structures provide a unified interface for market data
regardless of the source exchange.
"""
from datetime import datetime, timezone
from decimal import Decimal
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from enum import Enum
from ..base_collector import DataType, MarketDataPoint # Import from base
@dataclass
class DataValidationResult:
"""Result of data validation - common across all exchanges."""
is_valid: bool
errors: List[str]
warnings: List[str]
sanitized_data: Optional[Dict[str, Any]] = None
@dataclass
class StandardizedTrade:
"""
Standardized trade format for unified processing across all exchanges.
This format works for both real-time and historical data processing,
ensuring consistency across all data sources and scenarios.
"""
symbol: str
trade_id: str
price: Decimal
size: Decimal
side: str # 'buy' or 'sell'
timestamp: datetime
exchange: str
raw_data: Optional[Dict[str, Any]] = None
def __post_init__(self):
"""Validate and normalize fields after initialization."""
# Ensure timestamp is timezone-aware
if self.timestamp.tzinfo is None:
self.timestamp = self.timestamp.replace(tzinfo=timezone.utc)
# Normalize side to lowercase
self.side = self.side.lower()
# Validate side
if self.side not in ['buy', 'sell']:
raise ValueError(f"Invalid trade side: {self.side}")
@dataclass
class OHLCVCandle:
"""
OHLCV candle data structure for time-based aggregation.
This represents a complete candle for a specific timeframe,
built from aggregating multiple trades within the time period.
"""
symbol: str
timeframe: str
start_time: datetime
end_time: datetime
open: Decimal
high: Decimal
low: Decimal
close: Decimal
volume: Decimal
trade_count: int
exchange: str = "unknown"
is_complete: bool = False
first_trade_time: Optional[datetime] = None
last_trade_time: Optional[datetime] = None
def __post_init__(self):
"""Validate and normalize fields after initialization."""
# Ensure timestamps are timezone-aware
if self.start_time.tzinfo is None:
self.start_time = self.start_time.replace(tzinfo=timezone.utc)
if self.end_time.tzinfo is None:
self.end_time = self.end_time.replace(tzinfo=timezone.utc)
# Validate OHLC relationships
if self.high < self.low:
raise ValueError("High price cannot be less than low price")
if self.open < 0 or self.high < 0 or self.low < 0 or self.close < 0:
raise ValueError("Prices cannot be negative")
if self.volume < 0:
raise ValueError("Volume cannot be negative")
if self.trade_count < 0:
raise ValueError("Trade count cannot be negative")
def to_dict(self) -> Dict[str, Any]:
"""Convert candle to dictionary for storage/serialization."""
return {
'symbol': self.symbol,
'timeframe': self.timeframe,
'start_time': self.start_time.isoformat(),
'end_time': self.end_time.isoformat(),
'open': str(self.open),
'high': str(self.high),
'low': str(self.low),
'close': str(self.close),
'volume': str(self.volume),
'trade_count': self.trade_count,
'exchange': self.exchange,
'is_complete': self.is_complete,
'first_trade_time': self.first_trade_time.isoformat() if self.first_trade_time else None,
'last_trade_time': self.last_trade_time.isoformat() if self.last_trade_time else None
}
@dataclass
class CandleProcessingConfig:
"""Configuration for candle processing - shared across exchanges."""
timeframes: List[str] = field(default_factory=lambda: ['1m', '5m', '15m', '1h'])
auto_save_candles: bool = True
emit_incomplete_candles: bool = False
max_trades_per_candle: int = 100000 # Safety limit
def __post_init__(self):
"""Validate configuration after initialization."""
supported_timeframes = ['1m', '5m', '15m', '30m', '1h', '4h', '1d']
for tf in self.timeframes:
if tf not in supported_timeframes:
raise ValueError(f"Unsupported timeframe: {tf}")
class TradeSide(Enum):
"""Standardized trade side enumeration."""
BUY = "buy"
SELL = "sell"
class TimeframeUnit(Enum):
"""Time units for candle timeframes."""
MINUTE = "m"
HOUR = "h"
DAY = "d"
@dataclass
class ProcessingStats:
"""Common processing statistics structure."""
trades_processed: int = 0
candles_emitted: int = 0
errors_count: int = 0
warnings_count: int = 0
last_trade_time: Optional[datetime] = None
last_candle_time: Optional[datetime] = None
active_timeframes: int = 0
def to_dict(self) -> Dict[str, Any]:
"""Convert stats to dictionary."""
return {
'trades_processed': self.trades_processed,
'candles_emitted': self.candles_emitted,
'errors_count': self.errors_count,
'warnings_count': self.warnings_count,
'last_trade_time': self.last_trade_time.isoformat() if self.last_trade_time else None,
'last_candle_time': self.last_candle_time.isoformat() if self.last_candle_time else None,
'active_timeframes': self.active_timeframes
}
# Re-export from base_collector for convenience
__all__ = [
'DataType',
'MarketDataPoint',
'DataValidationResult',
'StandardizedTrade',
'OHLCVCandle',
'CandleProcessingConfig',
'TradeSide',
'TimeframeUnit',
'ProcessingStats'
]

View File

@@ -0,0 +1,471 @@
"""
Base transformation utilities for all exchanges.
This module provides common transformation patterns and base classes
for converting exchange-specific data to standardized formats.
"""
from datetime import datetime, timezone
from decimal import Decimal
from typing import Dict, List, Optional, Any, Iterator
from abc import ABC, abstractmethod
from .data_types import StandardizedTrade, OHLCVCandle, DataValidationResult
from .aggregation import BatchCandleProcessor
from utils.logger import get_logger
class BaseDataTransformer(ABC):
"""
Abstract base class for exchange data transformers.
This class provides common transformation patterns that can be
extended by exchange-specific implementations.
"""
def __init__(self,
exchange_name: str,
component_name: str = "base_data_transformer"):
"""
Initialize base data transformer.
Args:
exchange_name: Name of the exchange (e.g., 'okx', 'binance')
component_name: Name for logging
"""
self.exchange_name = exchange_name
self.component_name = component_name
self.logger = get_logger(self.component_name)
self.logger.info(f"Initialized base data transformer for {exchange_name}")
# Abstract methods that must be implemented by subclasses
@abstractmethod
def transform_trade_data(self, raw_data: Dict[str, Any], symbol: str) -> Optional[StandardizedTrade]:
"""Transform exchange-specific trade data to standardized format."""
pass
@abstractmethod
def transform_orderbook_data(self, raw_data: Dict[str, Any], symbol: str) -> Optional[Dict[str, Any]]:
"""Transform exchange-specific orderbook data to standardized format."""
pass
@abstractmethod
def transform_ticker_data(self, raw_data: Dict[str, Any], symbol: str) -> Optional[Dict[str, Any]]:
"""Transform exchange-specific ticker data to standardized format."""
pass
# Common transformation utilities available to all subclasses
def timestamp_to_datetime(self, timestamp: Any, is_milliseconds: bool = True) -> datetime:
"""
Convert various timestamp formats to timezone-aware datetime.
Args:
timestamp: Timestamp in various formats
is_milliseconds: True if timestamp is in milliseconds
Returns:
Timezone-aware datetime object
"""
try:
# Convert to int/float
if isinstance(timestamp, str):
timestamp_num = float(timestamp)
elif isinstance(timestamp, (int, float)):
timestamp_num = float(timestamp)
else:
raise ValueError(f"Invalid timestamp type: {type(timestamp)}")
# Convert to seconds if needed
if is_milliseconds:
timestamp_num = timestamp_num / 1000
# Create timezone-aware datetime
dt = datetime.fromtimestamp(timestamp_num, tz=timezone.utc)
return dt
except Exception as e:
self.logger.error(f"Error converting timestamp {timestamp}: {e}")
# Return current time as fallback
return datetime.now(timezone.utc)
def safe_decimal_conversion(self, value: Any, field_name: str = "value") -> Optional[Decimal]:
"""
Safely convert value to Decimal with error handling.
Args:
value: Value to convert
field_name: Name of field for error logging
Returns:
Decimal value or None if conversion failed
"""
try:
if value is None or value == "":
return None
return Decimal(str(value))
except Exception as e:
self.logger.warning(f"Failed to convert {field_name} '{value}' to Decimal: {e}")
return None
def normalize_trade_side(self, side: str) -> str:
"""
Normalize trade side to standard format.
Args:
side: Raw trade side string
Returns:
Normalized side ('buy' or 'sell')
"""
normalized = side.lower().strip()
# Handle common variations
if normalized in ['buy', 'bid', 'b', '1']:
return 'buy'
elif normalized in ['sell', 'ask', 's', '0']:
return 'sell'
else:
self.logger.warning(f"Unknown trade side: {side}, defaulting to 'buy'")
return 'buy'
def validate_symbol_format(self, symbol: str) -> str:
"""
Validate and normalize symbol format.
Args:
symbol: Raw symbol string
Returns:
Normalized symbol string
"""
if not symbol or not isinstance(symbol, str):
raise ValueError(f"Invalid symbol: {symbol}")
# Basic normalization
normalized = symbol.upper().strip()
if not normalized:
raise ValueError("Empty symbol after normalization")
return normalized
def transform_database_record(self, record: Any) -> Optional[StandardizedTrade]:
"""
Transform database record to standardized format.
This method should be overridden by subclasses to handle
their specific database schema.
Args:
record: Database record
Returns:
StandardizedTrade or None if transformation failed
"""
self.logger.warning("transform_database_record not implemented for this exchange")
return None
def get_transformer_info(self) -> Dict[str, Any]:
"""Get transformer information."""
return {
'exchange': self.exchange_name,
'component': self.component_name,
'capabilities': {
'trade_transformation': True,
'orderbook_transformation': True,
'ticker_transformation': True,
'database_transformation': hasattr(self, 'transform_database_record')
}
}
class UnifiedDataTransformer:
"""
Unified data transformation system for all scenarios.
This class provides a common interface for transforming data from
various sources (real-time, historical, backfill) into standardized
formats for further processing.
TRANSFORMATION PROCESS:
1. Raw Data Input (exchange format, database records, etc.)
2. Validation (using exchange-specific validators)
3. Transformation to StandardizedTrade format
4. Optional aggregation to candles
5. Output in consistent format
"""
def __init__(self,
exchange_transformer: BaseDataTransformer,
component_name: str = "unified_data_transformer"):
"""
Initialize unified data transformer.
Args:
exchange_transformer: Exchange-specific transformer instance
component_name: Name for logging
"""
self.exchange_transformer = exchange_transformer
self.component_name = component_name
self.logger = get_logger(self.component_name)
self.logger.info(f"Initialized unified data transformer with {exchange_transformer.exchange_name} transformer")
def transform_trade_data(self, raw_data: Dict[str, Any], symbol: str) -> Optional[StandardizedTrade]:
"""
Transform trade data using exchange-specific transformer.
Args:
raw_data: Raw trade data from exchange
symbol: Trading symbol
Returns:
Standardized trade or None if transformation failed
"""
try:
return self.exchange_transformer.transform_trade_data(raw_data, symbol)
except Exception as e:
self.logger.error(f"Error in trade transformation: {e}")
return None
def transform_orderbook_data(self, raw_data: Dict[str, Any], symbol: str) -> Optional[Dict[str, Any]]:
"""
Transform orderbook data using exchange-specific transformer.
Args:
raw_data: Raw orderbook data from exchange
symbol: Trading symbol
Returns:
Standardized orderbook data or None if transformation failed
"""
try:
return self.exchange_transformer.transform_orderbook_data(raw_data, symbol)
except Exception as e:
self.logger.error(f"Error in orderbook transformation: {e}")
return None
def transform_ticker_data(self, raw_data: Dict[str, Any], symbol: str) -> Optional[Dict[str, Any]]:
"""
Transform ticker data using exchange-specific transformer.
Args:
raw_data: Raw ticker data from exchange
symbol: Trading symbol
Returns:
Standardized ticker data or None if transformation failed
"""
try:
return self.exchange_transformer.transform_ticker_data(raw_data, symbol)
except Exception as e:
self.logger.error(f"Error in ticker transformation: {e}")
return None
def process_trades_to_candles(self,
trades: Iterator[StandardizedTrade],
timeframes: List[str],
symbol: str) -> List[OHLCVCandle]:
"""
Process any trade iterator to candles - unified processing function.
This function handles ALL scenarios:
- Real-time: Single trade iterators
- Historical: Batch trade iterators
- Backfill: API trade iterators
Args:
trades: Iterator of standardized trades
timeframes: List of timeframes to generate
symbol: Trading symbol
Returns:
List of completed candles
"""
try:
processor = BatchCandleProcessor(
symbol,
self.exchange_transformer.exchange_name,
timeframes,
f"unified_batch_processor_{symbol}"
)
candles = processor.process_trades_to_candles(trades)
self.logger.info(f"Processed {processor.get_stats()['trades_processed']} trades to {len(candles)} candles")
return candles
except Exception as e:
self.logger.error(f"Error processing trades to candles: {e}")
return []
def batch_transform_trades(self,
raw_trades: List[Dict[str, Any]],
symbol: str) -> List[StandardizedTrade]:
"""
Transform multiple trade records in batch.
Args:
raw_trades: List of raw trade data
symbol: Trading symbol
Returns:
List of successfully transformed trades
"""
transformed_trades = []
errors = 0
for raw_trade in raw_trades:
try:
trade = self.transform_trade_data(raw_trade, symbol)
if trade:
transformed_trades.append(trade)
else:
errors += 1
except Exception as e:
self.logger.error(f"Error transforming trade: {e}")
errors += 1
self.logger.info(f"Batch transformed {len(transformed_trades)} trades successfully, {errors} errors")
return transformed_trades
def get_transformer_info(self) -> Dict[str, Any]:
"""Get comprehensive transformer information."""
base_info = self.exchange_transformer.get_transformer_info()
base_info.update({
'unified_component': self.component_name,
'batch_processing': True,
'candle_aggregation': True
})
return base_info
# Utility functions for common transformation patterns
def create_standardized_trade(symbol: str,
trade_id: str,
price: Any,
size: Any,
side: str,
timestamp: Any,
exchange: str,
raw_data: Optional[Dict[str, Any]] = None,
is_milliseconds: bool = True) -> StandardizedTrade:
"""
Utility function to create StandardizedTrade with proper validation.
Args:
symbol: Trading symbol
trade_id: Trade identifier
price: Trade price (any numeric type)
size: Trade size (any numeric type)
side: Trade side ('buy' or 'sell')
timestamp: Trade timestamp
exchange: Exchange name
raw_data: Original raw data
is_milliseconds: True if timestamp is in milliseconds
Returns:
StandardizedTrade object
Raises:
ValueError: If data is invalid
"""
# Convert timestamp
if isinstance(timestamp, (int, float, str)):
timestamp_num = float(timestamp)
if is_milliseconds:
timestamp_num = timestamp_num / 1000
dt = datetime.fromtimestamp(timestamp_num, tz=timezone.utc)
elif isinstance(timestamp, datetime):
dt = timestamp
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
else:
raise ValueError(f"Invalid timestamp type: {type(timestamp)}")
# Convert price and size to Decimal
try:
decimal_price = Decimal(str(price))
decimal_size = Decimal(str(size))
except Exception as e:
raise ValueError(f"Invalid price or size: {e}")
# Normalize side
normalized_side = side.lower().strip()
if normalized_side not in ['buy', 'sell']:
raise ValueError(f"Invalid trade side: {side}")
return StandardizedTrade(
symbol=symbol.upper().strip(),
trade_id=str(trade_id),
price=decimal_price,
size=decimal_size,
side=normalized_side,
timestamp=dt,
exchange=exchange.lower(),
raw_data=raw_data
)
def batch_create_standardized_trades(raw_trades: List[Dict[str, Any]],
symbol: str,
exchange: str,
field_mapping: Dict[str, str],
is_milliseconds: bool = True) -> List[StandardizedTrade]:
"""
Batch create standardized trades from raw data.
Args:
raw_trades: List of raw trade dictionaries
symbol: Trading symbol
exchange: Exchange name
field_mapping: Mapping of StandardizedTrade fields to raw data fields
is_milliseconds: True if timestamps are in milliseconds
Returns:
List of successfully created StandardizedTrade objects
Example field_mapping:
{
'trade_id': 'id',
'price': 'px',
'size': 'sz',
'side': 'side',
'timestamp': 'ts'
}
"""
trades = []
for raw_trade in raw_trades:
try:
trade = create_standardized_trade(
symbol=symbol,
trade_id=raw_trade[field_mapping['trade_id']],
price=raw_trade[field_mapping['price']],
size=raw_trade[field_mapping['size']],
side=raw_trade[field_mapping['side']],
timestamp=raw_trade[field_mapping['timestamp']],
exchange=exchange,
raw_data=raw_trade,
is_milliseconds=is_milliseconds
)
trades.append(trade)
except Exception as e:
# Log error but continue processing
logger = get_logger("batch_transform")
logger.warning(f"Failed to transform trade: {e}")
return trades
__all__ = [
'BaseDataTransformer',
'UnifiedDataTransformer',
'create_standardized_trade',
'batch_create_standardized_trades'
]

484
data/common/validation.py Normal file
View File

@@ -0,0 +1,484 @@
"""
Base validation utilities for all exchanges.
This module provides common validation patterns and base classes
that can be extended by exchange-specific validators.
"""
import re
from datetime import datetime, timezone, timedelta
from decimal import Decimal, InvalidOperation
from typing import Dict, List, Optional, Any, Union, Pattern
from abc import ABC, abstractmethod
from .data_types import DataValidationResult, StandardizedTrade, TradeSide
from utils.logger import get_logger
class ValidationResult:
"""Simple validation result for individual field validation."""
def __init__(self, is_valid: bool, errors: List[str] = None, warnings: List[str] = None, sanitized_data: Any = None):
self.is_valid = is_valid
self.errors = errors or []
self.warnings = warnings or []
self.sanitized_data = sanitized_data
class BaseDataValidator(ABC):
"""
Abstract base class for exchange data validators.
This class provides common validation patterns and utilities
that can be reused across different exchange implementations.
"""
def __init__(self,
exchange_name: str,
component_name: str = "base_data_validator"):
"""
Initialize base data validator.
Args:
exchange_name: Name of the exchange (e.g., 'okx', 'binance')
component_name: Name for logging
"""
self.exchange_name = exchange_name
self.component_name = component_name
self.logger = get_logger(self.component_name)
# Common validation patterns
self._numeric_pattern = re.compile(r'^-?\d*\.?\d+$')
self._trade_id_pattern = re.compile(r'^[a-zA-Z0-9_-]+$') # Flexible pattern
# Valid trade sides
self._valid_trade_sides = {'buy', 'sell'}
# Common price and size limits (can be overridden by subclasses)
self._min_price = Decimal('0.00000001') # 1 satoshi equivalent
self._max_price = Decimal('10000000') # 10 million
self._min_size = Decimal('0.00000001') # Minimum trade size
self._max_size = Decimal('1000000000') # 1 billion max size
# Timestamp validation (milliseconds since epoch)
self._min_timestamp = 1000000000000 # 2001-09-09 (reasonable minimum)
self._max_timestamp = 9999999999999 # 2286-11-20 (reasonable maximum)
self.logger.debug(f"Initialized base data validator for {exchange_name}")
# Abstract methods that must be implemented by subclasses
@abstractmethod
def validate_symbol_format(self, symbol: str) -> ValidationResult:
"""Validate exchange-specific symbol format."""
pass
@abstractmethod
def validate_websocket_message(self, message: Dict[str, Any]) -> DataValidationResult:
"""Validate complete WebSocket message structure."""
pass
# Common validation methods available to all subclasses
def validate_price(self, price: Union[str, int, float, Decimal]) -> ValidationResult:
"""
Validate price value with common rules.
Args:
price: Price value to validate
Returns:
ValidationResult with sanitized decimal price
"""
errors = []
warnings = []
sanitized_data = None
try:
# Convert to Decimal for precise validation
if isinstance(price, str) and price.strip() == "":
errors.append("Empty price string")
return ValidationResult(False, errors, warnings)
decimal_price = Decimal(str(price))
sanitized_data = decimal_price
# Check for negative prices
if decimal_price <= 0:
errors.append(f"Price must be positive, got {decimal_price}")
# Check price bounds
if decimal_price < self._min_price:
warnings.append(f"Price {decimal_price} below minimum {self._min_price}")
elif decimal_price > self._max_price:
warnings.append(f"Price {decimal_price} above maximum {self._max_price}")
# Check for excessive decimal places (warn only)
if decimal_price.as_tuple().exponent < -12:
warnings.append(f"Price has excessive decimal precision: {decimal_price}")
except (InvalidOperation, ValueError, TypeError) as e:
errors.append(f"Invalid price value: {price} - {str(e)}")
return ValidationResult(len(errors) == 0, errors, warnings, sanitized_data)
def validate_size(self, size: Union[str, int, float, Decimal]) -> ValidationResult:
"""
Validate size/quantity value with common rules.
Args:
size: Size value to validate
Returns:
ValidationResult with sanitized decimal size
"""
errors = []
warnings = []
sanitized_data = None
try:
# Convert to Decimal for precise validation
if isinstance(size, str) and size.strip() == "":
errors.append("Empty size string")
return ValidationResult(False, errors, warnings)
decimal_size = Decimal(str(size))
sanitized_data = decimal_size
# Check for negative or zero sizes
if decimal_size <= 0:
errors.append(f"Size must be positive, got {decimal_size}")
# Check size bounds
if decimal_size < self._min_size:
warnings.append(f"Size {decimal_size} below minimum {self._min_size}")
elif decimal_size > self._max_size:
warnings.append(f"Size {decimal_size} above maximum {self._max_size}")
except (InvalidOperation, ValueError, TypeError) as e:
errors.append(f"Invalid size value: {size} - {str(e)}")
return ValidationResult(len(errors) == 0, errors, warnings, sanitized_data)
def validate_volume(self, volume: Union[str, int, float, Decimal]) -> ValidationResult:
"""
Validate volume value with common rules.
Args:
volume: Volume value to validate
Returns:
ValidationResult
"""
errors = []
warnings = []
try:
decimal_volume = Decimal(str(volume))
# Volume can be zero (no trades in period)
if decimal_volume < 0:
errors.append(f"Volume cannot be negative, got {decimal_volume}")
except (InvalidOperation, ValueError, TypeError) as e:
errors.append(f"Invalid volume value: {volume} - {str(e)}")
return ValidationResult(len(errors) == 0, errors, warnings)
def validate_trade_side(self, side: str) -> ValidationResult:
"""
Validate trade side with common rules.
Args:
side: Trade side string
Returns:
ValidationResult
"""
errors = []
warnings = []
if not isinstance(side, str):
errors.append(f"Trade side must be string, got {type(side)}")
return ValidationResult(False, errors, warnings)
normalized_side = side.lower()
if normalized_side not in self._valid_trade_sides:
errors.append(f"Invalid trade side: {side}. Must be 'buy' or 'sell'")
return ValidationResult(len(errors) == 0, errors, warnings)
def validate_timestamp(self, timestamp: Union[str, int], is_milliseconds: bool = True) -> ValidationResult:
"""
Validate timestamp value with common rules.
Args:
timestamp: Timestamp value to validate
is_milliseconds: True if timestamp is in milliseconds, False for seconds
Returns:
ValidationResult
"""
errors = []
warnings = []
try:
# Convert to int
if isinstance(timestamp, str):
if not timestamp.isdigit():
errors.append(f"Invalid timestamp format: {timestamp}")
return ValidationResult(False, errors, warnings)
timestamp_int = int(timestamp)
elif isinstance(timestamp, int):
timestamp_int = timestamp
else:
errors.append(f"Timestamp must be string or int, got {type(timestamp)}")
return ValidationResult(False, errors, warnings)
# Convert to milliseconds if needed
if not is_milliseconds:
timestamp_int = timestamp_int * 1000
# Check timestamp bounds
if timestamp_int < self._min_timestamp:
errors.append(f"Timestamp {timestamp_int} too old")
elif timestamp_int > self._max_timestamp:
errors.append(f"Timestamp {timestamp_int} too far in future")
# Check if timestamp is reasonable (within last year to next year)
current_time_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
one_year_ms = 365 * 24 * 60 * 60 * 1000
if timestamp_int < (current_time_ms - one_year_ms):
warnings.append(f"Timestamp {timestamp_int} is older than 1 year")
elif timestamp_int > (current_time_ms + one_year_ms):
warnings.append(f"Timestamp {timestamp_int} is more than 1 year in future")
except (ValueError, TypeError) as e:
errors.append(f"Invalid timestamp: {timestamp} - {str(e)}")
return ValidationResult(len(errors) == 0, errors, warnings)
def validate_trade_id(self, trade_id: Union[str, int]) -> ValidationResult:
"""
Validate trade ID with flexible rules.
Args:
trade_id: Trade ID to validate
Returns:
ValidationResult
"""
errors = []
warnings = []
if isinstance(trade_id, int):
trade_id = str(trade_id)
if not isinstance(trade_id, str):
errors.append(f"Trade ID must be string or int, got {type(trade_id)}")
return ValidationResult(False, errors, warnings)
if not trade_id.strip():
errors.append("Trade ID cannot be empty")
return ValidationResult(False, errors, warnings)
# Flexible validation - allow alphanumeric, underscore, hyphen
if not self._trade_id_pattern.match(trade_id):
warnings.append(f"Trade ID has unusual format: {trade_id}")
return ValidationResult(len(errors) == 0, errors, warnings)
def validate_symbol_match(self, symbol: str, expected_symbol: Optional[str] = None) -> ValidationResult:
"""
Validate symbol matches expected value.
Args:
symbol: Symbol to validate
expected_symbol: Expected symbol value
Returns:
ValidationResult
"""
errors = []
warnings = []
if not isinstance(symbol, str):
errors.append(f"Symbol must be string, got {type(symbol)}")
return ValidationResult(False, errors, warnings)
if expected_symbol and symbol != expected_symbol:
warnings.append(f"Symbol mismatch: expected {expected_symbol}, got {symbol}")
return ValidationResult(len(errors) == 0, errors, warnings)
def validate_orderbook_side(self, side_data: List[List[str]], side_name: str) -> ValidationResult:
"""
Validate orderbook side (asks or bids) with common rules.
Args:
side_data: List of price/size pairs
side_name: Name of side for error messages
Returns:
ValidationResult with sanitized data
"""
errors = []
warnings = []
sanitized_data = []
if not isinstance(side_data, list):
errors.append(f"{side_name} must be a list")
return ValidationResult(False, errors, warnings)
for i, level in enumerate(side_data):
if not isinstance(level, list) or len(level) < 2:
errors.append(f"{side_name}[{i}] must be a list with at least 2 elements")
continue
# Validate price and size
price_result = self.validate_price(level[0])
size_result = self.validate_size(level[1])
if not price_result.is_valid:
errors.extend([f"{side_name}[{i}] price: {error}" for error in price_result.errors])
if not size_result.is_valid:
errors.extend([f"{side_name}[{i}] size: {error}" for error in size_result.errors])
# Add sanitized level
if price_result.is_valid and size_result.is_valid:
sanitized_level = [str(price_result.sanitized_data), str(size_result.sanitized_data)]
# Include additional fields if present
if len(level) > 2:
sanitized_level.extend(level[2:])
sanitized_data.append(sanitized_level)
return ValidationResult(len(errors) == 0, errors, warnings, sanitized_data)
def validate_standardized_trade(self, trade: StandardizedTrade) -> DataValidationResult:
"""
Validate a standardized trade object.
Args:
trade: StandardizedTrade object to validate
Returns:
DataValidationResult
"""
errors = []
warnings = []
try:
# Validate price
price_result = self.validate_price(trade.price)
if not price_result.is_valid:
errors.extend([f"price: {error}" for error in price_result.errors])
warnings.extend([f"price: {warning}" for warning in price_result.warnings])
# Validate size
size_result = self.validate_size(trade.size)
if not size_result.is_valid:
errors.extend([f"size: {error}" for error in size_result.errors])
warnings.extend([f"size: {warning}" for warning in size_result.warnings])
# Validate side
side_result = self.validate_trade_side(trade.side)
if not side_result.is_valid:
errors.extend([f"side: {error}" for error in side_result.errors])
# Validate trade ID
trade_id_result = self.validate_trade_id(trade.trade_id)
if not trade_id_result.is_valid:
errors.extend([f"trade_id: {error}" for error in trade_id_result.errors])
warnings.extend([f"trade_id: {warning}" for warning in trade_id_result.warnings])
# Validate symbol format (exchange-specific)
symbol_result = self.validate_symbol_format(trade.symbol)
if not symbol_result.is_valid:
errors.extend([f"symbol: {error}" for error in symbol_result.errors])
warnings.extend([f"symbol: {warning}" for warning in symbol_result.warnings])
# Validate timestamp
timestamp_ms = int(trade.timestamp.timestamp() * 1000)
timestamp_result = self.validate_timestamp(timestamp_ms, is_milliseconds=True)
if not timestamp_result.is_valid:
errors.extend([f"timestamp: {error}" for error in timestamp_result.errors])
warnings.extend([f"timestamp: {warning}" for warning in timestamp_result.warnings])
return DataValidationResult(len(errors) == 0, errors, warnings)
except Exception as e:
errors.append(f"Exception during trade validation: {str(e)}")
return DataValidationResult(False, errors, warnings)
def get_validator_info(self) -> Dict[str, Any]:
"""Get validator configuration information."""
return {
'exchange': self.exchange_name,
'component': self.component_name,
'limits': {
'min_price': str(self._min_price),
'max_price': str(self._max_price),
'min_size': str(self._min_size),
'max_size': str(self._max_size),
'min_timestamp': self._min_timestamp,
'max_timestamp': self._max_timestamp
},
'patterns': {
'numeric': self._numeric_pattern.pattern,
'trade_id': self._trade_id_pattern.pattern
}
}
# Utility functions for common validation patterns
def is_valid_decimal(value: Any) -> bool:
"""Check if value can be converted to a valid decimal."""
try:
Decimal(str(value))
return True
except (InvalidOperation, ValueError, TypeError):
return False
def normalize_symbol(symbol: str, exchange: str) -> str:
"""
Normalize symbol format for exchange.
Args:
symbol: Raw symbol string
exchange: Exchange name
Returns:
Normalized symbol string
"""
# Basic normalization - can be extended per exchange
return symbol.upper().strip()
def validate_required_fields(data: Dict[str, Any], required_fields: List[str]) -> List[str]:
"""
Validate that all required fields are present in data.
Args:
data: Data dictionary to check
required_fields: List of required field names
Returns:
List of missing field names
"""
missing_fields = []
for field in required_fields:
if field not in data or data[field] is None:
missing_fields.append(field)
return missing_fields
__all__ = [
'ValidationResult',
'BaseDataValidator',
'is_valid_decimal',
'normalize_symbol',
'validate_required_fields'
]

View File

@@ -8,18 +8,19 @@ error handling, health monitoring, and database integration.
import asyncio
from datetime import datetime, timezone
from decimal import Decimal
from typing import Dict, List, Optional, Any, Set
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from ...base_collector import (
BaseDataCollector, DataType, CollectorStatus, MarketDataPoint,
OHLCVData, DataValidationError, ConnectionError
)
from ...common import StandardizedTrade, OHLCVCandle
from .websocket import (
OKXWebSocketClient, OKXSubscription, OKXChannelType,
ConnectionState, OKXWebSocketError
)
from .data_processor import OKXDataProcessor
from database.connection import get_db_manager, get_raw_data_manager
from database.models import MarketData, RawTrade
from utils.logger import get_logger
@@ -41,6 +42,8 @@ class OKXCollector(BaseDataCollector):
This collector handles a single trading pair and collects real-time data
including trades, orderbook, and ticker information from OKX exchange.
Uses the new common data processing framework for validation, transformation,
and aggregation.
"""
def __init__(self,
@@ -86,14 +89,22 @@ class OKXCollector(BaseDataCollector):
# WebSocket client
self._ws_client: Optional[OKXWebSocketClient] = None
# Data processor using new common framework
self._data_processor = OKXDataProcessor(symbol, component_name=f"{component_name}_processor")
# Add callbacks for processed data
self._data_processor.add_trade_callback(self._on_trade_processed)
self._data_processor.add_candle_callback(self._on_candle_processed)
# Database managers
self._db_manager = None
self._raw_data_manager = None
# Data processing
self._message_buffer: List[Dict[str, Any]] = []
self._last_trade_id: Optional[str] = None
self._last_orderbook_ts: Optional[int] = None
# Data processing counters
self._message_count = 0
self._processed_trades = 0
self._processed_candles = 0
self._error_count = 0
# OKX channel mapping
self._channel_mapping = {
@@ -103,6 +114,7 @@ class OKXCollector(BaseDataCollector):
}
self.logger.info(f"Initialized OKX collector for {symbol} with data types: {[dt.value for dt in data_types]}")
self.logger.info(f"Using common data processing framework")
async def connect(self) -> bool:
"""
@@ -200,14 +212,13 @@ class OKXCollector(BaseDataCollector):
# Subscribe to channels
success = await self._ws_client.subscribe(subscriptions)
if success:
self.logger.info(f"Successfully subscribed to {len(subscriptions)} channels for {self.symbol}")
return True
else:
self.logger.error(f"Failed to subscribe to channels for {self.symbol}")
return success
return False
except Exception as e:
self.logger.error(f"Error subscribing to data for {self.symbol}: {e}")
return False
@@ -224,11 +235,11 @@ class OKXCollector(BaseDataCollector):
True if unsubscription successful, False otherwise
"""
if not self._ws_client or not self._ws_client.is_connected:
self.logger.warning("WebSocket client not connected for unsubscription")
return True # Consider it successful if already disconnected
self.logger.warning("WebSocket client not connected")
return True # Consider it successful if not connected
try:
# Build unsubscriptions
# Build unsubscription list
subscriptions = []
for data_type in data_types:
if data_type in self._channel_mapping:
@@ -236,7 +247,7 @@ class OKXCollector(BaseDataCollector):
subscription = OKXSubscription(
channel=channel,
inst_id=self.symbol,
enabled=False
enabled=False # False for unsubscribe
)
subscriptions.append(subscription)
@@ -245,241 +256,223 @@ class OKXCollector(BaseDataCollector):
# Unsubscribe from channels
success = await self._ws_client.unsubscribe(subscriptions)
if success:
self.logger.info(f"Successfully unsubscribed from {len(subscriptions)} channels for {self.symbol}")
return True
else:
self.logger.warning(f"Failed to unsubscribe from channels for {self.symbol}")
return success
self.logger.error(f"Failed to unsubscribe from channels for {self.symbol}")
return False
except Exception as e:
self.logger.error(f"Error unsubscribing from data for {self.symbol}: {e}")
return False
async def _process_message(self, message: Any) -> Optional[MarketDataPoint]:
"""
Process incoming message from OKX WebSocket.
Process received message using the new data processor.
Args:
message: Raw message from WebSocket
Returns:
Processed MarketDataPoint or None if processing failed
MarketDataPoint if processing successful, None otherwise
"""
if not isinstance(message, dict):
self.logger.warning(f"Received non-dict message: {type(message)}")
return None
try:
if not isinstance(message, dict):
self.logger.warning(f"Unexpected message type: {type(message)}")
self._message_count += 1
# Use the new data processor for validation and processing
success, market_data_points, errors = self._data_processor.validate_and_process_message(
message, expected_symbol=self.symbol
)
if not success:
self._error_count += 1
self.logger.error(f"Message processing failed: {errors}")
return None
# Extract channel and data
arg = message.get('arg', {})
channel = arg.get('channel')
inst_id = arg.get('instId')
data_list = message.get('data', [])
if errors:
self.logger.warning(f"Message processing warnings: {errors}")
# Validate message structure
if not channel or not inst_id or not data_list:
self.logger.debug(f"Incomplete message structure: {message}")
return None
# Store raw data if enabled (for debugging/compliance)
if self.store_raw_data and 'data' in message and 'arg' in message:
await self._store_raw_data(message['arg'].get('channel', 'unknown'), message)
# Check if this message is for our symbol
if inst_id != self.symbol:
self.logger.debug(f"Message for different symbol: {inst_id} (expected: {self.symbol})")
return None
# Store processed market data points in raw_trades table
for data_point in market_data_points:
await self._store_processed_data(data_point)
# Process each data item
market_data_points = []
for data_item in data_list:
data_point = await self._process_data_item(channel, data_item)
if data_point:
market_data_points.append(data_point)
# Store raw data if enabled
if self.store_raw_data and self._raw_data_manager:
await self._store_raw_data(channel, message)
# Return the first processed data point (for the base class interface)
# Return the first data point for compatibility (most use cases have single data point per message)
return market_data_points[0] if market_data_points else None
except Exception as e:
self.logger.error(f"Error processing message for {self.symbol}: {e}")
self._error_count += 1
self.logger.error(f"Error processing message: {e}")
return None
async def _handle_messages(self) -> None:
"""
Handle incoming messages from WebSocket.
This is called by the base class message loop.
"""
# The actual message handling is done through the WebSocket client callback
# This method satisfies the abstract method requirement
if self._ws_client and self._ws_client.is_connected:
# Just sleep briefly to yield control
await asyncio.sleep(0.1)
else:
# If not connected, sleep longer to avoid busy loop
await asyncio.sleep(1.0)
async def _process_data_item(self, channel: str, data_item: Dict[str, Any]) -> Optional[MarketDataPoint]:
"""
Process individual data item from OKX message.
Args:
channel: OKX channel name
data_item: Individual data item
Returns:
Processed MarketDataPoint or None
"""
try:
# Determine data type from channel
data_type = None
for dt, ch in self._channel_mapping.items():
if ch == channel:
data_type = dt
break
if not data_type:
self.logger.warning(f"Unknown channel: {channel}")
return None
# Extract timestamp
timestamp_ms = data_item.get('ts')
if timestamp_ms:
timestamp = datetime.fromtimestamp(int(timestamp_ms) / 1000, tz=timezone.utc)
else:
timestamp = datetime.now(timezone.utc)
# Create MarketDataPoint
market_data_point = MarketDataPoint(
exchange="okx",
symbol=self.symbol,
timestamp=timestamp,
data_type=data_type,
data=data_item
)
# Store processed data to database
await self._store_processed_data(market_data_point)
# Update statistics
self._stats['messages_processed'] += 1
self._stats['last_message_time'] = timestamp
return market_data_point
except Exception as e:
self.logger.error(f"Error processing data item for {self.symbol}: {e}")
self._stats['errors'] += 1
return None
"""Handle message processing in the background."""
# The new data processor handles messages through callbacks
# This method exists for compatibility with BaseDataCollector
await asyncio.sleep(0.1)
async def _store_processed_data(self, data_point: MarketDataPoint) -> None:
"""
Store processed data to MarketData table.
Store raw market data in the raw_trades table.
Args:
data_point: Processed market data point
"""
try:
# For now, we'll focus on trade data storage
# Orderbook and ticker storage can be added later
if data_point.data_type == DataType.TRADE:
await self._store_trade_data(data_point)
except Exception as e:
self.logger.error(f"Error storing processed data for {self.symbol}: {e}")
async def _store_trade_data(self, data_point: MarketDataPoint) -> None:
"""
Store trade data to database.
Args:
data_point: Trade data point
data_point: Raw market data point (trade, orderbook, ticker)
"""
try:
if not self._db_manager:
return
trade_data = data_point.data
# Extract trade information
trade_id = trade_data.get('tradeId')
price = Decimal(str(trade_data.get('px', '0')))
size = Decimal(str(trade_data.get('sz', '0')))
side = trade_data.get('side', 'unknown')
# Skip duplicate trades
if trade_id == self._last_trade_id:
return
self._last_trade_id = trade_id
# For now, we'll log the trade data
# Actual database storage will be implemented in the next phase
self.logger.debug(f"Trade: {self.symbol} - {side} {size} @ {price} (ID: {trade_id})")
# Store raw market data points in raw_trades table
with self._db_manager.get_session() as session:
raw_trade = RawTrade(
exchange="okx",
symbol=data_point.symbol,
timestamp=data_point.timestamp,
data_type=data_point.data_type.value,
raw_data=data_point.data
)
session.add(raw_trade)
self.logger.debug(f"Stored raw data: {data_point.data_type.value} for {data_point.symbol}")
except Exception as e:
self.logger.error(f"Error storing trade data for {self.symbol}: {e}")
self.logger.error(f"Error storing raw market data: {e}")
async def _store_completed_candle(self, candle: OHLCVCandle) -> None:
"""
Store completed OHLCV candle in the market_data table.
Args:
candle: Completed OHLCV candle
"""
try:
if not self._db_manager:
return
# Store completed candles in market_data table
with self._db_manager.get_session() as session:
market_data = MarketData(
exchange=candle.exchange,
symbol=candle.symbol,
timeframe=candle.timeframe,
timestamp=candle.start_time, # Use start_time as the candle timestamp
open=candle.open,
high=candle.high,
low=candle.low,
close=candle.close,
volume=candle.volume,
trades_count=candle.trade_count
)
session.add(market_data)
self.logger.info(f"Stored completed candle: {candle.symbol} {candle.timeframe} at {candle.start_time}")
except Exception as e:
self.logger.error(f"Error storing completed candle: {e}")
async def _store_raw_data(self, channel: str, raw_message: Dict[str, Any]) -> None:
"""
Store raw data for debugging and compliance.
Store raw WebSocket data for debugging in raw_trades table.
Args:
channel: OKX channel name
raw_message: Complete raw message
channel: Channel name
raw_message: Raw WebSocket message
"""
try:
if not self._raw_data_manager:
if not self._raw_data_manager or 'data' not in raw_message:
return
# Store raw data using the raw data manager
self._raw_data_manager.store_raw_data(
exchange="okx",
symbol=self.symbol,
data_type=channel,
raw_data=raw_message,
timestamp=datetime.now(timezone.utc)
)
# Store each data item as a separate raw data record
for data_item in raw_message['data']:
self._raw_data_manager.store_raw_data(
exchange="okx",
symbol=self.symbol,
data_type=f"raw_{channel}", # Prefix with 'raw_' to distinguish from processed data
raw_data=data_item,
timestamp=datetime.now(timezone.utc)
)
except Exception as e:
self.logger.error(f"Error storing raw data for {self.symbol}: {e}")
self.logger.error(f"Error storing raw WebSocket data: {e}")
def _on_message(self, message: Dict[str, Any]) -> None:
"""
Callback function for WebSocket messages.
Handle incoming WebSocket message.
Args:
message: Message received from WebSocket
message: WebSocket message from OKX
"""
try:
# Add message to buffer for processing
self._message_buffer.append(message)
# Process message asynchronously
asyncio.create_task(self._process_message(message))
except Exception as e:
self.logger.error(f"Error in message callback for {self.symbol}: {e}")
self.logger.error(f"Error handling WebSocket message: {e}")
def _on_trade_processed(self, trade: StandardizedTrade) -> None:
"""
Callback for processed trades from data processor.
Args:
trade: Processed standardized trade
"""
self._processed_trades += 1
self.logger.debug(f"Processed trade: {trade.symbol} {trade.side} {trade.size}@{trade.price}")
def _on_candle_processed(self, candle: OHLCVCandle) -> None:
"""
Callback for completed candles from data processor.
Args:
candle: Completed OHLCV candle
"""
self._processed_candles += 1
self.logger.info(f"Completed candle: {candle.symbol} {candle.timeframe} O:{candle.open} H:{candle.high} L:{candle.low} C:{candle.close} V:{candle.volume}")
# Store completed candle in market_data table
if candle.is_complete:
asyncio.create_task(self._store_completed_candle(candle))
def get_status(self) -> Dict[str, Any]:
"""Get collector status including WebSocket client status."""
"""
Get current collector status including processing statistics.
Returns:
Dictionary containing collector status information
"""
base_status = super().get_status()
# Add OKX-specific status
okx_status = {
'symbol': self.symbol,
'websocket_connected': self._ws_client.is_connected if self._ws_client else False,
'websocket_state': self._ws_client.connection_state.value if self._ws_client else 'disconnected',
'last_trade_id': self._last_trade_id,
'message_buffer_size': len(self._message_buffer),
'store_raw_data': self.store_raw_data
"symbol": self.symbol,
"websocket_connected": self._ws_client.is_connected if self._ws_client else False,
"websocket_state": self._ws_client.connection_state.value if self._ws_client else "disconnected",
"store_raw_data": self.store_raw_data,
"processing_stats": {
"messages_received": self._message_count,
"trades_processed": self._processed_trades,
"candles_processed": self._processed_candles,
"errors": self._error_count
}
}
# Add WebSocket stats if available
if self._ws_client:
okx_status['websocket_stats'] = self._ws_client.get_stats()
# Add data processor statistics
if self._data_processor:
okx_status["data_processor_stats"] = self._data_processor.get_processing_stats()
return {**base_status, **okx_status}
# Add WebSocket statistics
if self._ws_client:
okx_status["websocket_stats"] = self._ws_client.get_stats()
# Merge with base status
base_status.update(okx_status)
return base_status
def __repr__(self) -> str:
return f"<OKXCollector(symbol={self.symbol}, status={self.status.value}, data_types={[dt.value for dt in self.data_types]})>"
"""String representation of the collector."""
return f"OKXCollector(symbol='{self.symbol}', status='{self.status.value}', data_types={[dt.value for dt in self.data_types]})"

View File

@@ -0,0 +1,726 @@
"""
OKX-specific data processing utilities.
This module provides OKX-specific data validation, transformation, and processing
utilities that extend the common data processing framework.
"""
import re
from datetime import datetime, timezone
from decimal import Decimal
from typing import Dict, List, Optional, Any, Union, Tuple
from enum import Enum
from ...base_collector import DataType, MarketDataPoint
from ...common import (
DataValidationResult,
StandardizedTrade,
OHLCVCandle,
CandleProcessingConfig,
RealTimeCandleProcessor,
BaseDataValidator,
ValidationResult,
BaseDataTransformer,
UnifiedDataTransformer,
create_standardized_trade
)
from utils.logger import get_logger
class OKXMessageType(Enum):
"""OKX WebSocket message types."""
DATA = "data"
SUBSCRIPTION_SUCCESS = "subscribe"
UNSUBSCRIPTION_SUCCESS = "unsubscribe"
ERROR = "error"
PING = "ping"
PONG = "pong"
class OKXTradeField(Enum):
"""OKX trade data field names."""
INST_ID = "instId"
TRADE_ID = "tradeId"
PRICE = "px"
SIZE = "sz"
SIDE = "side"
TIMESTAMP = "ts"
class OKXOrderbookField(Enum):
"""OKX orderbook data field names."""
INST_ID = "instId"
ASKS = "asks"
BIDS = "bids"
TIMESTAMP = "ts"
SEQID = "seqId"
class OKXTickerField(Enum):
"""OKX ticker data field names."""
INST_ID = "instId"
LAST = "last"
LAST_SZ = "lastSz"
ASK_PX = "askPx"
ASK_SZ = "askSz"
BID_PX = "bidPx"
BID_SZ = "bidSz"
OPEN_24H = "open24h"
HIGH_24H = "high24h"
LOW_24H = "low24h"
VOL_24H = "vol24h"
VOL_CNY_24H = "volCcy24h"
TIMESTAMP = "ts"
class OKXDataValidator(BaseDataValidator):
"""
OKX-specific data validator extending the common base validator.
This class provides OKX-specific validation for message formats,
symbol patterns, and data structures.
"""
def __init__(self, component_name: str = "okx_data_validator"):
"""Initialize OKX data validator."""
super().__init__("okx", component_name)
# OKX-specific patterns
self._symbol_pattern = re.compile(r'^[A-Z0-9]+-[A-Z0-9]+$') # BTC-USDT, ETH-USDC
self._trade_id_pattern = re.compile(r'^\d+$') # OKX uses numeric trade IDs
# OKX-specific valid channels
self._valid_channels = {
'trades', 'books5', 'books50', 'books-l2-tbt', 'tickers',
'candle1m', 'candle5m', 'candle15m', 'candle1H', 'candle4H', 'candle1D'
}
self.logger.debug("Initialized OKX data validator")
def validate_symbol_format(self, symbol: str) -> ValidationResult:
"""Validate OKX symbol format (e.g., BTC-USDT)."""
errors = []
warnings = []
if not isinstance(symbol, str):
errors.append(f"Symbol must be string, got {type(symbol)}")
return ValidationResult(False, errors, warnings)
if not self._symbol_pattern.match(symbol):
errors.append(f"Invalid OKX symbol format: {symbol}. Expected format: BASE-QUOTE (e.g., BTC-USDT)")
return ValidationResult(len(errors) == 0, errors, warnings)
def validate_websocket_message(self, message: Dict[str, Any]) -> DataValidationResult:
"""Validate OKX WebSocket message structure."""
errors = []
warnings = []
try:
# Check basic message structure
if not isinstance(message, dict):
errors.append(f"Message must be a dictionary, got {type(message)}")
return DataValidationResult(False, errors, warnings)
# Identify message type
message_type = self._identify_message_type(message)
if message_type == OKXMessageType.DATA:
return self._validate_data_message(message)
elif message_type in [OKXMessageType.SUBSCRIPTION_SUCCESS, OKXMessageType.UNSUBSCRIPTION_SUCCESS]:
return self._validate_subscription_message(message)
elif message_type == OKXMessageType.ERROR:
return self._validate_error_message(message)
elif message_type in [OKXMessageType.PING, OKXMessageType.PONG]:
return DataValidationResult(True, [], []) # Ping/pong are always valid
else:
warnings.append("Unknown message type, basic validation only")
return DataValidationResult(True, [], warnings)
except Exception as e:
errors.append(f"Exception during message validation: {str(e)}")
return DataValidationResult(False, errors, warnings)
def validate_trade_data(self, data: Dict[str, Any], symbol: Optional[str] = None) -> DataValidationResult:
"""Validate OKX trade data structure and values."""
errors = []
warnings = []
sanitized_data = data.copy()
try:
# Check required fields
required_fields = [field.value for field in OKXTradeField]
missing_fields = []
for field in required_fields:
if field not in data:
missing_fields.append(field)
if missing_fields:
errors.extend([f"Missing required trade field: {field}" for field in missing_fields])
return DataValidationResult(False, errors, warnings)
# Validate individual fields using base validator methods
symbol_result = self.validate_symbol_format(data[OKXTradeField.INST_ID.value])
if not symbol_result.is_valid:
errors.extend(symbol_result.errors)
if symbol:
match_result = self.validate_symbol_match(data[OKXTradeField.INST_ID.value], symbol)
warnings.extend(match_result.warnings)
trade_id_result = self.validate_trade_id(data[OKXTradeField.TRADE_ID.value])
if not trade_id_result.is_valid:
errors.extend(trade_id_result.errors)
warnings.extend(trade_id_result.warnings)
price_result = self.validate_price(data[OKXTradeField.PRICE.value])
if not price_result.is_valid:
errors.extend(price_result.errors)
else:
sanitized_data[OKXTradeField.PRICE.value] = str(price_result.sanitized_data)
warnings.extend(price_result.warnings)
size_result = self.validate_size(data[OKXTradeField.SIZE.value])
if not size_result.is_valid:
errors.extend(size_result.errors)
else:
sanitized_data[OKXTradeField.SIZE.value] = str(size_result.sanitized_data)
warnings.extend(size_result.warnings)
side_result = self.validate_trade_side(data[OKXTradeField.SIDE.value])
if not side_result.is_valid:
errors.extend(side_result.errors)
timestamp_result = self.validate_timestamp(data[OKXTradeField.TIMESTAMP.value])
if not timestamp_result.is_valid:
errors.extend(timestamp_result.errors)
warnings.extend(timestamp_result.warnings)
return DataValidationResult(len(errors) == 0, errors, warnings, sanitized_data)
except Exception as e:
errors.append(f"Exception during trade validation: {str(e)}")
return DataValidationResult(False, errors, warnings)
def validate_orderbook_data(self, data: Dict[str, Any], symbol: Optional[str] = None) -> DataValidationResult:
"""Validate OKX orderbook data structure and values."""
errors = []
warnings = []
sanitized_data = data.copy()
try:
# Check required fields
required_fields = [OKXOrderbookField.INST_ID.value, OKXOrderbookField.ASKS.value,
OKXOrderbookField.BIDS.value, OKXOrderbookField.TIMESTAMP.value]
missing_fields = []
for field in required_fields:
if field not in data:
missing_fields.append(field)
if missing_fields:
errors.extend([f"Missing required orderbook field: {field}" for field in missing_fields])
return DataValidationResult(False, errors, warnings)
# Validate symbol
symbol_result = self.validate_symbol_format(data[OKXOrderbookField.INST_ID.value])
if not symbol_result.is_valid:
errors.extend(symbol_result.errors)
if symbol:
match_result = self.validate_symbol_match(data[OKXOrderbookField.INST_ID.value], symbol)
warnings.extend(match_result.warnings)
# Validate timestamp
timestamp_result = self.validate_timestamp(data[OKXOrderbookField.TIMESTAMP.value])
if not timestamp_result.is_valid:
errors.extend(timestamp_result.errors)
warnings.extend(timestamp_result.warnings)
# Validate asks and bids using base validator
asks_result = self.validate_orderbook_side(data[OKXOrderbookField.ASKS.value], "asks")
if not asks_result.is_valid:
errors.extend(asks_result.errors)
else:
sanitized_data[OKXOrderbookField.ASKS.value] = asks_result.sanitized_data
warnings.extend(asks_result.warnings)
bids_result = self.validate_orderbook_side(data[OKXOrderbookField.BIDS.value], "bids")
if not bids_result.is_valid:
errors.extend(bids_result.errors)
else:
sanitized_data[OKXOrderbookField.BIDS.value] = bids_result.sanitized_data
warnings.extend(bids_result.warnings)
# Validate sequence ID if present
if OKXOrderbookField.SEQID.value in data:
seq_id = data[OKXOrderbookField.SEQID.value]
if not isinstance(seq_id, (int, str)) or (isinstance(seq_id, str) and not seq_id.isdigit()):
errors.append("Invalid sequence ID format")
return DataValidationResult(len(errors) == 0, errors, warnings, sanitized_data)
except Exception as e:
errors.append(f"Exception during orderbook validation: {str(e)}")
return DataValidationResult(False, errors, warnings)
def validate_ticker_data(self, data: Dict[str, Any], symbol: Optional[str] = None) -> DataValidationResult:
"""Validate OKX ticker data structure and values."""
errors = []
warnings = []
sanitized_data = data.copy()
try:
# Check required fields
required_fields = [OKXTickerField.INST_ID.value, OKXTickerField.LAST.value, OKXTickerField.TIMESTAMP.value]
missing_fields = []
for field in required_fields:
if field not in data:
missing_fields.append(field)
if missing_fields:
errors.extend([f"Missing required ticker field: {field}" for field in missing_fields])
return DataValidationResult(False, errors, warnings)
# Validate symbol
symbol_result = self.validate_symbol_format(data[OKXTickerField.INST_ID.value])
if not symbol_result.is_valid:
errors.extend(symbol_result.errors)
if symbol:
match_result = self.validate_symbol_match(data[OKXTickerField.INST_ID.value], symbol)
warnings.extend(match_result.warnings)
# Validate timestamp
timestamp_result = self.validate_timestamp(data[OKXTickerField.TIMESTAMP.value])
if not timestamp_result.is_valid:
errors.extend(timestamp_result.errors)
warnings.extend(timestamp_result.warnings)
# Validate price fields (optional fields)
price_fields = [OKXTickerField.LAST, OKXTickerField.ASK_PX, OKXTickerField.BID_PX,
OKXTickerField.OPEN_24H, OKXTickerField.HIGH_24H, OKXTickerField.LOW_24H]
for field in price_fields:
if field.value in data and data[field.value] not in [None, ""]:
price_result = self.validate_price(data[field.value])
if not price_result.is_valid:
errors.extend([f"{field.value}: {error}" for error in price_result.errors])
else:
sanitized_data[field.value] = str(price_result.sanitized_data)
warnings.extend([f"{field.value}: {warning}" for warning in price_result.warnings])
# Validate size fields (optional fields)
size_fields = [OKXTickerField.LAST_SZ, OKXTickerField.ASK_SZ, OKXTickerField.BID_SZ]
for field in size_fields:
if field.value in data and data[field.value] not in [None, ""]:
size_result = self.validate_size(data[field.value])
if not size_result.is_valid:
errors.extend([f"{field.value}: {error}" for error in size_result.errors])
else:
sanitized_data[field.value] = str(size_result.sanitized_data)
warnings.extend([f"{field.value}: {warning}" for warning in size_result.warnings])
# Validate volume fields (optional fields)
volume_fields = [OKXTickerField.VOL_24H, OKXTickerField.VOL_CNY_24H]
for field in volume_fields:
if field.value in data and data[field.value] not in [None, ""]:
volume_result = self.validate_volume(data[field.value])
if not volume_result.is_valid:
errors.extend([f"{field.value}: {error}" for error in volume_result.errors])
warnings.extend([f"{field.value}: {warning}" for warning in volume_result.warnings])
return DataValidationResult(len(errors) == 0, errors, warnings, sanitized_data)
except Exception as e:
errors.append(f"Exception during ticker validation: {str(e)}")
return DataValidationResult(False, errors, warnings)
# Private helper methods for OKX-specific validation
def _identify_message_type(self, message: Dict[str, Any]) -> OKXMessageType:
"""Identify the type of OKX WebSocket message."""
if 'event' in message:
event = message['event']
if event == 'subscribe':
return OKXMessageType.SUBSCRIPTION_SUCCESS
elif event == 'unsubscribe':
return OKXMessageType.UNSUBSCRIPTION_SUCCESS
elif event == 'error':
return OKXMessageType.ERROR
if 'data' in message and 'arg' in message:
return OKXMessageType.DATA
# Default to data type for unknown messages
return OKXMessageType.DATA
def _validate_data_message(self, message: Dict[str, Any]) -> DataValidationResult:
"""Validate OKX data message structure."""
errors = []
warnings = []
# Check required fields
if 'arg' not in message:
errors.append("Missing 'arg' field in data message")
if 'data' not in message:
errors.append("Missing 'data' field in data message")
if errors:
return DataValidationResult(False, errors, warnings)
# Validate arg structure
arg = message['arg']
if not isinstance(arg, dict):
errors.append("'arg' field must be a dictionary")
else:
if 'channel' not in arg:
errors.append("Missing 'channel' in arg")
elif arg['channel'] not in self._valid_channels:
warnings.append(f"Unknown channel: {arg['channel']}")
if 'instId' not in arg:
errors.append("Missing 'instId' in arg")
# Validate data structure
data = message['data']
if not isinstance(data, list):
errors.append("'data' field must be a list")
elif len(data) == 0:
warnings.append("Empty data array")
return DataValidationResult(len(errors) == 0, errors, warnings)
def _validate_subscription_message(self, message: Dict[str, Any]) -> DataValidationResult:
"""Validate subscription/unsubscription message."""
errors = []
warnings = []
if 'event' not in message:
errors.append("Missing 'event' field")
if 'arg' not in message:
errors.append("Missing 'arg' field")
return DataValidationResult(len(errors) == 0, errors, warnings)
def _validate_error_message(self, message: Dict[str, Any]) -> DataValidationResult:
"""Validate error message."""
errors = []
warnings = []
if 'event' not in message or message['event'] != 'error':
errors.append("Invalid error message structure")
if 'msg' in message:
warnings.append(f"OKX error: {message['msg']}")
return DataValidationResult(len(errors) == 0, errors, warnings)
class OKXDataTransformer(BaseDataTransformer):
"""
OKX-specific data transformer extending the common base transformer.
This class handles transformation of OKX data formats to standardized formats.
"""
def __init__(self, component_name: str = "okx_data_transformer"):
"""Initialize OKX data transformer."""
super().__init__("okx", component_name)
def transform_trade_data(self, raw_data: Dict[str, Any], symbol: str) -> Optional[StandardizedTrade]:
"""Transform OKX trade data to standardized format."""
try:
return create_standardized_trade(
symbol=raw_data[OKXTradeField.INST_ID.value],
trade_id=raw_data[OKXTradeField.TRADE_ID.value],
price=raw_data[OKXTradeField.PRICE.value],
size=raw_data[OKXTradeField.SIZE.value],
side=raw_data[OKXTradeField.SIDE.value],
timestamp=raw_data[OKXTradeField.TIMESTAMP.value],
exchange="okx",
raw_data=raw_data,
is_milliseconds=True
)
except Exception as e:
self.logger.error(f"Error transforming OKX trade data: {e}")
return None
def transform_orderbook_data(self, raw_data: Dict[str, Any], symbol: str) -> Optional[Dict[str, Any]]:
"""Transform OKX orderbook data to standardized format."""
try:
# Basic transformation - can be enhanced as needed
return {
'symbol': raw_data[OKXOrderbookField.INST_ID.value],
'asks': raw_data[OKXOrderbookField.ASKS.value],
'bids': raw_data[OKXOrderbookField.BIDS.value],
'timestamp': self.timestamp_to_datetime(raw_data[OKXOrderbookField.TIMESTAMP.value]),
'exchange': 'okx',
'raw_data': raw_data
}
except Exception as e:
self.logger.error(f"Error transforming OKX orderbook data: {e}")
return None
def transform_ticker_data(self, raw_data: Dict[str, Any], symbol: str) -> Optional[Dict[str, Any]]:
"""Transform OKX ticker data to standardized format."""
try:
# Transform ticker data to standardized format
ticker_data = {
'symbol': raw_data[OKXTickerField.INST_ID.value],
'timestamp': self.timestamp_to_datetime(raw_data[OKXTickerField.TIMESTAMP.value]),
'exchange': 'okx',
'raw_data': raw_data
}
# Add available price fields
price_fields = {
'last': OKXTickerField.LAST.value,
'bid': OKXTickerField.BID_PX.value,
'ask': OKXTickerField.ASK_PX.value,
'open_24h': OKXTickerField.OPEN_24H.value,
'high_24h': OKXTickerField.HIGH_24H.value,
'low_24h': OKXTickerField.LOW_24H.value
}
for std_field, okx_field in price_fields.items():
if okx_field in raw_data and raw_data[okx_field] not in [None, ""]:
decimal_price = self.safe_decimal_conversion(raw_data[okx_field], std_field)
if decimal_price:
ticker_data[std_field] = decimal_price
# Add volume fields
if OKXTickerField.VOL_24H.value in raw_data:
volume = self.safe_decimal_conversion(raw_data[OKXTickerField.VOL_24H.value], 'volume_24h')
if volume:
ticker_data['volume_24h'] = volume
return ticker_data
except Exception as e:
self.logger.error(f"Error transforming OKX ticker data: {e}")
return None
class OKXDataProcessor:
"""
Main OKX data processor using common utilities.
This class provides a simplified interface for OKX data processing,
leveraging the common validation, transformation, and aggregation utilities.
"""
def __init__(self,
symbol: str,
config: Optional[CandleProcessingConfig] = None,
component_name: str = "okx_data_processor"):
"""
Initialize OKX data processor.
Args:
symbol: Trading symbol to process
config: Candle processing configuration
component_name: Name for logging
"""
self.symbol = symbol
self.component_name = component_name
self.logger = get_logger(self.component_name)
# Core components using common utilities
self.validator = OKXDataValidator(f"{component_name}_validator")
self.transformer = OKXDataTransformer(f"{component_name}_transformer")
self.unified_transformer = UnifiedDataTransformer(self.transformer, f"{component_name}_unified")
# Real-time candle processing using common utilities
self.config = config or CandleProcessingConfig()
self.candle_processor = RealTimeCandleProcessor(
symbol, "okx", self.config, f"{component_name}_candles"
)
# Callbacks
self.trade_callbacks: List[callable] = []
self.candle_callbacks: List[callable] = []
# Connect candle processor callbacks
self.candle_processor.add_candle_callback(self._emit_candle_to_callbacks)
self.logger.info(f"Initialized OKX data processor for {symbol} with real-time candle processing")
def add_trade_callback(self, callback: callable) -> None:
"""Add callback for processed trades."""
self.trade_callbacks.append(callback)
def add_candle_callback(self, callback: callable) -> None:
"""Add callback for completed candles."""
self.candle_callbacks.append(callback)
def validate_and_process_message(self, message: Dict[str, Any], expected_symbol: Optional[str] = None) -> Tuple[bool, List[MarketDataPoint], List[str]]:
"""
Validate and process complete OKX WebSocket message.
This is the main entry point for real-time WebSocket data.
Args:
message: Complete WebSocket message from OKX
expected_symbol: Expected trading symbol for validation
Returns:
Tuple of (success, list of market data points, list of errors)
"""
try:
# First validate the message structure
validation_result = self.validator.validate_websocket_message(message)
if not validation_result.is_valid:
self.logger.error(f"Message validation failed: {validation_result.errors}")
return False, [], validation_result.errors
# Log warnings if any
if validation_result.warnings:
self.logger.warning(f"Message validation warnings: {validation_result.warnings}")
# Process data if it's a data message
if 'data' in message and 'arg' in message:
return self._process_data_message(message, expected_symbol)
# Non-data messages are considered successfully processed but return no data points
return True, [], []
except Exception as e:
error_msg = f"Exception during message validation and processing: {str(e)}"
self.logger.error(error_msg)
return False, [], [error_msg]
def _process_data_message(self, message: Dict[str, Any], expected_symbol: Optional[str] = None) -> Tuple[bool, List[MarketDataPoint], List[str]]:
"""Process OKX data message and return market data points."""
errors = []
market_data_points = []
try:
arg = message['arg']
channel = arg['channel']
inst_id = arg['instId']
data_list = message['data']
# Determine data type from channel
data_type = self._channel_to_data_type(channel)
if not data_type:
errors.append(f"Unsupported channel: {channel}")
return False, [], errors
# Process each data item
for data_item in data_list:
try:
# Validate and transform based on channel type
if channel == 'trades':
validation_result = self.validator.validate_trade_data(data_item, expected_symbol)
elif channel in ['books5', 'books50', 'books-l2-tbt']:
validation_result = self.validator.validate_orderbook_data(data_item, expected_symbol)
elif channel == 'tickers':
validation_result = self.validator.validate_ticker_data(data_item, expected_symbol)
else:
errors.append(f"Unsupported channel for validation: {channel}")
continue
if not validation_result.is_valid:
errors.extend(validation_result.errors)
continue
if validation_result.warnings:
self.logger.warning(f"Data validation warnings: {validation_result.warnings}")
# Create MarketDataPoint using sanitized data
sanitized_data = validation_result.sanitized_data or data_item
timestamp_ms = sanitized_data.get('ts')
if timestamp_ms:
timestamp = datetime.fromtimestamp(int(timestamp_ms) / 1000, tz=timezone.utc)
else:
timestamp = datetime.now(timezone.utc)
market_data_point = MarketDataPoint(
exchange="okx",
symbol=inst_id,
timestamp=timestamp,
data_type=data_type,
data=sanitized_data
)
market_data_points.append(market_data_point)
# Real-time processing for trades
if channel == 'trades' and inst_id == self.symbol:
self._process_real_time_trade(sanitized_data)
except Exception as e:
self.logger.error(f"Error processing data item: {e}")
errors.append(f"Error processing data item: {str(e)}")
return len(errors) == 0, market_data_points, errors
except Exception as e:
error_msg = f"Exception during data message processing: {str(e)}"
errors.append(error_msg)
return False, [], errors
def _process_real_time_trade(self, trade_data: Dict[str, Any]) -> None:
"""Process real-time trade for candle generation."""
try:
# Transform to standardized format using the unified transformer
standardized_trade = self.unified_transformer.transform_trade_data(trade_data, self.symbol)
if standardized_trade:
# Process for real-time candles using common utilities
completed_candles = self.candle_processor.process_trade(standardized_trade)
# Emit trade to callbacks
for callback in self.trade_callbacks:
try:
callback(standardized_trade)
except Exception as e:
self.logger.error(f"Error in trade callback: {e}")
# Note: Candle callbacks are handled by _emit_candle_to_callbacks
except Exception as e:
self.logger.error(f"Error processing real-time trade: {e}")
def _emit_candle_to_callbacks(self, candle: OHLCVCandle) -> None:
"""Emit candle to all registered callbacks."""
for callback in self.candle_callbacks:
try:
callback(candle)
except Exception as e:
self.logger.error(f"Error in candle callback: {e}")
def _channel_to_data_type(self, channel: str) -> Optional[DataType]:
"""Convert OKX channel name to DataType enum."""
channel_mapping = {
'trades': DataType.TRADE,
'books5': DataType.ORDERBOOK,
'books50': DataType.ORDERBOOK,
'books-l2-tbt': DataType.ORDERBOOK,
'tickers': DataType.TICKER
}
return channel_mapping.get(channel)
def get_processing_stats(self) -> Dict[str, Any]:
"""Get comprehensive processing statistics."""
return {
'candle_processor': self.candle_processor.get_stats(),
'current_candles': self.candle_processor.get_current_candles(),
'callbacks': {
'trade_callbacks': len(self.trade_callbacks),
'candle_callbacks': len(self.candle_callbacks)
},
'validator_info': self.validator.get_validator_info(),
'transformer_info': self.unified_transformer.get_transformer_info()
}
__all__ = [
'OKXMessageType',
'OKXTradeField',
'OKXOrderbookField',
'OKXTickerField',
'OKXDataValidator',
'OKXDataTransformer',
'OKXDataProcessor'
]