diff --git a/data/common/aggregation/utils.py b/data/common/aggregation/utils.py index 35faf47..002f0a5 100644 --- a/data/common/aggregation/utils.py +++ b/data/common/aggregation/utils.py @@ -8,6 +8,7 @@ and trade data aggregation. import re from typing import List, Tuple from utils.timeframe_utils import load_timeframe_options +import pandas as pd from ..data_types import StandardizedTrade, OHLCVCandle @@ -74,8 +75,75 @@ def parse_timeframe(timeframe: str) -> Tuple[int, str]: return number, unit +def resample_candles_to_timeframe(df: pd.DataFrame, target_timeframe: str) -> pd.DataFrame: + """ + Resamples a DataFrame of OHLCV candles to a higher timeframe. + + Args: + df (pd.DataFrame): Input DataFrame with a datetime index and 'open', 'high', 'low', 'close', 'volume', + and optionally 'trades_count' columns. + target_timeframe (str): The target timeframe for resampling (e.g., '1h', '1d'). + + Returns: + pd.DataFrame: Resampled DataFrame with OHLCV data for the target timeframe. + """ + if df.empty: + return pd.DataFrame() + + # Ensure the DataFrame index is a datetime index + if not isinstance(df.index, pd.DatetimeIndex): + df['timestamp'] = pd.to_datetime(df['timestamp']) + df = df.set_index('timestamp') + + # Convert timedelta string to pandas frequency string + # '1m' -> '1T', '1h' -> '1H', '1d' -> '1D' + timeframe_map = { + 's': 'S', + 'm': 'T', + 'h': 'H', + 'd': 'D' + } + + # Convert target_timeframe to pandas offset string + match = re.match(r'^(\d+)([smhd])$', target_timeframe.lower()) + if not match: + raise ValueError(f"Invalid target timeframe format: {target_timeframe}") + number = match.group(1) + unit = timeframe_map.get(match.group(2)) + if not unit: + raise ValueError(f"Unsupported timeframe unit: {target_timeframe}") + + resample_freq = f"{number}{unit}" + + # Define how to aggregate each column + ohlcv_dict = { + 'open': 'first', + 'high': 'max', + 'low': 'min', + 'close': 'last', + 'volume': 'sum', + } + + # Only include 'trades_count' if it exists in the DataFrame + if 'trades_count' in df.columns: + ohlcv_dict['trades_count'] = 'sum' + + # Resample the data + resampled_df = df.resample(resample_freq).apply(ohlcv_dict) + + # Drop rows where all OHLCV values are NaN (e.g., periods with no data) + resampled_df.dropna(subset=['open', 'high', 'low', 'close'], inplace=True) + + # Fill NaN trades_count with 0 after resampling + if 'trades_count' in resampled_df.columns: + resampled_df['trades_count'] = resampled_df['trades_count'].fillna(0).astype(int) + + return resampled_df + + __all__ = [ 'aggregate_trades_to_candles', 'validate_timeframe', - 'parse_timeframe' + 'parse_timeframe', + 'resample_candles_to_timeframe' ] \ No newline at end of file diff --git a/database/repositories/base_repository.py b/database/repositories/base_repository.py index 539d7c3..d3b0d60 100644 --- a/database/repositories/base_repository.py +++ b/database/repositories/base_repository.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from typing import Optional from ..connection import get_db_manager +from utils.logger import get_logger class DatabaseOperationError(Exception): @@ -17,24 +18,24 @@ class BaseRepository: def __init__(self, logger: Optional[logging.Logger] = None): """Initialize repository with optional logger.""" - self.logger = logger + if logger is None: + self.logger = get_logger(self.__class__.__name__) + else: + self.logger = logger self._db_manager = get_db_manager() self._db_manager.initialize() def log_info(self, message: str) -> None: """Log info message if logger is available.""" - if self.logger: - self.logger.info(message) + self.logger.info(message) def log_debug(self, message: str) -> None: """Log debug message if logger is available.""" - if self.logger: - self.logger.debug(message) + self.logger.debug(message) def log_error(self, message: str) -> None: """Log error message if logger is available.""" - if self.logger: - self.logger.error(message) + self.logger.error(message) @contextmanager def get_session(self): diff --git a/database/repositories/market_data_repository.py b/database/repositories/market_data_repository.py index 7d66dbb..b45b6e3 100644 --- a/database/repositories/market_data_repository.py +++ b/database/repositories/market_data_repository.py @@ -10,6 +10,7 @@ from sqlalchemy.dialects.postgresql import insert from ..models import MarketData from data.common.data_types import OHLCVCandle from .base_repository import BaseRepository, DatabaseOperationError +from tqdm import tqdm class MarketDataRepository(BaseRepository): @@ -68,6 +69,63 @@ class MarketDataRepository(BaseRepository): self.log_error(f"Error storing candle {candle.symbol} {candle.timeframe}: {e}") raise DatabaseOperationError(f"Failed to store candle: {e}") + def upsert_candles_batch(self, candles: List[OHLCVCandle], force_update: bool = False, batch_size: int = 1000) -> int: + """ + Insert or update multiple candles in the market_data table in batches. + """ + total_processed = 0 + try: + for i in tqdm(range(0, len(candles), batch_size), desc="Inserting candles in batches"): + batch = candles[i:i + batch_size] + + values = [ + { + 'exchange': candle.exchange, + 'symbol': candle.symbol, + 'timeframe': candle.timeframe, + 'timestamp': candle.end_time, + 'open': candle.open, + 'high': candle.high, + 'low': candle.low, + 'close': candle.close, + 'volume': candle.volume, + 'trades_count': candle.trade_count + } + for candle in batch + ] + + with self.get_session() as session: + stmt = insert(MarketData).values(values) + + if force_update: + final_stmt = stmt.on_conflict_do_update( + index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'], + set_={ + 'open': stmt.excluded.open, + 'high': stmt.excluded.high, + 'low': stmt.excluded.low, + 'close': stmt.excluded.close, + 'volume': stmt.excluded.volume, + 'trades_count': stmt.excluded.trades_count + } + ) + action = "Updated" + else: + final_stmt = stmt.on_conflict_do_nothing( + index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'] + ) + action = "Stored" + + session.execute(final_stmt) + session.commit() + total_processed += len(batch) + self.log_debug(f"{action} {len(batch)} candles in batch. Total processed: {total_processed}") + return total_processed + + except Exception as e: + self.log_error(f"Error storing candles in batch: {e}") + raise DatabaseOperationError(f"Failed to store candles in batch: {e}") + def get_candles(self, symbol: str, timeframe: str, @@ -77,6 +135,7 @@ class MarketDataRepository(BaseRepository): """ Retrieve candles from the database using the ORM. """ + self.log_debug(f"DB: get_candles called with: symbol={symbol}, timeframe={timeframe}, start_time={start_time}, end_time={end_time}, exchange={exchange}") try: with self.get_session() as session: query = ( @@ -102,7 +161,7 @@ class MarketDataRepository(BaseRepository): } for r in results ] - self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}") + self.log_debug(f"DB: Retrieved {len(candles)} candles for {symbol} {timeframe} from {start_time} to {end_time}") return candles except Exception as e: @@ -195,4 +254,20 @@ class MarketDataRepository(BaseRepository): except Exception as e: self.log_error(f"Error retrieving candles as DataFrame: {e}") - raise DatabaseOperationError(f"Failed to retrieve candles as DataFrame: {e}") \ No newline at end of file + raise DatabaseOperationError(f"Failed to retrieve candles as DataFrame: {e}") + + def delete_candles_before_timestamp(self, timestamp: datetime) -> int: + """ + Delete candles from the market_data table that are older than the specified timestamp. + """ + try: + with self.get_session() as session: + deleted_count = session.query(MarketData).filter( + MarketData.timestamp < timestamp + ).delete(synchronize_session=False) + session.commit() + self.logger.warning(f"Deleted {deleted_count} candles older than {timestamp}") + return deleted_count + except Exception as e: + self.log_error(f"Error deleting candles older than {timestamp}: {e}") + raise DatabaseOperationError(f"Failed to delete candles: {e}") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6fa1aea..08ba543 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "pytest>=8.3.5", "psutil>=7.0.0", "tzlocal>=5.3.1", + "tdqm>=0.0.1", ] [project.optional-dependencies] diff --git a/scripts/data_ingestion.py b/scripts/data_ingestion.py new file mode 100644 index 0000000..175faf4 --- /dev/null +++ b/scripts/data_ingestion.py @@ -0,0 +1,230 @@ +import argparse +import os +import pandas as pd +import sqlite3 +from datetime import datetime +from decimal import Decimal +from tqdm import tqdm # Import tqdm + +from data.common.data_types import OHLCVCandle, StandardizedTrade +from data.common.aggregation.batch import BatchCandleProcessor +from data.common.aggregation.utils import resample_candles_to_timeframe # Import for CSV aggregation +from data.common.aggregation.bucket import TimeframeBucket # For calculating start_time from end_time +from database.repositories.market_data_repository import MarketDataRepository +from database.repositories.raw_trade_repository import RawTradeRepository +from utils.logger import get_logger # Import custom logger + +logger = get_logger('data_ingestion') + +def parse_csv_to_candles(file_path: str, exchange: str, symbol: str, sample_rows: int = None) -> list[OHLCVCandle]: + """Parses a CSV file into OHLCVCandle objects, assuming 1-minute candles.""" + if sample_rows: + df = pd.read_csv(file_path, nrows=sample_rows) + logger.info(f"Reading first {sample_rows} rows from CSV for test run.") + else: + df = pd.read_csv(file_path) + + # Convert column names to lowercase to handle case insensitivity + df.columns = df.columns.str.lower() + required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] + if not all(col in df.columns for col in required_columns): + raise ValueError(f"CSV file must contain columns: {required_columns}") + + candles = [] + for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing CSV rows"): # Add tqdm + try: + timestamp = datetime.fromtimestamp(row['timestamp']) + candle = OHLCVCandle( + exchange=exchange, + symbol=symbol, + timeframe='1m', # Assume 1-minute candles for raw CSV data + start_time=timestamp, + end_time=timestamp, # For minute data, start and end time can be the same + open=Decimal(str(row['open'])), + high=Decimal(str(row['high'])), + low=Decimal(str(row['low'])), + close=Decimal(str(row['close'])), + volume=Decimal(str(row['volume'])), + trade_count=int(row.get('trades_count', 0)), # trades_count might not be in all CSVs + is_complete=True # Explicitly set to True for CSV data + ) + candles.append(candle) + except Exception as e: + logger.error(f"Error parsing row: {row}. Error: {e}") + return candles + +def parse_sqlite_to_trades(db_path: str, exchange: str, symbol: str, sample_rows: int = None) -> list[StandardizedTrade]: + """Reads raw trades from an SQLite database.""" + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + query = "SELECT id, instrument, price, size, side, timestamp FROM trades WHERE instrument = ? ORDER BY timestamp ASC" + if sample_rows: + query += f" LIMIT {sample_rows}" + logger.info(f"Reading first {sample_rows} trades from SQLite for test run.") + + cursor.execute(query, (symbol,)) + + # Fetch all results to apply tqdm effectively over the list + rows = cursor.fetchall() + + trades = [] + for row in tqdm(rows, total=len(rows), desc="Processing SQLite trades"): # Add tqdm + trade_id, instrument, price, size, side, timestamp = row + try: + # Assuming timestamp is in milliseconds and needs conversion to datetime + trade_timestamp = datetime.fromtimestamp(int(timestamp) / 1000) + trade = StandardizedTrade( + symbol=instrument, + trade_id=str(trade_id), + price=Decimal(str(price)), + size=Decimal(str(size)), + side=side, + timestamp=trade_timestamp, + exchange=exchange + ) + trades.append(trade) + except Exception as e: + logger.error(f"Error parsing trade row: {row}. Error: {e}") + conn.close() + return trades + +def main(): + parser = argparse.ArgumentParser(description="Ingest market data into the database.") + parser.add_argument("--file", required=True, help="Path to the input data file (CSV or SQLite).") + parser.add_argument("--exchange", required=True, help="Exchange name (e.g., 'okx').") + parser.add_argument("--symbol", required=True, help="Trading symbol (e.g., 'BTC-USDT').") + parser.add_argument("--timeframes", nargs='*', default=['1m'], help="Timeframes for aggregation (e.g., '1m', '5m', '1h'). Required for SQLite, optional for CSV.") + parser.add_argument("--force", action="store_true", help="Overwrite existing data if it conflicts.") + parser.add_argument("--test-run", action="store_true", help="Run without inserting data, print a sample instead.") + parser.add_argument("--sample-rows", type=int, help="Number of rows to process in test-run mode. Only effective with --test-run.") + parser.add_argument("--batch-size", type=int, default=10000, help="Batch size for inserting data into the database.") + + args = parser.parse_args() + + file_path = args.file + exchange = args.exchange + symbol = args.symbol + timeframes = args.timeframes + force_update = args.force + test_run = args.test_run + sample_rows = args.sample_rows + batch_size = args.batch_size + + if test_run and sample_rows is None: + logger.warning("--- No --sample-rows specified for --test-run. Processing full file for sample output. ---") + + market_data_repo = MarketDataRepository() + # raw_trade_repo = RawTradeRepository() # Not used in this script + + if not os.path.exists(file_path): + logger.error(f"Error: File not found at {file_path}") + return + + if file_path.endswith('.csv'): + logger.info(f"Processing CSV file: {file_path}") + + raw_candles = parse_csv_to_candles(file_path, exchange, symbol, sample_rows=sample_rows if test_run else None) + logger.info(f"Parsed {len(raw_candles)} raw 1m candles from CSV.") + + if not raw_candles: + logger.info("No raw candles found to process in the CSV file.") + return + + all_aggregated_candles = [] + + # Convert raw candles to a pandas DataFrame for resampling + df_raw_candles = pd.DataFrame([c.to_dict() for c in raw_candles]) + # Ensure 'end_time' is a datetime object and set as index for resampling + df_raw_candles['end_time'] = pd.to_datetime(df_raw_candles['end_time']) + df_raw_candles = df_raw_candles.set_index('end_time') + + # Convert Decimal types to float for pandas resampling, then back to Decimal after aggregation + # This ensures compatibility with pandas' numerical operations + for col in ['open', 'high', 'low', 'close', 'volume']: + if col in df_raw_candles.columns: + df_raw_candles[col] = pd.to_numeric(df_raw_candles[col]) + # 'trade_count' might not exist, handle with .get() + if 'trade_count' in df_raw_candles.columns: + df_raw_candles['trade_count'] = pd.to_numeric(df_raw_candles['trade_count']) + + for tf in timeframes: + logger.info(f"Aggregating 1m candles to {tf} timeframe...") + # Resample the DataFrame to the target timeframe + resampled_df = resample_candles_to_timeframe(df_raw_candles, tf) + + # Convert resampled DataFrame back to OHLCVCandle objects + for index, row in resampled_df.iterrows(): + # index is the end_time for the resampled candle + end_time = index + # Calculate start_time based on end_time and timeframe + # TimeframeBucket._parse_timeframe_to_timedelta returns timedelta + time_delta = TimeframeBucket._parse_timeframe_to_timedelta(tf) + start_time = end_time - time_delta + + candle = OHLCVCandle( + exchange=exchange, + symbol=symbol, + timeframe=tf, + start_time=start_time, + end_time=end_time, + open=Decimal(str(row['open'])), + high=Decimal(str(row['high'])), + low=Decimal(str(row['low'])), + close=Decimal(str(row['close'])), + volume=Decimal(str(row['volume'])), + trade_count=int(row.get('trades_count', 0)), + is_complete=True # Resampled candles are considered complete + ) + all_aggregated_candles.append(candle) + + # Sort candles by timeframe and then by end_time for consistent output/insertion + all_aggregated_candles.sort(key=lambda x: (x.timeframe, x.end_time)) + + logger.info(f"Aggregated {len(all_aggregated_candles)} candles for timeframes: {', '.join(timeframes)}") + + if test_run: + logger.info("--- Test Run: Sample of Aggregated Candles (first 5) ---") + for i, candle in enumerate(all_aggregated_candles[:5]): + logger.info(f" {candle.to_dict()}") + logger.info("--- End of Test Run Sample ---") + logger.info("Data not inserted into database due to --test-run flag.") + else: + logger.info(f"Starting batch insertion of {len(all_aggregated_candles)} aggregated candles with batch size {batch_size}.") + market_data_repo.upsert_candles_batch(all_aggregated_candles, force_update=force_update, batch_size=batch_size) + logger.info("CSV data ingestion complete.") + + elif file_path.endswith('.db') or file_path.endswith('.sqlite'): + logger.info(f"Processing SQLite database: {file_path}") + if not timeframes: + logger.error("Error: Timeframes must be specified for SQLite trade data aggregation.") + return + + trades = parse_sqlite_to_trades(file_path, exchange, symbol, sample_rows=sample_rows if test_run else None) + logger.info(f"Parsed {len(trades)} trades from SQLite.") + + if not trades: + logger.info("No trades found to process in the SQLite database.") + return + + # Use BatchCandleProcessor to aggregate trades into candles + processor = BatchCandleProcessor(symbol=symbol, exchange=exchange, timeframes=timeframes, logger=logger) + aggregated_candles = processor.process_trades_to_candles(iter(trades)) + + logger.info(f"Aggregated {len(aggregated_candles)} candles from trades for timeframes: {', '.join(timeframes)}") + + if test_run: + logger.info("--- Test Run: Sample of Aggregated Candles (first 5) ---") + for i, candle in enumerate(aggregated_candles[:5]): + logger.info(f" {candle.to_dict()}") + logger.info("--- End of Test Run Sample ---") + logger.info("Data not inserted into database due to --test-run flag.") + else: + logger.info(f"Starting batch insertion of {len(aggregated_candles)} candles with batch size {batch_size}.") + market_data_repo.upsert_candles_batch(aggregated_candles, force_update=force_update, batch_size=batch_size) + logger.info("SQLite data ingestion complete.") + + else: + logger.error("Error: Unsupported file type. Please provide a .csv or .sqlite/.db file.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/uv.lock b/uv.lock index 7ffe441..baee9cb 100644 --- a/uv.lock +++ b/uv.lock @@ -441,6 +441,7 @@ dependencies = [ { name = "requests" }, { name = "sqlalchemy" }, { name = "structlog" }, + { name = "tdqm" }, { name = "tzlocal" }, { name = "waitress" }, { name = "watchdog" }, @@ -498,6 +499,7 @@ requires-dist = [ { name = "requests", specifier = ">=2.31.0" }, { name = "sqlalchemy", specifier = ">=2.0.0" }, { name = "structlog", specifier = ">=23.1.0" }, + { name = "tdqm", specifier = ">=0.0.1" }, { name = "tzlocal", specifier = ">=5.3.1" }, { name = "waitress", specifier = ">=3.0.0" }, { name = "watchdog", specifier = ">=3.0.0" }, @@ -1745,6 +1747,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f5/52/7a2c7a317b254af857464da3d60a0d3730c44f912f8c510c76a738a207fd/structlog-25.3.0-py3-none-any.whl", hash = "sha256:a341f5524004c158498c3127eecded091eb67d3a611e7a3093deca30db06e172", size = 68240 }, ] +[[package]] +name = "tdqm" +version = "0.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/38/58c9e22b95e98666fe29f35e217ab6126a03798dadf3f4f8f3e5f898510b/tdqm-0.0.1.tar.gz", hash = "sha256:f050004a76b1d22f70b78209b48781353c82440215b6ba7d22b1f499b05a0101", size = 1388 } + [[package]] name = "tomli" version = "2.2.1" @@ -1784,6 +1795,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 }, ] +[[package]] +name = "tqdm" +version = "4.67.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 }, +] + [[package]] name = "typing-extensions" version = "4.13.2"