import argparse import os import pandas as pd import sqlite3 from datetime import datetime from decimal import Decimal from tqdm import tqdm # Import tqdm from data.common.data_types import OHLCVCandle, StandardizedTrade from data.common.aggregation.batch import BatchCandleProcessor from data.common.aggregation.utils import resample_candles_to_timeframe # Import for CSV aggregation from data.common.aggregation.bucket import TimeframeBucket # For calculating start_time from end_time from database.repositories.market_data_repository import MarketDataRepository from database.repositories.raw_trade_repository import RawTradeRepository from utils.logger import get_logger # Import custom logger logger = get_logger('data_ingestion') def parse_csv_to_candles(file_path: str, exchange: str, symbol: str, sample_rows: int = None) -> list[OHLCVCandle]: """Parses a CSV file into OHLCVCandle objects, assuming 1-minute candles.""" if sample_rows: df = pd.read_csv(file_path, nrows=sample_rows) logger.info(f"Reading first {sample_rows} rows from CSV for test run.") else: df = pd.read_csv(file_path) # Convert column names to lowercase to handle case insensitivity df.columns = df.columns.str.lower() required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] if not all(col in df.columns for col in required_columns): raise ValueError(f"CSV file must contain columns: {required_columns}") candles = [] for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing CSV rows"): # Add tqdm try: timestamp = datetime.fromtimestamp(row['timestamp']) candle = OHLCVCandle( exchange=exchange, symbol=symbol, timeframe='1m', # Assume 1-minute candles for raw CSV data start_time=timestamp, end_time=timestamp, # For minute data, start and end time can be the same open=Decimal(str(row['open'])), high=Decimal(str(row['high'])), low=Decimal(str(row['low'])), close=Decimal(str(row['close'])), volume=Decimal(str(row['volume'])), trade_count=int(row.get('trades_count', 0)), # trades_count might not be in all CSVs is_complete=True # Explicitly set to True for CSV data ) candles.append(candle) except Exception as e: logger.error(f"Error parsing row: {row}. Error: {e}") return candles def parse_sqlite_to_trades(db_path: str, exchange: str, symbol: str, sample_rows: int = None) -> list[StandardizedTrade]: """Reads raw trades from an SQLite database.""" conn = sqlite3.connect(db_path) cursor = conn.cursor() query = "SELECT id, instrument, price, size, side, timestamp FROM trades WHERE instrument = ? ORDER BY timestamp ASC" if sample_rows: query += f" LIMIT {sample_rows}" logger.info(f"Reading first {sample_rows} trades from SQLite for test run.") cursor.execute(query, (symbol,)) # Fetch all results to apply tqdm effectively over the list rows = cursor.fetchall() trades = [] for row in tqdm(rows, total=len(rows), desc="Processing SQLite trades"): # Add tqdm trade_id, instrument, price, size, side, timestamp = row try: # Assuming timestamp is in milliseconds and needs conversion to datetime trade_timestamp = datetime.fromtimestamp(int(timestamp) / 1000) trade = StandardizedTrade( symbol=instrument, trade_id=str(trade_id), price=Decimal(str(price)), size=Decimal(str(size)), side=side, timestamp=trade_timestamp, exchange=exchange ) trades.append(trade) except Exception as e: logger.error(f"Error parsing trade row: {row}. Error: {e}") conn.close() return trades def main(): parser = argparse.ArgumentParser(description="Ingest market data into the database.") parser.add_argument("--file", required=True, help="Path to the input data file (CSV or SQLite).") parser.add_argument("--exchange", required=True, help="Exchange name (e.g., 'okx').") parser.add_argument("--symbol", required=True, help="Trading symbol (e.g., 'BTC-USDT').") parser.add_argument("--timeframes", nargs='*', default=['1m'], help="Timeframes for aggregation (e.g., '1m', '5m', '1h'). Required for SQLite, optional for CSV.") parser.add_argument("--force", action="store_true", help="Overwrite existing data if it conflicts.") parser.add_argument("--test-run", action="store_true", help="Run without inserting data, print a sample instead.") parser.add_argument("--sample-rows", type=int, help="Number of rows to process in test-run mode. Only effective with --test-run.") parser.add_argument("--batch-size", type=int, default=10000, help="Batch size for inserting data into the database.") args = parser.parse_args() file_path = args.file exchange = args.exchange symbol = args.symbol timeframes = args.timeframes force_update = args.force test_run = args.test_run sample_rows = args.sample_rows batch_size = args.batch_size if test_run and sample_rows is None: logger.warning("--- No --sample-rows specified for --test-run. Processing full file for sample output. ---") market_data_repo = MarketDataRepository() # raw_trade_repo = RawTradeRepository() # Not used in this script if not os.path.exists(file_path): logger.error(f"Error: File not found at {file_path}") return if file_path.endswith('.csv'): logger.info(f"Processing CSV file: {file_path}") raw_candles = parse_csv_to_candles(file_path, exchange, symbol, sample_rows=sample_rows if test_run else None) logger.info(f"Parsed {len(raw_candles)} raw 1m candles from CSV.") if not raw_candles: logger.info("No raw candles found to process in the CSV file.") return all_aggregated_candles = [] # Convert raw candles to a pandas DataFrame for resampling df_raw_candles = pd.DataFrame([c.to_dict() for c in raw_candles]) # Ensure 'end_time' is a datetime object and set as index for resampling df_raw_candles['end_time'] = pd.to_datetime(df_raw_candles['end_time']) df_raw_candles = df_raw_candles.set_index('end_time') # Convert Decimal types to float for pandas resampling, then back to Decimal after aggregation # This ensures compatibility with pandas' numerical operations for col in ['open', 'high', 'low', 'close', 'volume']: if col in df_raw_candles.columns: df_raw_candles[col] = pd.to_numeric(df_raw_candles[col]) # 'trade_count' might not exist, handle with .get() if 'trade_count' in df_raw_candles.columns: df_raw_candles['trade_count'] = pd.to_numeric(df_raw_candles['trade_count']) for tf in timeframes: logger.info(f"Aggregating 1m candles to {tf} timeframe...") # Resample the DataFrame to the target timeframe resampled_df = resample_candles_to_timeframe(df_raw_candles, tf) # Convert resampled DataFrame back to OHLCVCandle objects for index, row in resampled_df.iterrows(): # index is the end_time for the resampled candle end_time = index # Calculate start_time based on end_time and timeframe # TimeframeBucket._parse_timeframe_to_timedelta returns timedelta time_delta = TimeframeBucket._parse_timeframe_to_timedelta(tf) start_time = end_time - time_delta candle = OHLCVCandle( exchange=exchange, symbol=symbol, timeframe=tf, start_time=start_time, end_time=end_time, open=Decimal(str(row['open'])), high=Decimal(str(row['high'])), low=Decimal(str(row['low'])), close=Decimal(str(row['close'])), volume=Decimal(str(row['volume'])), trade_count=int(row.get('trades_count', 0)), is_complete=True # Resampled candles are considered complete ) all_aggregated_candles.append(candle) # Sort candles by timeframe and then by end_time for consistent output/insertion all_aggregated_candles.sort(key=lambda x: (x.timeframe, x.end_time)) logger.info(f"Aggregated {len(all_aggregated_candles)} candles for timeframes: {', '.join(timeframes)}") if test_run: logger.info("--- Test Run: Sample of Aggregated Candles (first 5) ---") for i, candle in enumerate(all_aggregated_candles[:5]): logger.info(f" {candle.to_dict()}") logger.info("--- End of Test Run Sample ---") logger.info("Data not inserted into database due to --test-run flag.") else: logger.info(f"Starting batch insertion of {len(all_aggregated_candles)} aggregated candles with batch size {batch_size}.") market_data_repo.upsert_candles_batch(all_aggregated_candles, force_update=force_update, batch_size=batch_size) logger.info("CSV data ingestion complete.") elif file_path.endswith('.db') or file_path.endswith('.sqlite'): logger.info(f"Processing SQLite database: {file_path}") if not timeframes: logger.error("Error: Timeframes must be specified for SQLite trade data aggregation.") return trades = parse_sqlite_to_trades(file_path, exchange, symbol, sample_rows=sample_rows if test_run else None) logger.info(f"Parsed {len(trades)} trades from SQLite.") if not trades: logger.info("No trades found to process in the SQLite database.") return # Use BatchCandleProcessor to aggregate trades into candles processor = BatchCandleProcessor(symbol=symbol, exchange=exchange, timeframes=timeframes, logger=logger) aggregated_candles = processor.process_trades_to_candles(iter(trades)) logger.info(f"Aggregated {len(aggregated_candles)} candles from trades for timeframes: {', '.join(timeframes)}") if test_run: logger.info("--- Test Run: Sample of Aggregated Candles (first 5) ---") for i, candle in enumerate(aggregated_candles[:5]): logger.info(f" {candle.to_dict()}") logger.info("--- End of Test Run Sample ---") logger.info("Data not inserted into database due to --test-run flag.") else: logger.info(f"Starting batch insertion of {len(aggregated_candles)} candles with batch size {batch_size}.") market_data_repo.upsert_candles_batch(aggregated_candles, force_update=force_update, batch_size=batch_size) logger.info("SQLite data ingestion complete.") else: logger.error("Error: Unsupported file type. Please provide a .csv or .sqlite/.db file.") if __name__ == "__main__": main()