231 lines
8.3 KiB
Python
231 lines
8.3 KiB
Python
|
|
"""
|
||
|
|
Safety net tests for the aggregation package.
|
||
|
|
|
||
|
|
These tests verify the core functionality of the aggregation module
|
||
|
|
before and during refactoring to ensure no regressions are introduced.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import unittest
|
||
|
|
from datetime import datetime, timezone, timedelta
|
||
|
|
from decimal import Decimal
|
||
|
|
from typing import Dict, List
|
||
|
|
|
||
|
|
from data.common.aggregation.bucket import TimeframeBucket
|
||
|
|
from data.common.aggregation.realtime import RealTimeCandleProcessor
|
||
|
|
from data.common.aggregation.batch import BatchCandleProcessor
|
||
|
|
from data.common.aggregation.utils import (
|
||
|
|
validate_timeframe,
|
||
|
|
parse_timeframe,
|
||
|
|
aggregate_trades_to_candles
|
||
|
|
)
|
||
|
|
from data.common import (
|
||
|
|
StandardizedTrade,
|
||
|
|
OHLCVCandle,
|
||
|
|
CandleProcessingConfig
|
||
|
|
)
|
||
|
|
|
||
|
|
class TestTimeframeBucketSafety(unittest.TestCase):
|
||
|
|
"""Safety net tests for TimeframeBucket class."""
|
||
|
|
|
||
|
|
def setUp(self):
|
||
|
|
self.symbol = "BTC-USDT"
|
||
|
|
self.timeframe = "5m"
|
||
|
|
self.start_time = datetime(2024, 1, 1, 10, 0, tzinfo=timezone.utc)
|
||
|
|
self.bucket = TimeframeBucket(self.symbol, self.timeframe, self.start_time)
|
||
|
|
|
||
|
|
def test_bucket_initialization(self):
|
||
|
|
"""Test bucket initialization and time boundaries."""
|
||
|
|
self.assertEqual(self.bucket.symbol, self.symbol)
|
||
|
|
self.assertEqual(self.bucket.timeframe, self.timeframe)
|
||
|
|
self.assertEqual(self.bucket.start_time, self.start_time)
|
||
|
|
self.assertEqual(self.bucket.end_time, self.start_time + timedelta(minutes=5))
|
||
|
|
|
||
|
|
def test_add_trade_updates_ohlcv(self):
|
||
|
|
"""Test that adding trades correctly updates OHLCV data."""
|
||
|
|
trade1 = StandardizedTrade(
|
||
|
|
symbol=self.symbol,
|
||
|
|
trade_id="1",
|
||
|
|
price=Decimal("50000"),
|
||
|
|
size=Decimal("1"),
|
||
|
|
side="buy",
|
||
|
|
timestamp=self.start_time + timedelta(minutes=1),
|
||
|
|
exchange="test"
|
||
|
|
)
|
||
|
|
|
||
|
|
trade2 = StandardizedTrade(
|
||
|
|
symbol=self.symbol,
|
||
|
|
trade_id="2",
|
||
|
|
price=Decimal("51000"),
|
||
|
|
size=Decimal("0.5"),
|
||
|
|
side="sell",
|
||
|
|
timestamp=self.start_time + timedelta(minutes=2),
|
||
|
|
exchange="test"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Add first trade
|
||
|
|
self.bucket.add_trade(trade1)
|
||
|
|
self.assertEqual(self.bucket.open, Decimal("50000"))
|
||
|
|
self.assertEqual(self.bucket.high, Decimal("50000"))
|
||
|
|
self.assertEqual(self.bucket.low, Decimal("50000"))
|
||
|
|
self.assertEqual(self.bucket.close, Decimal("50000"))
|
||
|
|
self.assertEqual(self.bucket.volume, Decimal("1"))
|
||
|
|
self.assertEqual(self.bucket.trade_count, 1)
|
||
|
|
|
||
|
|
# Add second trade
|
||
|
|
self.bucket.add_trade(trade2)
|
||
|
|
self.assertEqual(self.bucket.open, Decimal("50000"))
|
||
|
|
self.assertEqual(self.bucket.high, Decimal("51000"))
|
||
|
|
self.assertEqual(self.bucket.low, Decimal("50000"))
|
||
|
|
self.assertEqual(self.bucket.close, Decimal("51000"))
|
||
|
|
self.assertEqual(self.bucket.volume, Decimal("1.5"))
|
||
|
|
self.assertEqual(self.bucket.trade_count, 2)
|
||
|
|
|
||
|
|
def test_bucket_time_boundaries(self):
|
||
|
|
"""Test that trades are only added within correct time boundaries."""
|
||
|
|
valid_trade = StandardizedTrade(
|
||
|
|
symbol=self.symbol,
|
||
|
|
trade_id="1",
|
||
|
|
price=Decimal("50000"),
|
||
|
|
size=Decimal("1"),
|
||
|
|
side="buy",
|
||
|
|
timestamp=self.start_time + timedelta(minutes=1),
|
||
|
|
exchange="test"
|
||
|
|
)
|
||
|
|
|
||
|
|
invalid_trade = StandardizedTrade(
|
||
|
|
symbol=self.symbol,
|
||
|
|
trade_id="2",
|
||
|
|
price=Decimal("51000"),
|
||
|
|
size=Decimal("1"),
|
||
|
|
side="buy",
|
||
|
|
timestamp=self.start_time + timedelta(minutes=6),
|
||
|
|
exchange="test"
|
||
|
|
)
|
||
|
|
|
||
|
|
self.assertTrue(self.bucket.add_trade(valid_trade))
|
||
|
|
self.assertFalse(self.bucket.add_trade(invalid_trade))
|
||
|
|
|
||
|
|
class TestRealTimeCandleProcessorSafety(unittest.TestCase):
|
||
|
|
"""Safety net tests for RealTimeCandleProcessor class."""
|
||
|
|
|
||
|
|
def setUp(self):
|
||
|
|
self.symbol = "BTC-USDT"
|
||
|
|
self.exchange = "test"
|
||
|
|
self.config = CandleProcessingConfig(timeframes=["1m", "5m"])
|
||
|
|
self.processor = RealTimeCandleProcessor(self.symbol, self.exchange, self.config)
|
||
|
|
|
||
|
|
def test_process_single_trade(self):
|
||
|
|
"""Test processing a single trade."""
|
||
|
|
trade = StandardizedTrade(
|
||
|
|
symbol=self.symbol,
|
||
|
|
trade_id="1",
|
||
|
|
price=Decimal("50000"),
|
||
|
|
size=Decimal("1"),
|
||
|
|
side="buy",
|
||
|
|
timestamp=datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc),
|
||
|
|
exchange=self.exchange
|
||
|
|
)
|
||
|
|
|
||
|
|
completed_candles = self.processor.process_trade(trade)
|
||
|
|
self.assertEqual(len(completed_candles), 0) # No completed candles yet
|
||
|
|
|
||
|
|
current_candles = self.processor.get_current_candles()
|
||
|
|
self.assertEqual(len(current_candles), 2) # One for each timeframe
|
||
|
|
|
||
|
|
def test_candle_completion(self):
|
||
|
|
"""Test that candles are completed at correct time boundaries."""
|
||
|
|
# First trade in first minute
|
||
|
|
trade1 = StandardizedTrade(
|
||
|
|
symbol=self.symbol,
|
||
|
|
trade_id="1",
|
||
|
|
price=Decimal("50000"),
|
||
|
|
size=Decimal("1"),
|
||
|
|
side="buy",
|
||
|
|
timestamp=datetime(2024, 1, 1, 10, 0, 30, tzinfo=timezone.utc),
|
||
|
|
exchange=self.exchange
|
||
|
|
)
|
||
|
|
|
||
|
|
# Second trade in next minute - should complete 1m candle
|
||
|
|
trade2 = StandardizedTrade(
|
||
|
|
symbol=self.symbol,
|
||
|
|
trade_id="2",
|
||
|
|
price=Decimal("51000"),
|
||
|
|
size=Decimal("1"),
|
||
|
|
side="sell",
|
||
|
|
timestamp=datetime(2024, 1, 1, 10, 1, 15, tzinfo=timezone.utc),
|
||
|
|
exchange=self.exchange
|
||
|
|
)
|
||
|
|
|
||
|
|
completed1 = self.processor.process_trade(trade1)
|
||
|
|
self.assertEqual(len(completed1), 0)
|
||
|
|
|
||
|
|
completed2 = self.processor.process_trade(trade2)
|
||
|
|
self.assertEqual(len(completed2), 1) # 1m candle completed
|
||
|
|
self.assertEqual(completed2[0].timeframe, "1m")
|
||
|
|
|
||
|
|
class TestBatchCandleProcessorSafety(unittest.TestCase):
|
||
|
|
"""Safety net tests for BatchCandleProcessor class."""
|
||
|
|
|
||
|
|
def setUp(self):
|
||
|
|
self.symbol = "BTC-USDT"
|
||
|
|
self.exchange = "test"
|
||
|
|
self.timeframes = ["1m", "5m"]
|
||
|
|
self.processor = BatchCandleProcessor(self.symbol, self.exchange, self.timeframes)
|
||
|
|
|
||
|
|
def test_batch_processing(self):
|
||
|
|
"""Test processing multiple trades in batch."""
|
||
|
|
trades = [
|
||
|
|
StandardizedTrade(
|
||
|
|
symbol=self.symbol,
|
||
|
|
trade_id=str(i),
|
||
|
|
price=Decimal(str(50000 + i)),
|
||
|
|
size=Decimal("1"),
|
||
|
|
side="buy" if i % 2 == 0 else "sell",
|
||
|
|
timestamp=datetime(2024, 1, 1, 10, 0, i, tzinfo=timezone.utc),
|
||
|
|
exchange=self.exchange
|
||
|
|
)
|
||
|
|
for i in range(10)
|
||
|
|
]
|
||
|
|
|
||
|
|
candles = self.processor.process_trades_to_candles(iter(trades))
|
||
|
|
self.assertTrue(len(candles) > 0)
|
||
|
|
|
||
|
|
# Verify candle integrity
|
||
|
|
for candle in candles:
|
||
|
|
self.assertEqual(candle.symbol, self.symbol)
|
||
|
|
self.assertTrue(candle.timeframe in self.timeframes)
|
||
|
|
self.assertTrue(candle.is_complete)
|
||
|
|
self.assertTrue(candle.volume > 0)
|
||
|
|
self.assertTrue(candle.trade_count > 0)
|
||
|
|
|
||
|
|
class TestAggregationUtilsSafety(unittest.TestCase):
|
||
|
|
"""Safety net tests for aggregation utility functions."""
|
||
|
|
|
||
|
|
def test_validate_timeframe(self):
|
||
|
|
"""Test timeframe validation."""
|
||
|
|
valid_timeframes = ['1s', '5s', '10s', '15s', '30s', '1m', '5m', '15m', '30m', '1h', '4h', '1d']
|
||
|
|
invalid_timeframes = ['2m', '2h', '1w', 'invalid']
|
||
|
|
|
||
|
|
for tf in valid_timeframes:
|
||
|
|
self.assertTrue(validate_timeframe(tf))
|
||
|
|
|
||
|
|
for tf in invalid_timeframes:
|
||
|
|
self.assertFalse(validate_timeframe(tf))
|
||
|
|
|
||
|
|
def test_parse_timeframe(self):
|
||
|
|
"""Test timeframe parsing."""
|
||
|
|
test_cases = [
|
||
|
|
('1s', (1, 's')),
|
||
|
|
('5m', (5, 'm')),
|
||
|
|
('1h', (1, 'h')),
|
||
|
|
('1d', (1, 'd'))
|
||
|
|
]
|
||
|
|
|
||
|
|
for tf, expected in test_cases:
|
||
|
|
self.assertEqual(parse_timeframe(tf), expected)
|
||
|
|
|
||
|
|
with self.assertRaises(ValueError):
|
||
|
|
parse_timeframe('invalid')
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
unittest.main()
|