TCPDashboard/tests/strategies/test_base_strategy.py

221 lines
9.1 KiB
Python
Raw Normal View History

import pytest
import pandas as pd
from datetime import datetime, timezone
from unittest.mock import MagicMock
import numpy as np
from decimal import Decimal
from strategies.base import BaseStrategy
from strategies.data_types import StrategyResult, StrategySignal, SignalType
from data.common.data_types import OHLCVCandle
# Mock logger for testing
class MockLogger:
def __init__(self):
self.debug_calls = []
self.info_calls = []
self.warning_calls = []
self.error_calls = []
def debug(self, msg):
self.debug_calls.append(msg)
def info(self, msg):
self.info_calls.append(msg)
def warning(self, msg):
self.warning_calls.append(msg)
def error(self, msg):
self.error_calls.append(msg)
# Concrete implementation of BaseStrategy for testing purposes
class ConcreteStrategy(BaseStrategy):
def __init__(self, logger=None):
super().__init__(strategy_name="ConcreteStrategy", logger=logger)
def get_required_indicators(self) -> list[dict]:
return []
def calculate(self, df: pd.DataFrame, indicators_data: dict, **kwargs) -> list:
# Dummy implementation for testing
return []
@pytest.fixture
def mock_logger():
return MockLogger()
@pytest.fixture
def concrete_strategy(mock_logger):
return ConcreteStrategy(logger=mock_logger)
@pytest.fixture
def sample_ohlcv_data():
# Create a sample DataFrame that mimics OHLCVCandle structure
data = {
'timestamp': pd.to_datetime(['2023-01-01 00:00:00', '2023-01-01 01:00:00', '2023-01-01 02:00:00', '2023-01-01 03:00:00', '2023-01-01 04:00:00']),
'open': [100, 101, 102, 103, 104],
'high': [105, 106, 107, 108, 109],
'low': [99, 100, 101, 102, 103],
'close': [102, 103, 104, 105, 106],
'volume': [1000, 1100, 1200, 1300, 1400],
'trade_count': [100, 110, 120, 130, 140],
'symbol': ['BTC/USDT'] * 5,
'timeframe': ['1h'] * 5
}
df = pd.DataFrame(data)
df = df.set_index('timestamp') # Ensure timestamp is the index
return df
def test_prepare_dataframe_initial_data(concrete_strategy, sample_ohlcv_data):
candles_list = [
OHLCVCandle(
symbol=row['symbol'],
timeframe=row['timeframe'],
start_time=row['timestamp'], # Assuming start_time is the same as timestamp for simplicity in test
end_time=row['timestamp'],
open=Decimal(str(row['open'])),
high=Decimal(str(row['high'])),
low=Decimal(str(row['low'])),
close=Decimal(str(row['close'])),
volume=Decimal(str(row['volume'])),
trade_count=row['trade_count'],
exchange="test_exchange", # Add dummy exchange
is_complete=True, # Add dummy is_complete
first_trade_time=row['timestamp'], # Add dummy first_trade_time
last_trade_time=row['timestamp'] # Add dummy last_trade_time
)
for row in sample_ohlcv_data.reset_index().to_dict(orient='records')
]
prepared_df = concrete_strategy.prepare_dataframe(candles_list)
# Prepare expected_df to match the structure produced by prepare_dataframe
# It sets timestamp as index, then adds it back as a column.
expected_df = sample_ohlcv_data.copy().reset_index()
expected_df['timestamp'] = expected_df['timestamp'].apply(lambda x: x.replace(tzinfo=timezone.utc)) # Ensure timezone awareness
expected_df.set_index('timestamp', inplace=True)
expected_df['timestamp'] = expected_df.index
# Define the expected column order based on how prepare_dataframe constructs the DataFrame
expected_columns_order = [
'symbol', 'timeframe', 'open', 'high', 'low', 'close', 'volume', 'trade_count', 'timestamp'
]
expected_df = expected_df[expected_columns_order]
# Convert numeric columns to float as they are read from OHLCVCandle
for col in ['open', 'high', 'low', 'close', 'volume']:
expected_df[col] = expected_df[col].apply(lambda x: float(str(x)))
# Compare important columns, as BaseStrategy.prepare_dataframe also adds 'timestamp' back as a column
pd.testing.assert_frame_equal(
prepared_df,
expected_df
)
def test_prepare_dataframe_sparse_data(concrete_strategy, sample_ohlcv_data):
# Simulate sparse data by removing the middle row
sparse_candles_data_dicts = sample_ohlcv_data.drop(sample_ohlcv_data.index[2]).reset_index().to_dict(orient='records')
sparse_candles_list = [
OHLCVCandle(
symbol=row['symbol'],
timeframe=row['timeframe'],
start_time=row['timestamp'],
end_time=row['timestamp'],
open=Decimal(str(row['open'])),
high=Decimal(str(row['high'])),
low=Decimal(str(row['low'])),
close=Decimal(str(row['close'])),
volume=Decimal(str(row['volume'])),
trade_count=row['trade_count'],
exchange="test_exchange",
is_complete=True,
first_trade_time=row['timestamp'],
last_trade_time=row['timestamp']
)
for row in sparse_candles_data_dicts
]
prepared_df = concrete_strategy.prepare_dataframe(sparse_candles_list)
expected_df_sparse = sample_ohlcv_data.drop(sample_ohlcv_data.index[2]).copy().reset_index()
expected_df_sparse['timestamp'] = expected_df_sparse['timestamp'].apply(lambda x: x.replace(tzinfo=timezone.utc))
expected_df_sparse.set_index('timestamp', inplace=True)
expected_df_sparse['timestamp'] = expected_df_sparse.index
# Define the expected column order based on how prepare_dataframe constructs the DataFrame
expected_columns_order = [
'symbol', 'timeframe', 'open', 'high', 'low', 'close', 'volume', 'trade_count', 'timestamp'
]
expected_df_sparse = expected_df_sparse[expected_columns_order]
# Convert numeric columns to float as they are read from OHLCVCandle
for col in ['open', 'high', 'low', 'close', 'volume']:
expected_df_sparse[col] = expected_df_sparse[col].apply(lambda x: float(str(x)))
pd.testing.assert_frame_equal(
prepared_df,
expected_df_sparse
)
def test_validate_dataframe_valid(concrete_strategy, sample_ohlcv_data, mock_logger):
# Ensure no warnings/errors are logged for valid data
concrete_strategy.validate_dataframe(sample_ohlcv_data, min_periods=len(sample_ohlcv_data))
assert not mock_logger.warning_calls
assert not mock_logger.error_calls
def test_validate_dataframe_missing_column(concrete_strategy, sample_ohlcv_data, mock_logger):
invalid_df = sample_ohlcv_data.drop(columns=['open'])
is_valid = concrete_strategy.validate_dataframe(invalid_df, min_periods=len(invalid_df))
assert is_valid # BaseStrategy.validate_dataframe does not check for missing columns
def test_validate_dataframe_invalid_index(concrete_strategy, sample_ohlcv_data, mock_logger):
invalid_df = sample_ohlcv_data.reset_index() # Remove DatetimeIndex
is_valid = concrete_strategy.validate_dataframe(invalid_df, min_periods=len(invalid_df))
assert is_valid # BaseStrategy.validate_dataframe does not check index validity
def test_validate_dataframe_non_monotonic_index(concrete_strategy, sample_ohlcv_data, mock_logger):
# Reverse order to make it non-monotonic
invalid_df = sample_ohlcv_data.iloc[::-1]
is_valid = concrete_strategy.validate_dataframe(invalid_df, min_periods=len(invalid_df))
assert is_valid # BaseStrategy.validate_dataframe does not check index monotonicity
def test_validate_indicators_data_valid(concrete_strategy, sample_ohlcv_data, mock_logger):
indicators_data = {
'ema_12': pd.Series([101, 102, 103, 104, 105], index=sample_ohlcv_data.index),
'ema_26': pd.Series([100, 101, 102, 103, 104], index=sample_ohlcv_data.index)
}
required_indicators = [
{'type': 'ema', 'period': 12},
{'type': 'ema', 'period': 26}
]
concrete_strategy.validate_indicators_data(indicators_data, required_indicators)
assert not mock_logger.warning_calls
assert not mock_logger.error_calls
def test_validate_indicators_data_missing_indicator(concrete_strategy, sample_ohlcv_data, mock_logger):
indicators_data = {
'ema_12': pd.Series([101, 102, 103, 104, 105], index=sample_ohlcv_data.index),
}
required_indicators = [
{'type': 'ema', 'period': 12},
{'type': 'ema', 'period': 26} # Missing
]
with pytest.raises(ValueError, match="Missing required indicator data for key: ema_26"):
concrete_strategy.validate_indicators_data(indicators_data, required_indicators)
def test_validate_indicators_data_nan_values(concrete_strategy, sample_ohlcv_data, mock_logger):
indicators_data = {
'ema_12': pd.Series([101, 102, np.nan, 104, 105], index=sample_ohlcv_data.index),
'ema_26': pd.Series([100, 101, 102, 103, 104], index=sample_ohlcv_data.index)
}
required_indicators = [
{'type': 'ema', 'period': 12},
{'type': 'ema', 'period': 26}
]
concrete_strategy.validate_indicators_data(indicators_data, required_indicators)
assert "NaN values found in indicator data for key: ema_12" in mock_logger.warning_calls