Refactor technical indicators module and enhance structure
- Introduced a dedicated sub-package for technical indicators under `data/common/indicators/`, improving modularity and maintainability. - Moved `TechnicalIndicators` and `IndicatorResult` classes to their respective files, along with utility functions for configuration management. - Updated import paths throughout the codebase to reflect the new structure, ensuring compatibility. - Added comprehensive safety net tests for the indicators module to verify core functionality and prevent regressions during refactoring. - Enhanced documentation to provide clear usage examples and details on the new package structure. These changes improve the overall architecture of the technical indicators module, making it more scalable and easier to manage.
This commit is contained in:
parent
e7ede7f329
commit
c8d8d980aa
@ -17,7 +17,8 @@ from ..error_handling import (
|
||||
)
|
||||
|
||||
from .base import BaseLayer, LayerConfig
|
||||
from data.common.indicators import TechnicalIndicators, OHLCVCandle
|
||||
from data.common.indicators import TechnicalIndicators
|
||||
from data.common.data_types import OHLCVCandle
|
||||
from components.charts.utils import get_indicator_colors
|
||||
from utils.logger import get_logger
|
||||
|
||||
|
||||
@ -14,7 +14,8 @@ from dataclasses import dataclass
|
||||
|
||||
from .base import BaseChartLayer, LayerConfig
|
||||
from .indicators import BaseIndicatorLayer, IndicatorLayerConfig
|
||||
from data.common.indicators import TechnicalIndicators, IndicatorResult, OHLCVCandle
|
||||
from data.common.indicators import TechnicalIndicators, IndicatorResult
|
||||
from data.common.data_types import OHLCVCandle
|
||||
from components.charts.utils import get_indicator_colors
|
||||
from utils.logger import get_logger
|
||||
from ..error_handling import (
|
||||
|
||||
26
data/common/indicators/__init__.py
Normal file
26
data/common/indicators/__init__.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""
|
||||
Technical Indicators Package
|
||||
|
||||
This package provides technical indicator calculations optimized for sparse OHLCV data
|
||||
as produced by the TCP Trading Platform's aggregation strategy.
|
||||
|
||||
IMPORTANT: Handles Sparse Data
|
||||
- Missing candles (time gaps) are normal in this system
|
||||
- Indicators properly handle gaps without interpolation
|
||||
- Uses pandas for efficient vectorized calculations
|
||||
- Follows right-aligned timestamp convention
|
||||
"""
|
||||
|
||||
from .technical import TechnicalIndicators
|
||||
from .result import IndicatorResult
|
||||
from .utils import (
|
||||
create_default_indicators_config,
|
||||
validate_indicator_config
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'TechnicalIndicators',
|
||||
'IndicatorResult',
|
||||
'create_default_indicators_config',
|
||||
'validate_indicator_config'
|
||||
]
|
||||
29
data/common/indicators/result.py
Normal file
29
data/common/indicators/result.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""
|
||||
Technical Indicator Result Container
|
||||
|
||||
This module provides the IndicatorResult dataclass for storing
|
||||
technical indicator calculation results in a standardized format.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndicatorResult:
|
||||
"""
|
||||
Container for technical indicator calculation results.
|
||||
|
||||
Attributes:
|
||||
timestamp: Candle timestamp (right-aligned)
|
||||
symbol: Trading symbol
|
||||
timeframe: Candle timeframe
|
||||
values: Dictionary of indicator values
|
||||
metadata: Additional calculation metadata
|
||||
"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
timeframe: str
|
||||
values: Dict[str, float]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
@ -18,33 +18,13 @@ Supported Indicators:
|
||||
- Bollinger Bands
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Dict, List, Optional, Any, Union, Tuple
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .data_types import OHLCVCandle
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndicatorResult:
|
||||
"""
|
||||
Container for technical indicator calculation results.
|
||||
|
||||
Attributes:
|
||||
timestamp: Candle timestamp (right-aligned)
|
||||
symbol: Trading symbol
|
||||
timeframe: Candle timeframe
|
||||
values: Dictionary of indicator values
|
||||
metadata: Additional calculation metadata
|
||||
"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
timeframe: str
|
||||
values: Dict[str, float]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
from .result import IndicatorResult
|
||||
from ..data_types import OHLCVCandle
|
||||
|
||||
|
||||
class TechnicalIndicators:
|
||||
@ -112,7 +92,7 @@ class TechnicalIndicators:
|
||||
|
||||
return df
|
||||
|
||||
def sma(self, df: pd.DataFrame, period: int,
|
||||
def sma(self, df: pd.DataFrame, period: int,
|
||||
price_column: str = 'close') -> List[IndicatorResult]:
|
||||
"""
|
||||
Calculate Simple Moving Average (SMA).
|
||||
@ -231,7 +211,7 @@ class TechnicalIndicators:
|
||||
|
||||
return results
|
||||
|
||||
def macd(self, df: pd.DataFrame,
|
||||
def macd(self, df: pd.DataFrame,
|
||||
fast_period: int = 12, slow_period: int = 26, signal_period: int = 9,
|
||||
price_column: str = 'close') -> List[IndicatorResult]:
|
||||
"""
|
||||
@ -289,7 +269,7 @@ class TechnicalIndicators:
|
||||
|
||||
return results
|
||||
|
||||
def bollinger_bands(self, df: pd.DataFrame, period: int = 20,
|
||||
def bollinger_bands(self, df: pd.DataFrame, period: int = 20,
|
||||
std_dev: float = 2.0, price_column: str = 'close') -> List[IndicatorResult]:
|
||||
"""
|
||||
Calculate Bollinger Bands.
|
||||
@ -345,13 +325,13 @@ class TechnicalIndicators:
|
||||
|
||||
return results
|
||||
|
||||
def calculate_multiple_indicators(self, candles: List[OHLCVCandle],
|
||||
def calculate_multiple_indicators(self, df: pd.DataFrame,
|
||||
indicators_config: Dict[str, Dict[str, Any]]) -> Dict[str, List[IndicatorResult]]:
|
||||
"""
|
||||
Calculate multiple indicators at once for efficiency.
|
||||
|
||||
Args:
|
||||
candles: List of OHLCV candles
|
||||
df: DataFrame with OHLCV data
|
||||
indicators_config: Configuration for indicators to calculate
|
||||
Example: {
|
||||
'sma_20': {'type': 'sma', 'period': 20},
|
||||
@ -373,30 +353,30 @@ class TechnicalIndicators:
|
||||
if indicator_type == 'sma':
|
||||
period = config.get('period', 20)
|
||||
price_column = config.get('price_column', 'close')
|
||||
results[indicator_name] = self.sma(candles, period, price_column)
|
||||
results[indicator_name] = self.sma(df, period, price_column)
|
||||
|
||||
elif indicator_type == 'ema':
|
||||
period = config.get('period', 20)
|
||||
price_column = config.get('price_column', 'close')
|
||||
results[indicator_name] = self.ema(candles, period, price_column)
|
||||
results[indicator_name] = self.ema(df, period, price_column)
|
||||
|
||||
elif indicator_type == 'rsi':
|
||||
period = config.get('period', 14)
|
||||
price_column = config.get('price_column', 'close')
|
||||
results[indicator_name] = self.rsi(candles, period, price_column)
|
||||
results[indicator_name] = self.rsi(df, period, price_column)
|
||||
|
||||
elif indicator_type == 'macd':
|
||||
fast_period = config.get('fast_period', 12)
|
||||
slow_period = config.get('slow_period', 26)
|
||||
signal_period = config.get('signal_period', 9)
|
||||
price_column = config.get('price_column', 'close')
|
||||
results[indicator_name] = self.macd(candles, fast_period, slow_period, signal_period, price_column)
|
||||
results[indicator_name] = self.macd(df, fast_period, slow_period, signal_period, price_column)
|
||||
|
||||
elif indicator_type == 'bollinger_bands':
|
||||
period = config.get('period', 20)
|
||||
std_dev = config.get('std_dev', 2.0)
|
||||
price_column = config.get('price_column', 'close')
|
||||
results[indicator_name] = self.bollinger_bands(candles, period, std_dev, price_column)
|
||||
results[indicator_name] = self.bollinger_bands(df, period, std_dev, price_column)
|
||||
|
||||
else:
|
||||
if self.logger:
|
||||
@ -410,13 +390,13 @@ class TechnicalIndicators:
|
||||
|
||||
return results
|
||||
|
||||
def calculate(self, indicator_type: str, candles: Union[pd.DataFrame, List[OHLCVCandle]], **kwargs) -> Optional[Dict[str, Any]]:
|
||||
def calculate(self, indicator_type: str, df: pd.DataFrame, **kwargs) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Calculate a single indicator with dynamic dispatch.
|
||||
|
||||
Args:
|
||||
indicator_type: Name of the indicator (e.g., 'sma', 'ema')
|
||||
candles: List of OHLCV candles or a pre-prepared DataFrame
|
||||
df: DataFrame with OHLCV data
|
||||
**kwargs: Indicator-specific parameters (e.g., period=20)
|
||||
|
||||
Returns:
|
||||
@ -430,14 +410,6 @@ class TechnicalIndicators:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Prepare DataFrame if input is a list of candles
|
||||
if isinstance(candles, list):
|
||||
df = self._prepare_dataframe_from_list(candles)
|
||||
elif isinstance(candles, pd.DataFrame):
|
||||
df = candles
|
||||
else:
|
||||
raise TypeError("Input 'candles' must be a list of OHLCVCandle objects or a pandas DataFrame.")
|
||||
|
||||
if df.empty:
|
||||
return {'data': [], 'metadata': {}}
|
||||
|
||||
@ -458,56 +430,4 @@ class TechnicalIndicators:
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"TechnicalIndicators: Error calculating {indicator_type}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_default_indicators_config() -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Create default configuration for common technical indicators.
|
||||
|
||||
Returns:
|
||||
Dictionary with default indicator configurations
|
||||
"""
|
||||
return {
|
||||
'sma_20': {'type': 'sma', 'period': 20},
|
||||
'sma_50': {'type': 'sma', 'period': 50},
|
||||
'ema_12': {'type': 'ema', 'period': 12},
|
||||
'ema_26': {'type': 'ema', 'period': 26},
|
||||
'rsi_14': {'type': 'rsi', 'period': 14},
|
||||
'macd_default': {'type': 'macd'},
|
||||
'bollinger_bands_20': {'type': 'bollinger_bands', 'period': 20}
|
||||
}
|
||||
|
||||
|
||||
def validate_indicator_config(config: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate technical indicator configuration.
|
||||
|
||||
Args:
|
||||
config: Indicator configuration dictionary
|
||||
|
||||
Returns:
|
||||
True if configuration is valid, False otherwise
|
||||
"""
|
||||
required_fields = ['type']
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
return False
|
||||
|
||||
# Validate indicator type
|
||||
valid_types = ['sma', 'ema', 'rsi', 'macd', 'bollinger_bands']
|
||||
if config['type'] not in valid_types:
|
||||
return False
|
||||
|
||||
# Validate period fields
|
||||
if 'period' in config and (not isinstance(config['period'], int) or config['period'] <= 0):
|
||||
return False
|
||||
|
||||
# Validate standard deviation for Bollinger Bands
|
||||
if config['type'] == 'bollinger_bands' and 'std_dev' in config:
|
||||
if not isinstance(config['std_dev'], (int, float)) or config['std_dev'] <= 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
return None
|
||||
60
data/common/indicators/utils.py
Normal file
60
data/common/indicators/utils.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""
|
||||
Technical Indicator Utilities
|
||||
|
||||
This module provides utility functions for managing technical indicator
|
||||
configurations and validation.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def create_default_indicators_config() -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Create default configuration for common technical indicators.
|
||||
|
||||
Returns:
|
||||
Dictionary with default indicator configurations
|
||||
"""
|
||||
return {
|
||||
'sma_20': {'type': 'sma', 'period': 20},
|
||||
'sma_50': {'type': 'sma', 'period': 50},
|
||||
'ema_12': {'type': 'ema', 'period': 12},
|
||||
'ema_26': {'type': 'ema', 'period': 26},
|
||||
'rsi_14': {'type': 'rsi', 'period': 14},
|
||||
'macd_default': {'type': 'macd'},
|
||||
'bollinger_bands_20': {'type': 'bollinger_bands', 'period': 20}
|
||||
}
|
||||
|
||||
|
||||
def validate_indicator_config(config: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate technical indicator configuration.
|
||||
|
||||
Args:
|
||||
config: Indicator configuration dictionary
|
||||
|
||||
Returns:
|
||||
True if configuration is valid, False otherwise
|
||||
"""
|
||||
required_fields = ['type']
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
return False
|
||||
|
||||
# Validate indicator type
|
||||
valid_types = ['sma', 'ema', 'rsi', 'macd', 'bollinger_bands']
|
||||
if config['type'] not in valid_types:
|
||||
return False
|
||||
|
||||
# Validate period fields
|
||||
if 'period' in config and (not isinstance(config['period'], int) or config['period'] <= 0):
|
||||
return False
|
||||
|
||||
# Validate standard deviation for Bollinger Bands
|
||||
if config['type'] == 'bollinger_bands' and 'std_dev' in config:
|
||||
if not isinstance(config['std_dev'], (int, float)) or config['std_dev'] <= 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -4,7 +4,17 @@ The Technical Indicators module provides a suite of common technical analysis to
|
||||
|
||||
## Overview
|
||||
|
||||
The module has been refactored to be **DataFrame-centric**. All calculation methods now expect a pandas DataFrame with a `DatetimeIndex` and the required OHLCV columns (`open`, `high`, `low`, `close`, `volume`). This change simplifies the data pipeline, improves performance through vectorization, and ensures consistency across the platform.
|
||||
The module has been refactored into a dedicated package structure under `data/common/indicators/`. All calculation methods now expect a pandas DataFrame with a `DatetimeIndex` and the required OHLCV columns (`open`, `high`, `low`, `close`, `volume`). This change simplifies the data pipeline, improves performance through vectorization, and ensures consistency across the platform.
|
||||
|
||||
### Package Structure
|
||||
|
||||
```
|
||||
data/common/indicators/
|
||||
├── __init__.py # Package exports
|
||||
├── technical.py # TechnicalIndicators class implementation
|
||||
├── result.py # IndicatorResult dataclass
|
||||
└── utils.py # Utility functions for configuration
|
||||
```
|
||||
|
||||
The module implements five core technical indicators:
|
||||
|
||||
@ -20,9 +30,22 @@ The module implements five core technical indicators:
|
||||
- **Vectorized Calculations**: Leverages pandas and numpy for high-speed computation.
|
||||
- **Flexible `calculate` Method**: A single entry point for calculating any supported indicator by name.
|
||||
- **Standardized Output**: All methods return a DataFrame containing the calculated indicator values, indexed by timestamp.
|
||||
- **Modular Architecture**: Clear separation between calculation logic, result types, and utilities.
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Importing the Required Components
|
||||
|
||||
```python
|
||||
from data.common.indicators import (
|
||||
TechnicalIndicators,
|
||||
IndicatorResult,
|
||||
create_default_indicators_config,
|
||||
validate_indicator_config
|
||||
)
|
||||
from data.common.data_types import OHLCVCandle
|
||||
```
|
||||
|
||||
### Preparing the DataFrame
|
||||
|
||||
Before you can calculate indicators, you need a properly formatted pandas DataFrame. The `prepare_chart_data` utility is the recommended way to create one from a list of candle dictionaries.
|
||||
@ -115,15 +138,11 @@ The following details the parameters and the columns returned in the result Data
|
||||
- **Parameters**: `period` (int), `std_dev` (float), `price_column` (str, default: 'close')
|
||||
- **Returned Columns**: `upper_band`, `middle_band`, `lower_band`
|
||||
|
||||
## Integration with the TCP Platform
|
||||
|
||||
The refactored `TechnicalIndicators` module is now tightly integrated with the `ChartBuilder`, which handles all data preparation and calculation automatically when indicators are added to a chart. For custom analysis or strategy development, you can use the class directly as shown in the examples above. The key is to always start with a properly prepared DataFrame using `prepare_chart_data`.
|
||||
|
||||
## Data Structures
|
||||
|
||||
### IndicatorResult
|
||||
|
||||
Container for technical indicator calculation results.
|
||||
The `IndicatorResult` class (from `data.common.indicators.result`) contains technical indicator calculation results:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
@ -135,79 +154,50 @@ class IndicatorResult:
|
||||
metadata: Optional[Dict[str, Any]] = None # Calculation metadata
|
||||
```
|
||||
|
||||
### Configuration Format
|
||||
### Configuration Management
|
||||
|
||||
Indicator configurations use a standardized JSON format:
|
||||
|
||||
```json
|
||||
{
|
||||
"indicator_name": {
|
||||
"type": "sma|ema|rsi|macd|bollinger_bands",
|
||||
"period": 20,
|
||||
"price_column": "close",
|
||||
// Additional parameters specific to indicator type
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Integration with TCP Platform
|
||||
|
||||
### Aggregation Strategy Compatibility
|
||||
|
||||
The indicators module is designed to work seamlessly with the TCP platform's aggregation strategy:
|
||||
|
||||
- **Right-Aligned Timestamps**: Uses `end_time` from OHLCV candles
|
||||
- **Sparse Data Support**: Handles missing candles without interpolation
|
||||
- **No Future Leakage**: Only processes completed candles
|
||||
- **Time Boundary Respect**: Maintains proper temporal ordering
|
||||
|
||||
### Real-Time Processing
|
||||
The module provides utilities for managing indicator configurations (from `data.common.indicators.utils`):
|
||||
|
||||
```python
|
||||
from data.common.aggregation.realtime import RealTimeCandleProcessor
|
||||
from data.common.indicators import TechnicalIndicators
|
||||
# Create default configurations
|
||||
config = create_default_indicators_config()
|
||||
|
||||
# Set up real-time processing
|
||||
candle_processor = RealTimeCandleProcessor(symbol='BTC-USDT', exchange='okx')
|
||||
# Validate a configuration
|
||||
is_valid = validate_indicator_config({
|
||||
'type': 'sma',
|
||||
'period': 20,
|
||||
'price_column': 'close'
|
||||
})
|
||||
```
|
||||
|
||||
### Integration with TCP Platform
|
||||
|
||||
The indicators module is designed to work seamlessly with the platform's components:
|
||||
|
||||
```python
|
||||
from data.common.indicators import TechnicalIndicators
|
||||
from data.common.data_types import OHLCVCandle
|
||||
from components.charts.utils import prepare_chart_data
|
||||
|
||||
# Initialize calculator
|
||||
indicators = TechnicalIndicators()
|
||||
|
||||
# Process incoming trades and calculate indicators
|
||||
def on_new_candle(candle):
|
||||
# Get recent candles for indicator calculation
|
||||
recent_candles = get_recent_candles(symbol='BTC-USDT', count=50)
|
||||
|
||||
# Calculate indicators
|
||||
sma_results = indicators.sma(recent_candles, period=20)
|
||||
rsi_results = indicators.rsi(recent_candles, period=14)
|
||||
|
||||
# Use indicator values for trading decisions
|
||||
if sma_results and rsi_results:
|
||||
latest_sma = sma_results[-1].values['sma']
|
||||
latest_rsi = rsi_results[-1].values['rsi']
|
||||
|
||||
# Trading logic here...
|
||||
```
|
||||
# Calculate indicators
|
||||
results = indicators.calculate_multiple_indicators(df, {
|
||||
'sma_20': {'type': 'sma', 'period': 20},
|
||||
'rsi_14': {'type': 'rsi', 'period': 14}
|
||||
})
|
||||
|
||||
### Database Integration
|
||||
|
||||
```python
|
||||
from database.models import IndicatorData
|
||||
|
||||
# Store indicator results in database
|
||||
def store_indicators(indicator_results, indicator_type):
|
||||
# Access results
|
||||
for indicator_name, indicator_results in results.items():
|
||||
for result in indicator_results:
|
||||
indicator_data = IndicatorData(
|
||||
symbol=result.symbol,
|
||||
timeframe=result.timeframe,
|
||||
timestamp=result.timestamp,
|
||||
indicator_type=indicator_type,
|
||||
values=result.values,
|
||||
metadata=result.metadata
|
||||
)
|
||||
session.add(indicator_data)
|
||||
session.commit()
|
||||
print(f"{indicator_name}: {result.values}")
|
||||
```
|
||||
|
||||
## Integration with the TCP Platform
|
||||
|
||||
The refactored `TechnicalIndicators` module is now tightly integrated with the `ChartBuilder`, which handles all data preparation and calculation automatically when indicators are added to a chart. For custom analysis or strategy development, you can use the class directly as shown in the examples above. The key is to always start with a properly prepared DataFrame using `prepare_chart_data`.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Memory Usage
|
||||
|
||||
@ -27,16 +27,16 @@
|
||||
- [x] 1.9 Delete the original `data/common/aggregation.py` file.
|
||||
- [x] 1.10 Run tests to verify the aggregation logic still works as expected.
|
||||
|
||||
- [ ] 2.0 Refactor `indicators.py` into a dedicated sub-package.
|
||||
- [ ] 2.1 Create safety net tests for indicators module.
|
||||
- [ ] 2.2 Create a new directory `data/common/indicators`.
|
||||
- [ ] 2.3 Create `data/common/indicators/__init__.py` to mark it as a package.
|
||||
- [ ] 2.4 Move the `TechnicalIndicators` class to `data/common/indicators/technical.py`.
|
||||
- [ ] 2.5 Move the `IndicatorResult` class to `data/common/indicators/result.py`.
|
||||
- [ ] 2.6 Move the utility functions to `data/common/indicators/utils.py`.
|
||||
- [ ] 2.7 Update `data/common/indicators/__init__.py` to expose all public classes and functions.
|
||||
- [ ] 2.8 Delete the original `data/common/indicators.py` file.
|
||||
- [ ] 2.9 Run tests to verify the indicators logic still works as expected.
|
||||
- [x] 2.0 Refactor `indicators.py` into a dedicated sub-package.
|
||||
- [x] 2.1 Create safety net tests for indicators module.
|
||||
- [x] 2.2 Create a new directory `data/common/indicators`.
|
||||
- [x] 2.3 Create `data/common/indicators/__init__.py` to mark it as a package.
|
||||
- [x] 2.4 Move the `TechnicalIndicators` class to `data/common/indicators/technical.py`.
|
||||
- [x] 2.5 Move the `IndicatorResult` class to `data/common/indicators/result.py`.
|
||||
- [x] 2.6 Move the utility functions to `data/common/indicators/utils.py`.
|
||||
- [x] 2.7 Update `data/common/indicators/__init__.py` to expose all public classes and functions.
|
||||
- [x] 2.8 Delete the original `data/common/indicators.py` file.
|
||||
- [x] 2.9 Run tests to verify the indicators logic still works as expected.
|
||||
|
||||
- [ ] 3.0 Refactor `validation.py` for better modularity.
|
||||
- [ ] 3.1 Create safety net tests for validation module.
|
||||
|
||||
325
tests/test_indicators_safety.py
Normal file
325
tests/test_indicators_safety.py
Normal file
@ -0,0 +1,325 @@
|
||||
"""
|
||||
Safety net tests for technical indicators module.
|
||||
|
||||
These tests ensure that the core functionality of the indicators module
|
||||
remains intact during refactoring.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from decimal import Decimal
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from data.common.indicators import (
|
||||
TechnicalIndicators,
|
||||
IndicatorResult,
|
||||
create_default_indicators_config,
|
||||
validate_indicator_config
|
||||
)
|
||||
from data.common.data_types import OHLCVCandle
|
||||
|
||||
|
||||
class TestTechnicalIndicatorsSafety:
|
||||
"""Safety net test suite for TechnicalIndicators class."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_candles(self):
|
||||
"""Create sample OHLCV candles for testing."""
|
||||
candles = []
|
||||
base_time = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
# Create 30 candles with realistic price movement
|
||||
prices = [100.0, 101.0, 102.5, 101.8, 103.0, 104.2, 103.8, 105.0, 104.5, 106.0,
|
||||
107.5, 108.0, 107.2, 109.0, 108.5, 110.0, 109.8, 111.0, 110.5, 112.0,
|
||||
111.8, 113.0, 112.5, 114.0, 113.2, 115.0, 114.8, 116.0, 115.5, 117.0]
|
||||
|
||||
for i, price in enumerate(prices):
|
||||
candle = OHLCVCandle(
|
||||
symbol='BTC-USDT',
|
||||
timeframe='1m',
|
||||
start_time=base_time + timedelta(minutes=i),
|
||||
end_time=base_time + timedelta(minutes=i+1),
|
||||
open=Decimal(str(price - 0.2)),
|
||||
high=Decimal(str(price + 0.5)),
|
||||
low=Decimal(str(price - 0.5)),
|
||||
close=Decimal(str(price)),
|
||||
volume=Decimal('1000'),
|
||||
trade_count=10,
|
||||
exchange='test',
|
||||
is_complete=True
|
||||
)
|
||||
candles.append(candle)
|
||||
|
||||
return candles
|
||||
|
||||
@pytest.fixture
|
||||
def sparse_candles(self):
|
||||
"""Create sample OHLCV candles with time gaps for testing."""
|
||||
candles = []
|
||||
base_time = datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
# Create 15 candles with gaps (every other minute)
|
||||
prices = [100.0, 102.5, 104.2, 105.0, 106.0,
|
||||
108.0, 109.0, 110.0, 111.0, 112.0,
|
||||
113.0, 114.0, 115.0, 116.0, 117.0]
|
||||
|
||||
for i, price in enumerate(prices):
|
||||
# Create 2-minute gaps between candles
|
||||
candle = OHLCVCandle(
|
||||
symbol='BTC-USDT',
|
||||
timeframe='1m',
|
||||
start_time=base_time + timedelta(minutes=i*2),
|
||||
end_time=base_time + timedelta(minutes=(i*2)+1),
|
||||
open=Decimal(str(price - 0.2)),
|
||||
high=Decimal(str(price + 0.5)),
|
||||
low=Decimal(str(price - 0.5)),
|
||||
close=Decimal(str(price)),
|
||||
volume=Decimal('1000'),
|
||||
trade_count=10,
|
||||
exchange='test',
|
||||
is_complete=True
|
||||
)
|
||||
candles.append(candle)
|
||||
|
||||
return candles
|
||||
|
||||
@pytest.fixture
|
||||
def indicators(self):
|
||||
"""Create TechnicalIndicators instance."""
|
||||
return TechnicalIndicators()
|
||||
|
||||
def test_initialization(self, indicators):
|
||||
"""Test indicator calculator initialization."""
|
||||
assert isinstance(indicators, TechnicalIndicators)
|
||||
|
||||
def test_prepare_dataframe_from_list(self, indicators, sample_candles):
|
||||
"""Test DataFrame preparation from OHLCV candles."""
|
||||
df = indicators._prepare_dataframe_from_list(sample_candles)
|
||||
assert isinstance(df, pd.DataFrame)
|
||||
assert not df.empty
|
||||
assert len(df) == len(sample_candles)
|
||||
assert 'close' in df.columns
|
||||
assert 'timestamp' in df.index.names
|
||||
|
||||
def test_prepare_dataframe_empty(self, indicators):
|
||||
"""Test DataFrame preparation with empty candles list."""
|
||||
df = indicators._prepare_dataframe_from_list([])
|
||||
assert isinstance(df, pd.DataFrame)
|
||||
assert df.empty
|
||||
|
||||
def test_sma_calculation(self, indicators, sample_candles):
|
||||
"""Test Simple Moving Average calculation."""
|
||||
period = 5
|
||||
df = indicators._prepare_dataframe_from_list(sample_candles)
|
||||
results = indicators.sma(df, period)
|
||||
|
||||
assert len(results) > 0
|
||||
assert isinstance(results[0], IndicatorResult)
|
||||
assert 'sma' in results[0].values
|
||||
assert results[0].metadata['period'] == period
|
||||
|
||||
def test_sma_insufficient_data(self, indicators, sample_candles):
|
||||
"""Test SMA with insufficient data."""
|
||||
period = 50 # More than available candles
|
||||
df = indicators._prepare_dataframe_from_list(sample_candles)
|
||||
results = indicators.sma(df, period)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_ema_calculation(self, indicators, sample_candles):
|
||||
"""Test Exponential Moving Average calculation."""
|
||||
period = 10
|
||||
df = indicators._prepare_dataframe_from_list(sample_candles)
|
||||
results = indicators.ema(df, period)
|
||||
|
||||
assert len(results) > 0
|
||||
assert isinstance(results[0], IndicatorResult)
|
||||
assert 'ema' in results[0].values
|
||||
assert results[0].metadata['period'] == period
|
||||
|
||||
def test_rsi_calculation(self, indicators, sample_candles):
|
||||
"""Test Relative Strength Index calculation."""
|
||||
period = 14
|
||||
df = indicators._prepare_dataframe_from_list(sample_candles)
|
||||
results = indicators.rsi(df, period)
|
||||
|
||||
assert len(results) > 0
|
||||
assert isinstance(results[0], IndicatorResult)
|
||||
assert 'rsi' in results[0].values
|
||||
assert results[0].metadata['period'] == period
|
||||
assert 0 <= results[0].values['rsi'] <= 100
|
||||
|
||||
def test_macd_calculation(self, indicators, sample_candles):
|
||||
"""Test MACD calculation."""
|
||||
fast_period = 12
|
||||
slow_period = 26
|
||||
signal_period = 9
|
||||
df = indicators._prepare_dataframe_from_list(sample_candles)
|
||||
results = indicators.macd(df, fast_period, slow_period, signal_period)
|
||||
|
||||
# MACD should start producing results after slow_period periods
|
||||
assert len(results) > 0
|
||||
|
||||
if results: # Only test if we have results
|
||||
first_result = results[0]
|
||||
assert isinstance(first_result, IndicatorResult)
|
||||
assert 'macd' in first_result.values
|
||||
assert 'signal' in first_result.values
|
||||
assert 'histogram' in first_result.values
|
||||
|
||||
# Histogram should equal MACD - Signal
|
||||
expected_histogram = first_result.values['macd'] - first_result.values['signal']
|
||||
assert abs(first_result.values['histogram'] - expected_histogram) < 0.001
|
||||
|
||||
def test_bollinger_bands_calculation(self, indicators, sample_candles):
|
||||
"""Test Bollinger Bands calculation."""
|
||||
period = 20
|
||||
std_dev = 2.0
|
||||
df = indicators._prepare_dataframe_from_list(sample_candles)
|
||||
results = indicators.bollinger_bands(df, period, std_dev)
|
||||
|
||||
assert len(results) > 0
|
||||
assert isinstance(results[0], IndicatorResult)
|
||||
assert 'upper_band' in results[0].values
|
||||
assert 'middle_band' in results[0].values
|
||||
assert 'lower_band' in results[0].values
|
||||
assert results[0].metadata['period'] == period
|
||||
assert results[0].metadata['std_dev'] == std_dev
|
||||
|
||||
def test_sparse_data_handling(self, indicators, sparse_candles):
|
||||
"""Test indicators with sparse data (time gaps)."""
|
||||
period = 5
|
||||
df = indicators._prepare_dataframe_from_list(sparse_candles)
|
||||
sma_results = indicators.sma(df, period)
|
||||
|
||||
assert len(sma_results) > 0
|
||||
# Verify that gaps are preserved (no interpolation)
|
||||
timestamps = [r.timestamp for r in sma_results]
|
||||
for i in range(1, len(timestamps)):
|
||||
time_diff = timestamps[i] - timestamps[i-1]
|
||||
assert time_diff >= timedelta(minutes=1)
|
||||
|
||||
def test_calculate_multiple_indicators(self, indicators, sample_candles):
|
||||
"""Test calculating multiple indicators at once."""
|
||||
config = {
|
||||
'sma_10': {'type': 'sma', 'period': 10},
|
||||
'ema_12': {'type': 'ema', 'period': 12},
|
||||
'rsi_14': {'type': 'rsi', 'period': 14},
|
||||
'macd': {'type': 'macd'},
|
||||
'bb_20': {'type': 'bollinger_bands', 'period': 20}
|
||||
}
|
||||
|
||||
df = indicators._prepare_dataframe_from_list(sample_candles)
|
||||
results = indicators.calculate_multiple_indicators(df, config)
|
||||
|
||||
assert len(results) == len(config)
|
||||
assert 'sma_10' in results
|
||||
assert 'ema_12' in results
|
||||
assert 'rsi_14' in results
|
||||
assert 'macd' in results
|
||||
assert 'bb_20' in results
|
||||
|
||||
# Check that each indicator has appropriate results
|
||||
assert len(results['sma_10']) > 0
|
||||
assert len(results['ema_12']) > 0
|
||||
assert len(results['rsi_14']) > 0
|
||||
assert len(results['macd']) > 0
|
||||
assert len(results['bb_20']) > 0
|
||||
|
||||
def test_different_price_columns(self, indicators, sample_candles):
|
||||
"""Test indicators with different price columns."""
|
||||
df = indicators._prepare_dataframe_from_list(sample_candles)
|
||||
|
||||
# Test SMA with 'high' price column
|
||||
sma_high = indicators.sma(df, 5, price_column='high')
|
||||
assert len(sma_high) > 0
|
||||
|
||||
# Test SMA with 'low' price column
|
||||
sma_low = indicators.sma(df, 5, price_column='low')
|
||||
assert len(sma_low) > 0
|
||||
|
||||
# Values should be different
|
||||
assert sma_high[0].values['sma'] != sma_low[0].values['sma']
|
||||
|
||||
|
||||
class TestIndicatorHelperFunctions:
|
||||
"""Test suite for indicator helper functions."""
|
||||
|
||||
def test_create_default_indicators_config(self):
|
||||
"""Test default indicator configuration creation."""
|
||||
config = create_default_indicators_config()
|
||||
assert isinstance(config, dict)
|
||||
assert len(config) > 0
|
||||
assert 'sma_20' in config
|
||||
assert 'ema_12' in config
|
||||
assert 'rsi_14' in config
|
||||
assert 'macd_default' in config
|
||||
assert 'bollinger_bands_20' in config
|
||||
|
||||
def test_validate_indicator_config_valid(self):
|
||||
"""Test indicator configuration validation with valid config."""
|
||||
valid_configs = [
|
||||
{'type': 'sma', 'period': 20},
|
||||
{'type': 'ema', 'period': 12},
|
||||
{'type': 'rsi', 'period': 14},
|
||||
{'type': 'macd'},
|
||||
{'type': 'bollinger_bands', 'period': 20, 'std_dev': 2.0}
|
||||
]
|
||||
|
||||
for config in valid_configs:
|
||||
assert validate_indicator_config(config)
|
||||
|
||||
def test_validate_indicator_config_invalid(self):
|
||||
"""Test indicator configuration validation with invalid config."""
|
||||
invalid_configs = [
|
||||
{}, # Empty config
|
||||
{'type': 'unknown'}, # Invalid type
|
||||
{'type': 'sma', 'period': -1}, # Invalid period
|
||||
{'type': 'bollinger_bands', 'std_dev': -1}, # Invalid std_dev
|
||||
{'type': 'sma', 'period': 'not_a_number'} # Wrong type for period
|
||||
]
|
||||
|
||||
for config in invalid_configs:
|
||||
assert not validate_indicator_config(config)
|
||||
|
||||
|
||||
class TestIndicatorResultDataClass:
|
||||
"""Test suite for IndicatorResult dataclass."""
|
||||
|
||||
def test_indicator_result_creation(self):
|
||||
"""Test IndicatorResult creation with all fields."""
|
||||
timestamp = datetime.now(timezone.utc)
|
||||
values = {'sma': 100.0}
|
||||
metadata = {'period': 20}
|
||||
|
||||
result = IndicatorResult(
|
||||
timestamp=timestamp,
|
||||
symbol='BTC-USDT',
|
||||
timeframe='1m',
|
||||
values=values,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
assert result.timestamp == timestamp
|
||||
assert result.symbol == 'BTC-USDT'
|
||||
assert result.timeframe == '1m'
|
||||
assert result.values == values
|
||||
assert result.metadata == metadata
|
||||
|
||||
def test_indicator_result_without_metadata(self):
|
||||
"""Test IndicatorResult creation without optional metadata."""
|
||||
timestamp = datetime.now(timezone.utc)
|
||||
values = {'sma': 100.0}
|
||||
|
||||
result = IndicatorResult(
|
||||
timestamp=timestamp,
|
||||
symbol='BTC-USDT',
|
||||
timeframe='1m',
|
||||
values=values
|
||||
)
|
||||
|
||||
assert result.timestamp == timestamp
|
||||
assert result.symbol == 'BTC-USDT'
|
||||
assert result.timeframe == '1m'
|
||||
assert result.values == values
|
||||
assert result.metadata is None
|
||||
Loading…
x
Reference in New Issue
Block a user