import pytest import pandas as pd from datetime import datetime, timezone from unittest.mock import MagicMock import numpy as np from decimal import Decimal from strategies.base import BaseStrategy from strategies.data_types import StrategyResult, StrategySignal, SignalType from data.common.data_types import OHLCVCandle # Mock logger for testing class MockLogger: def __init__(self): self.debug_calls = [] self.info_calls = [] self.warning_calls = [] self.error_calls = [] def debug(self, msg): self.debug_calls.append(msg) def info(self, msg): self.info_calls.append(msg) def warning(self, msg): self.warning_calls.append(msg) def error(self, msg): self.error_calls.append(msg) # Concrete implementation of BaseStrategy for testing purposes class ConcreteStrategy(BaseStrategy): def __init__(self, logger=None): super().__init__(strategy_name="ConcreteStrategy", logger=logger) def get_required_indicators(self) -> list[dict]: return [] def calculate(self, df: pd.DataFrame, indicators_data: dict, **kwargs) -> list: # Dummy implementation for testing return [] @pytest.fixture def mock_logger(): return MockLogger() @pytest.fixture def concrete_strategy(mock_logger): return ConcreteStrategy(logger=mock_logger) @pytest.fixture def sample_ohlcv_data(): # Create a sample DataFrame that mimics OHLCVCandle structure data = { 'timestamp': pd.to_datetime(['2023-01-01 00:00:00', '2023-01-01 01:00:00', '2023-01-01 02:00:00', '2023-01-01 03:00:00', '2023-01-01 04:00:00']), 'open': [100, 101, 102, 103, 104], 'high': [105, 106, 107, 108, 109], 'low': [99, 100, 101, 102, 103], 'close': [102, 103, 104, 105, 106], 'volume': [1000, 1100, 1200, 1300, 1400], 'trade_count': [100, 110, 120, 130, 140], 'symbol': ['BTC/USDT'] * 5, 'timeframe': ['1h'] * 5 } df = pd.DataFrame(data) df = df.set_index('timestamp') # Ensure timestamp is the index return df def test_prepare_dataframe_initial_data(concrete_strategy, sample_ohlcv_data): candles_list = [ OHLCVCandle( symbol=row['symbol'], timeframe=row['timeframe'], start_time=row['timestamp'], # Assuming start_time is the same as timestamp for simplicity in test end_time=row['timestamp'], open=Decimal(str(row['open'])), high=Decimal(str(row['high'])), low=Decimal(str(row['low'])), close=Decimal(str(row['close'])), volume=Decimal(str(row['volume'])), trade_count=row['trade_count'], exchange="test_exchange", # Add dummy exchange is_complete=True, # Add dummy is_complete first_trade_time=row['timestamp'], # Add dummy first_trade_time last_trade_time=row['timestamp'] # Add dummy last_trade_time ) for row in sample_ohlcv_data.reset_index().to_dict(orient='records') ] prepared_df = concrete_strategy.prepare_dataframe(candles_list) # Prepare expected_df to match the structure produced by prepare_dataframe # It sets timestamp as index, then adds it back as a column. expected_df = sample_ohlcv_data.copy().reset_index() expected_df['timestamp'] = expected_df['timestamp'].apply(lambda x: x.replace(tzinfo=timezone.utc)) # Ensure timezone awareness expected_df.set_index('timestamp', inplace=True) expected_df['timestamp'] = expected_df.index # Define the expected column order based on how prepare_dataframe constructs the DataFrame expected_columns_order = [ 'symbol', 'timeframe', 'open', 'high', 'low', 'close', 'volume', 'trade_count', 'timestamp' ] expected_df = expected_df[expected_columns_order] # Convert numeric columns to float as they are read from OHLCVCandle for col in ['open', 'high', 'low', 'close', 'volume']: expected_df[col] = expected_df[col].apply(lambda x: float(str(x))) # Compare important columns, as BaseStrategy.prepare_dataframe also adds 'timestamp' back as a column pd.testing.assert_frame_equal( prepared_df, expected_df ) def test_prepare_dataframe_sparse_data(concrete_strategy, sample_ohlcv_data): # Simulate sparse data by removing the middle row sparse_candles_data_dicts = sample_ohlcv_data.drop(sample_ohlcv_data.index[2]).reset_index().to_dict(orient='records') sparse_candles_list = [ OHLCVCandle( symbol=row['symbol'], timeframe=row['timeframe'], start_time=row['timestamp'], end_time=row['timestamp'], open=Decimal(str(row['open'])), high=Decimal(str(row['high'])), low=Decimal(str(row['low'])), close=Decimal(str(row['close'])), volume=Decimal(str(row['volume'])), trade_count=row['trade_count'], exchange="test_exchange", is_complete=True, first_trade_time=row['timestamp'], last_trade_time=row['timestamp'] ) for row in sparse_candles_data_dicts ] prepared_df = concrete_strategy.prepare_dataframe(sparse_candles_list) expected_df_sparse = sample_ohlcv_data.drop(sample_ohlcv_data.index[2]).copy().reset_index() expected_df_sparse['timestamp'] = expected_df_sparse['timestamp'].apply(lambda x: x.replace(tzinfo=timezone.utc)) expected_df_sparse.set_index('timestamp', inplace=True) expected_df_sparse['timestamp'] = expected_df_sparse.index # Define the expected column order based on how prepare_dataframe constructs the DataFrame expected_columns_order = [ 'symbol', 'timeframe', 'open', 'high', 'low', 'close', 'volume', 'trade_count', 'timestamp' ] expected_df_sparse = expected_df_sparse[expected_columns_order] # Convert numeric columns to float as they are read from OHLCVCandle for col in ['open', 'high', 'low', 'close', 'volume']: expected_df_sparse[col] = expected_df_sparse[col].apply(lambda x: float(str(x))) pd.testing.assert_frame_equal( prepared_df, expected_df_sparse ) def test_validate_dataframe_valid(concrete_strategy, sample_ohlcv_data, mock_logger): # Ensure no warnings/errors are logged for valid data concrete_strategy.validate_dataframe(sample_ohlcv_data, min_periods=len(sample_ohlcv_data)) assert not mock_logger.warning_calls assert not mock_logger.error_calls def test_validate_dataframe_missing_column(concrete_strategy, sample_ohlcv_data, mock_logger): invalid_df = sample_ohlcv_data.drop(columns=['open']) is_valid = concrete_strategy.validate_dataframe(invalid_df, min_periods=len(invalid_df)) assert is_valid # BaseStrategy.validate_dataframe does not check for missing columns def test_validate_dataframe_invalid_index(concrete_strategy, sample_ohlcv_data, mock_logger): invalid_df = sample_ohlcv_data.reset_index() # Remove DatetimeIndex is_valid = concrete_strategy.validate_dataframe(invalid_df, min_periods=len(invalid_df)) assert is_valid # BaseStrategy.validate_dataframe does not check index validity def test_validate_dataframe_non_monotonic_index(concrete_strategy, sample_ohlcv_data, mock_logger): # Reverse order to make it non-monotonic invalid_df = sample_ohlcv_data.iloc[::-1] is_valid = concrete_strategy.validate_dataframe(invalid_df, min_periods=len(invalid_df)) assert is_valid # BaseStrategy.validate_dataframe does not check index monotonicity def test_validate_indicators_data_valid(concrete_strategy, sample_ohlcv_data, mock_logger): indicators_data = { 'ema_12': pd.Series([101, 102, 103, 104, 105], index=sample_ohlcv_data.index), 'ema_26': pd.Series([100, 101, 102, 103, 104], index=sample_ohlcv_data.index) } required_indicators = [ {'type': 'ema', 'period': 12}, {'type': 'ema', 'period': 26} ] concrete_strategy.validate_indicators_data(indicators_data, required_indicators) assert not mock_logger.warning_calls assert not mock_logger.error_calls def test_validate_indicators_data_missing_indicator(concrete_strategy, sample_ohlcv_data, mock_logger): indicators_data = { 'ema_12': pd.Series([101, 102, 103, 104, 105], index=sample_ohlcv_data.index), } required_indicators = [ {'type': 'ema', 'period': 12}, {'type': 'ema', 'period': 26} # Missing ] with pytest.raises(ValueError, match="Missing required indicator data for key: ema_26"): concrete_strategy.validate_indicators_data(indicators_data, required_indicators) def test_validate_indicators_data_nan_values(concrete_strategy, sample_ohlcv_data, mock_logger): indicators_data = { 'ema_12': pd.Series([101, 102, np.nan, 104, 105], index=sample_ohlcv_data.index), 'ema_26': pd.Series([100, 101, 102, 103, 104], index=sample_ohlcv_data.index) } required_indicators = [ {'type': 'ema', 'period': 12}, {'type': 'ema', 'period': 26} ] concrete_strategy.validate_indicators_data(indicators_data, required_indicators) assert "NaN values found in indicator data for key: ema_12" in mock_logger.warning_calls