210 lines
6.5 KiB
Python
210 lines
6.5 KiB
Python
"""
|
|
Data management for OHLCV data download and storage.
|
|
|
|
Handles data retrieval from exchanges and local file management.
|
|
"""
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
import ccxt
|
|
import pandas as pd
|
|
|
|
from engine.logging_config import get_logger
|
|
from engine.market import MarketType, get_ccxt_symbol
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class DataManager:
|
|
"""
|
|
Manages OHLCV data download and storage for different market types.
|
|
|
|
Data is stored in: data/ccxt/{exchange}/{market_type}/{symbol}/{timeframe}.csv
|
|
"""
|
|
|
|
def __init__(self, data_dir: str = "data/ccxt"):
|
|
self.data_dir = Path(data_dir)
|
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
|
self.exchanges: dict[str, ccxt.Exchange] = {}
|
|
|
|
def get_exchange(self, exchange_id: str) -> ccxt.Exchange:
|
|
"""Get or create a CCXT exchange instance."""
|
|
if exchange_id not in self.exchanges:
|
|
exchange_class = getattr(ccxt, exchange_id)
|
|
self.exchanges[exchange_id] = exchange_class({
|
|
'enableRateLimit': True,
|
|
})
|
|
return self.exchanges[exchange_id]
|
|
|
|
def _get_data_path(
|
|
self,
|
|
exchange_id: str,
|
|
symbol: str,
|
|
timeframe: str,
|
|
market_type: MarketType
|
|
) -> Path:
|
|
"""
|
|
Get the file path for storing/loading data.
|
|
|
|
Args:
|
|
exchange_id: Exchange name (e.g., 'okx')
|
|
symbol: Trading pair (e.g., 'BTC/USDT')
|
|
timeframe: Candle timeframe (e.g., '1m')
|
|
market_type: Market type (spot or perpetual)
|
|
|
|
Returns:
|
|
Path to the CSV file
|
|
"""
|
|
safe_symbol = symbol.replace('/', '-')
|
|
return (
|
|
self.data_dir
|
|
/ exchange_id
|
|
/ market_type.value
|
|
/ safe_symbol
|
|
/ f"{timeframe}.csv"
|
|
)
|
|
|
|
def download_data(
|
|
self,
|
|
exchange_id: str,
|
|
symbol: str,
|
|
timeframe: str = '1m',
|
|
start_date: str | None = None,
|
|
end_date: str | None = None,
|
|
market_type: MarketType = MarketType.SPOT
|
|
) -> pd.DataFrame | None:
|
|
"""
|
|
Download OHLCV data from exchange and save to CSV.
|
|
|
|
Args:
|
|
exchange_id: Exchange name (e.g., 'okx')
|
|
symbol: Trading pair (e.g., 'BTC/USDT')
|
|
timeframe: Candle timeframe (e.g., '1m')
|
|
start_date: Start date string (YYYY-MM-DD)
|
|
end_date: End date string (YYYY-MM-DD)
|
|
market_type: Market type (spot or perpetual)
|
|
|
|
Returns:
|
|
DataFrame with OHLCV data, or None if download failed
|
|
"""
|
|
exchange = self.get_exchange(exchange_id)
|
|
|
|
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
|
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
ccxt_symbol = get_ccxt_symbol(symbol, market_type)
|
|
|
|
since, until = self._parse_date_range(exchange, start_date, end_date)
|
|
|
|
logger.info(
|
|
"Downloading %s (%s) from %s...",
|
|
symbol, market_type.value, exchange_id
|
|
)
|
|
|
|
all_ohlcv = self._fetch_all_candles(exchange, ccxt_symbol, timeframe, since, until)
|
|
|
|
if not all_ohlcv:
|
|
logger.warning("No data downloaded.")
|
|
return None
|
|
|
|
df = self._convert_to_dataframe(all_ohlcv)
|
|
df.to_csv(file_path)
|
|
logger.info("Saved %d candles to %s", len(df), file_path)
|
|
return df
|
|
|
|
def load_data(
|
|
self,
|
|
exchange_id: str,
|
|
symbol: str,
|
|
timeframe: str = '1m',
|
|
market_type: MarketType = MarketType.SPOT
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Load saved OHLCV data for vectorbt.
|
|
|
|
Args:
|
|
exchange_id: Exchange name (e.g., 'okx')
|
|
symbol: Trading pair (e.g., 'BTC/USDT')
|
|
timeframe: Candle timeframe (e.g., '1m')
|
|
market_type: Market type (spot or perpetual)
|
|
|
|
Returns:
|
|
DataFrame with OHLCV data indexed by timestamp
|
|
|
|
Raises:
|
|
FileNotFoundError: If data file does not exist
|
|
"""
|
|
file_path = self._get_data_path(exchange_id, symbol, timeframe, market_type)
|
|
|
|
if not file_path.exists():
|
|
raise FileNotFoundError(
|
|
f"Data not found at {file_path}. "
|
|
f"Run: uv run python main.py download --pair {symbol} "
|
|
f"--market {market_type.value}"
|
|
)
|
|
|
|
return pd.read_csv(file_path, index_col='timestamp', parse_dates=True)
|
|
|
|
def _parse_date_range(
|
|
self,
|
|
exchange: ccxt.Exchange,
|
|
start_date: str | None,
|
|
end_date: str | None
|
|
) -> tuple[int, int]:
|
|
"""Parse date strings into millisecond timestamps."""
|
|
if start_date:
|
|
since = exchange.parse8601(f"{start_date}T00:00:00Z")
|
|
else:
|
|
since = exchange.milliseconds() - 365 * 24 * 60 * 60 * 1000
|
|
|
|
if end_date:
|
|
until = exchange.parse8601(f"{end_date}T23:59:59Z")
|
|
else:
|
|
until = exchange.milliseconds()
|
|
|
|
return since, until
|
|
|
|
def _fetch_all_candles(
|
|
self,
|
|
exchange: ccxt.Exchange,
|
|
symbol: str,
|
|
timeframe: str,
|
|
since: int,
|
|
until: int
|
|
) -> list:
|
|
"""Fetch all candles in the date range."""
|
|
all_ohlcv = []
|
|
|
|
while since < until:
|
|
try:
|
|
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since, limit=100)
|
|
if not ohlcv:
|
|
break
|
|
|
|
all_ohlcv.extend(ohlcv)
|
|
since = ohlcv[-1][0] + 1
|
|
|
|
current_date = datetime.fromtimestamp(
|
|
since/1000, tz=timezone.utc
|
|
).strftime('%Y-%m-%d')
|
|
logger.debug("Fetched up to %s", current_date)
|
|
|
|
time.sleep(exchange.rateLimit / 1000)
|
|
|
|
except Exception as e:
|
|
logger.error("Error fetching data: %s", e)
|
|
break
|
|
|
|
return all_ohlcv
|
|
|
|
def _convert_to_dataframe(self, ohlcv: list) -> pd.DataFrame:
|
|
"""Convert OHLCV list to DataFrame."""
|
|
df = pd.DataFrame(
|
|
ohlcv,
|
|
columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
|
)
|
|
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
|
df.set_index('timestamp', inplace=True)
|
|
return df
|