- Split the `aggregation.py` file into a dedicated sub-package, improving modularity and maintainability. - Moved `TimeframeBucket`, `RealTimeCandleProcessor`, and `BatchCandleProcessor` classes into their respective files within the new `aggregation` sub-package. - Introduced utility functions for trade aggregation and validation, enhancing code organization. - Updated import paths throughout the codebase to reflect the new structure, ensuring compatibility. - Added safety net tests for the aggregation package to verify core functionality and prevent regressions during refactoring. These changes enhance the overall architecture of the aggregation module, making it more scalable and easier to manage.
231 lines
8.3 KiB
Python
231 lines
8.3 KiB
Python
"""
|
|
Safety net tests for the aggregation package.
|
|
|
|
These tests verify the core functionality of the aggregation module
|
|
before and during refactoring to ensure no regressions are introduced.
|
|
"""
|
|
|
|
import unittest
|
|
from datetime import datetime, timezone, timedelta
|
|
from decimal import Decimal
|
|
from typing import Dict, List
|
|
|
|
from data.common.aggregation.bucket import TimeframeBucket
|
|
from data.common.aggregation.realtime import RealTimeCandleProcessor
|
|
from data.common.aggregation.batch import BatchCandleProcessor
|
|
from data.common.aggregation.utils import (
|
|
validate_timeframe,
|
|
parse_timeframe,
|
|
aggregate_trades_to_candles
|
|
)
|
|
from data.common import (
|
|
StandardizedTrade,
|
|
OHLCVCandle,
|
|
CandleProcessingConfig
|
|
)
|
|
|
|
class TestTimeframeBucketSafety(unittest.TestCase):
|
|
"""Safety net tests for TimeframeBucket class."""
|
|
|
|
def setUp(self):
|
|
self.symbol = "BTC-USDT"
|
|
self.timeframe = "5m"
|
|
self.start_time = datetime(2024, 1, 1, 10, 0, tzinfo=timezone.utc)
|
|
self.bucket = TimeframeBucket(self.symbol, self.timeframe, self.start_time)
|
|
|
|
def test_bucket_initialization(self):
|
|
"""Test bucket initialization and time boundaries."""
|
|
self.assertEqual(self.bucket.symbol, self.symbol)
|
|
self.assertEqual(self.bucket.timeframe, self.timeframe)
|
|
self.assertEqual(self.bucket.start_time, self.start_time)
|
|
self.assertEqual(self.bucket.end_time, self.start_time + timedelta(minutes=5))
|
|
|
|
def test_add_trade_updates_ohlcv(self):
|
|
"""Test that adding trades correctly updates OHLCV data."""
|
|
trade1 = StandardizedTrade(
|
|
symbol=self.symbol,
|
|
trade_id="1",
|
|
price=Decimal("50000"),
|
|
size=Decimal("1"),
|
|
side="buy",
|
|
timestamp=self.start_time + timedelta(minutes=1),
|
|
exchange="test"
|
|
)
|
|
|
|
trade2 = StandardizedTrade(
|
|
symbol=self.symbol,
|
|
trade_id="2",
|
|
price=Decimal("51000"),
|
|
size=Decimal("0.5"),
|
|
side="sell",
|
|
timestamp=self.start_time + timedelta(minutes=2),
|
|
exchange="test"
|
|
)
|
|
|
|
# Add first trade
|
|
self.bucket.add_trade(trade1)
|
|
self.assertEqual(self.bucket.open, Decimal("50000"))
|
|
self.assertEqual(self.bucket.high, Decimal("50000"))
|
|
self.assertEqual(self.bucket.low, Decimal("50000"))
|
|
self.assertEqual(self.bucket.close, Decimal("50000"))
|
|
self.assertEqual(self.bucket.volume, Decimal("1"))
|
|
self.assertEqual(self.bucket.trade_count, 1)
|
|
|
|
# Add second trade
|
|
self.bucket.add_trade(trade2)
|
|
self.assertEqual(self.bucket.open, Decimal("50000"))
|
|
self.assertEqual(self.bucket.high, Decimal("51000"))
|
|
self.assertEqual(self.bucket.low, Decimal("50000"))
|
|
self.assertEqual(self.bucket.close, Decimal("51000"))
|
|
self.assertEqual(self.bucket.volume, Decimal("1.5"))
|
|
self.assertEqual(self.bucket.trade_count, 2)
|
|
|
|
def test_bucket_time_boundaries(self):
|
|
"""Test that trades are only added within correct time boundaries."""
|
|
valid_trade = StandardizedTrade(
|
|
symbol=self.symbol,
|
|
trade_id="1",
|
|
price=Decimal("50000"),
|
|
size=Decimal("1"),
|
|
side="buy",
|
|
timestamp=self.start_time + timedelta(minutes=1),
|
|
exchange="test"
|
|
)
|
|
|
|
invalid_trade = StandardizedTrade(
|
|
symbol=self.symbol,
|
|
trade_id="2",
|
|
price=Decimal("51000"),
|
|
size=Decimal("1"),
|
|
side="buy",
|
|
timestamp=self.start_time + timedelta(minutes=6),
|
|
exchange="test"
|
|
)
|
|
|
|
self.assertTrue(self.bucket.add_trade(valid_trade))
|
|
self.assertFalse(self.bucket.add_trade(invalid_trade))
|
|
|
|
class TestRealTimeCandleProcessorSafety(unittest.TestCase):
|
|
"""Safety net tests for RealTimeCandleProcessor class."""
|
|
|
|
def setUp(self):
|
|
self.symbol = "BTC-USDT"
|
|
self.exchange = "test"
|
|
self.config = CandleProcessingConfig(timeframes=["1m", "5m"])
|
|
self.processor = RealTimeCandleProcessor(self.symbol, self.exchange, self.config)
|
|
|
|
def test_process_single_trade(self):
|
|
"""Test processing a single trade."""
|
|
trade = StandardizedTrade(
|
|
symbol=self.symbol,
|
|
trade_id="1",
|
|
price=Decimal("50000"),
|
|
size=Decimal("1"),
|
|
side="buy",
|
|
timestamp=datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc),
|
|
exchange=self.exchange
|
|
)
|
|
|
|
completed_candles = self.processor.process_trade(trade)
|
|
self.assertEqual(len(completed_candles), 0) # No completed candles yet
|
|
|
|
current_candles = self.processor.get_current_candles()
|
|
self.assertEqual(len(current_candles), 2) # One for each timeframe
|
|
|
|
def test_candle_completion(self):
|
|
"""Test that candles are completed at correct time boundaries."""
|
|
# First trade in first minute
|
|
trade1 = StandardizedTrade(
|
|
symbol=self.symbol,
|
|
trade_id="1",
|
|
price=Decimal("50000"),
|
|
size=Decimal("1"),
|
|
side="buy",
|
|
timestamp=datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc),
|
|
exchange=self.exchange
|
|
)
|
|
|
|
# Second trade in next minute - should complete 1m candle
|
|
trade2 = StandardizedTrade(
|
|
symbol=self.symbol,
|
|
trade_id="2",
|
|
price=Decimal("51000"),
|
|
size=Decimal("1"),
|
|
side="sell",
|
|
timestamp=datetime(2024, 1, 1, 10, 1, 15, tzinfo=timezone.utc),
|
|
exchange=self.exchange
|
|
)
|
|
|
|
completed1 = self.processor.process_trade(trade1)
|
|
self.assertEqual(len(completed1), 0)
|
|
|
|
completed2 = self.processor.process_trade(trade2)
|
|
self.assertEqual(len(completed2), 1) # 1m candle completed
|
|
self.assertEqual(completed2[0].timeframe, "1m")
|
|
|
|
class TestBatchCandleProcessorSafety(unittest.TestCase):
|
|
"""Safety net tests for BatchCandleProcessor class."""
|
|
|
|
def setUp(self):
|
|
self.symbol = "BTC-USDT"
|
|
self.exchange = "test"
|
|
self.timeframes = ["1m", "5m"]
|
|
self.processor = BatchCandleProcessor(self.symbol, self.exchange, self.timeframes)
|
|
|
|
def test_batch_processing(self):
|
|
"""Test processing multiple trades in batch."""
|
|
trades = [
|
|
StandardizedTrade(
|
|
symbol=self.symbol,
|
|
trade_id=str(i),
|
|
price=Decimal(str(50000 + i)),
|
|
size=Decimal("1"),
|
|
side="buy" if i % 2 == 0 else "sell",
|
|
timestamp=datetime(2024, 1, 1, 10, 0, i, tzinfo=timezone.utc),
|
|
exchange=self.exchange
|
|
)
|
|
for i in range(10)
|
|
]
|
|
|
|
candles = self.processor.process_trades_to_candles(iter(trades))
|
|
self.assertTrue(len(candles) > 0)
|
|
|
|
# Verify candle integrity
|
|
for candle in candles:
|
|
self.assertEqual(candle.symbol, self.symbol)
|
|
self.assertTrue(candle.timeframe in self.timeframes)
|
|
self.assertTrue(candle.is_complete)
|
|
self.assertTrue(candle.volume > 0)
|
|
self.assertTrue(candle.trade_count > 0)
|
|
|
|
class TestAggregationUtilsSafety(unittest.TestCase):
|
|
"""Safety net tests for aggregation utility functions."""
|
|
|
|
def test_validate_timeframe(self):
|
|
"""Test timeframe validation."""
|
|
valid_timeframes = ['1s', '5s', '10s', '15s', '30s', '1m', '5m', '15m', '30m', '1h', '4h', '1d']
|
|
invalid_timeframes = ['2m', '2h', '1w', 'invalid']
|
|
|
|
for tf in valid_timeframes:
|
|
self.assertTrue(validate_timeframe(tf))
|
|
|
|
for tf in invalid_timeframes:
|
|
self.assertFalse(validate_timeframe(tf))
|
|
|
|
def test_parse_timeframe(self):
|
|
"""Test timeframe parsing."""
|
|
test_cases = [
|
|
('1s', (1, 's')),
|
|
('5m', (5, 'm')),
|
|
('1h', (1, 'h')),
|
|
('1d', (1, 'd'))
|
|
]
|
|
|
|
for tf, expected in test_cases:
|
|
self.assertEqual(parse_timeframe(tf), expected)
|
|
|
|
with self.assertRaises(ValueError):
|
|
parse_timeframe('invalid')
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |