TCPDashboard/tests/data/common/test_aggregation_safety.py
Vasily.onl e7ede7f329 Refactor aggregation module and enhance structure
- Split the `aggregation.py` file into a dedicated sub-package, improving modularity and maintainability.
- Moved `TimeframeBucket`, `RealTimeCandleProcessor`, and `BatchCandleProcessor` classes into their respective files within the new `aggregation` sub-package.
- Introduced utility functions for trade aggregation and validation, enhancing code organization.
- Updated import paths throughout the codebase to reflect the new structure, ensuring compatibility.
- Added safety net tests for the aggregation package to verify core functionality and prevent regressions during refactoring.

These changes enhance the overall architecture of the aggregation module, making it more scalable and easier to manage.
2025-06-07 01:17:22 +08:00

231 lines
8.3 KiB
Python

"""
Safety net tests for the aggregation package.
These tests verify the core functionality of the aggregation module
before and during refactoring to ensure no regressions are introduced.
"""
import unittest
from datetime import datetime, timezone, timedelta
from decimal import Decimal
from typing import Dict, List
from data.common.aggregation.bucket import TimeframeBucket
from data.common.aggregation.realtime import RealTimeCandleProcessor
from data.common.aggregation.batch import BatchCandleProcessor
from data.common.aggregation.utils import (
validate_timeframe,
parse_timeframe,
aggregate_trades_to_candles
)
from data.common import (
StandardizedTrade,
OHLCVCandle,
CandleProcessingConfig
)
class TestTimeframeBucketSafety(unittest.TestCase):
"""Safety net tests for TimeframeBucket class."""
def setUp(self):
self.symbol = "BTC-USDT"
self.timeframe = "5m"
self.start_time = datetime(2024, 1, 1, 10, 0, tzinfo=timezone.utc)
self.bucket = TimeframeBucket(self.symbol, self.timeframe, self.start_time)
def test_bucket_initialization(self):
"""Test bucket initialization and time boundaries."""
self.assertEqual(self.bucket.symbol, self.symbol)
self.assertEqual(self.bucket.timeframe, self.timeframe)
self.assertEqual(self.bucket.start_time, self.start_time)
self.assertEqual(self.bucket.end_time, self.start_time + timedelta(minutes=5))
def test_add_trade_updates_ohlcv(self):
"""Test that adding trades correctly updates OHLCV data."""
trade1 = StandardizedTrade(
symbol=self.symbol,
trade_id="1",
price=Decimal("50000"),
size=Decimal("1"),
side="buy",
timestamp=self.start_time + timedelta(minutes=1),
exchange="test"
)
trade2 = StandardizedTrade(
symbol=self.symbol,
trade_id="2",
price=Decimal("51000"),
size=Decimal("0.5"),
side="sell",
timestamp=self.start_time + timedelta(minutes=2),
exchange="test"
)
# Add first trade
self.bucket.add_trade(trade1)
self.assertEqual(self.bucket.open, Decimal("50000"))
self.assertEqual(self.bucket.high, Decimal("50000"))
self.assertEqual(self.bucket.low, Decimal("50000"))
self.assertEqual(self.bucket.close, Decimal("50000"))
self.assertEqual(self.bucket.volume, Decimal("1"))
self.assertEqual(self.bucket.trade_count, 1)
# Add second trade
self.bucket.add_trade(trade2)
self.assertEqual(self.bucket.open, Decimal("50000"))
self.assertEqual(self.bucket.high, Decimal("51000"))
self.assertEqual(self.bucket.low, Decimal("50000"))
self.assertEqual(self.bucket.close, Decimal("51000"))
self.assertEqual(self.bucket.volume, Decimal("1.5"))
self.assertEqual(self.bucket.trade_count, 2)
def test_bucket_time_boundaries(self):
"""Test that trades are only added within correct time boundaries."""
valid_trade = StandardizedTrade(
symbol=self.symbol,
trade_id="1",
price=Decimal("50000"),
size=Decimal("1"),
side="buy",
timestamp=self.start_time + timedelta(minutes=1),
exchange="test"
)
invalid_trade = StandardizedTrade(
symbol=self.symbol,
trade_id="2",
price=Decimal("51000"),
size=Decimal("1"),
side="buy",
timestamp=self.start_time + timedelta(minutes=6),
exchange="test"
)
self.assertTrue(self.bucket.add_trade(valid_trade))
self.assertFalse(self.bucket.add_trade(invalid_trade))
class TestRealTimeCandleProcessorSafety(unittest.TestCase):
"""Safety net tests for RealTimeCandleProcessor class."""
def setUp(self):
self.symbol = "BTC-USDT"
self.exchange = "test"
self.config = CandleProcessingConfig(timeframes=["1m", "5m"])
self.processor = RealTimeCandleProcessor(self.symbol, self.exchange, self.config)
def test_process_single_trade(self):
"""Test processing a single trade."""
trade = StandardizedTrade(
symbol=self.symbol,
trade_id="1",
price=Decimal("50000"),
size=Decimal("1"),
side="buy",
timestamp=datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc),
exchange=self.exchange
)
completed_candles = self.processor.process_trade(trade)
self.assertEqual(len(completed_candles), 0) # No completed candles yet
current_candles = self.processor.get_current_candles()
self.assertEqual(len(current_candles), 2) # One for each timeframe
def test_candle_completion(self):
"""Test that candles are completed at correct time boundaries."""
# First trade in first minute
trade1 = StandardizedTrade(
symbol=self.symbol,
trade_id="1",
price=Decimal("50000"),
size=Decimal("1"),
side="buy",
timestamp=datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc),
exchange=self.exchange
)
# Second trade in next minute - should complete 1m candle
trade2 = StandardizedTrade(
symbol=self.symbol,
trade_id="2",
price=Decimal("51000"),
size=Decimal("1"),
side="sell",
timestamp=datetime(2024, 1, 1, 10, 1, 15, tzinfo=timezone.utc),
exchange=self.exchange
)
completed1 = self.processor.process_trade(trade1)
self.assertEqual(len(completed1), 0)
completed2 = self.processor.process_trade(trade2)
self.assertEqual(len(completed2), 1) # 1m candle completed
self.assertEqual(completed2[0].timeframe, "1m")
class TestBatchCandleProcessorSafety(unittest.TestCase):
"""Safety net tests for BatchCandleProcessor class."""
def setUp(self):
self.symbol = "BTC-USDT"
self.exchange = "test"
self.timeframes = ["1m", "5m"]
self.processor = BatchCandleProcessor(self.symbol, self.exchange, self.timeframes)
def test_batch_processing(self):
"""Test processing multiple trades in batch."""
trades = [
StandardizedTrade(
symbol=self.symbol,
trade_id=str(i),
price=Decimal(str(50000 + i)),
size=Decimal("1"),
side="buy" if i % 2 == 0 else "sell",
timestamp=datetime(2024, 1, 1, 10, 0, i, tzinfo=timezone.utc),
exchange=self.exchange
)
for i in range(10)
]
candles = self.processor.process_trades_to_candles(iter(trades))
self.assertTrue(len(candles) > 0)
# Verify candle integrity
for candle in candles:
self.assertEqual(candle.symbol, self.symbol)
self.assertTrue(candle.timeframe in self.timeframes)
self.assertTrue(candle.is_complete)
self.assertTrue(candle.volume > 0)
self.assertTrue(candle.trade_count > 0)
class TestAggregationUtilsSafety(unittest.TestCase):
"""Safety net tests for aggregation utility functions."""
def test_validate_timeframe(self):
"""Test timeframe validation."""
valid_timeframes = ['1s', '5s', '10s', '15s', '30s', '1m', '5m', '15m', '30m', '1h', '4h', '1d']
invalid_timeframes = ['2m', '2h', '1w', 'invalid']
for tf in valid_timeframes:
self.assertTrue(validate_timeframe(tf))
for tf in invalid_timeframes:
self.assertFalse(validate_timeframe(tf))
def test_parse_timeframe(self):
"""Test timeframe parsing."""
test_cases = [
('1s', (1, 's')),
('5m', (5, 'm')),
('1h', (1, 'h')),
('1d', (1, 'd'))
]
for tf, expected in test_cases:
self.assertEqual(parse_timeframe(tf), expected)
with self.assertRaises(ValueError):
parse_timeframe('invalid')
if __name__ == '__main__':
unittest.main()