data ingestion
This commit is contained in:
230
scripts/data_ingestion.py
Normal file
230
scripts/data_ingestion.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import argparse
|
||||
import os
|
||||
import pandas as pd
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from tqdm import tqdm # Import tqdm
|
||||
|
||||
from data.common.data_types import OHLCVCandle, StandardizedTrade
|
||||
from data.common.aggregation.batch import BatchCandleProcessor
|
||||
from data.common.aggregation.utils import resample_candles_to_timeframe # Import for CSV aggregation
|
||||
from data.common.aggregation.bucket import TimeframeBucket # For calculating start_time from end_time
|
||||
from database.repositories.market_data_repository import MarketDataRepository
|
||||
from database.repositories.raw_trade_repository import RawTradeRepository
|
||||
from utils.logger import get_logger # Import custom logger
|
||||
|
||||
logger = get_logger('data_ingestion')
|
||||
|
||||
def parse_csv_to_candles(file_path: str, exchange: str, symbol: str, sample_rows: int = None) -> list[OHLCVCandle]:
|
||||
"""Parses a CSV file into OHLCVCandle objects, assuming 1-minute candles."""
|
||||
if sample_rows:
|
||||
df = pd.read_csv(file_path, nrows=sample_rows)
|
||||
logger.info(f"Reading first {sample_rows} rows from CSV for test run.")
|
||||
else:
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
# Convert column names to lowercase to handle case insensitivity
|
||||
df.columns = df.columns.str.lower()
|
||||
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
if not all(col in df.columns for col in required_columns):
|
||||
raise ValueError(f"CSV file must contain columns: {required_columns}")
|
||||
|
||||
candles = []
|
||||
for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing CSV rows"): # Add tqdm
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(row['timestamp'])
|
||||
candle = OHLCVCandle(
|
||||
exchange=exchange,
|
||||
symbol=symbol,
|
||||
timeframe='1m', # Assume 1-minute candles for raw CSV data
|
||||
start_time=timestamp,
|
||||
end_time=timestamp, # For minute data, start and end time can be the same
|
||||
open=Decimal(str(row['open'])),
|
||||
high=Decimal(str(row['high'])),
|
||||
low=Decimal(str(row['low'])),
|
||||
close=Decimal(str(row['close'])),
|
||||
volume=Decimal(str(row['volume'])),
|
||||
trade_count=int(row.get('trades_count', 0)), # trades_count might not be in all CSVs
|
||||
is_complete=True # Explicitly set to True for CSV data
|
||||
)
|
||||
candles.append(candle)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing row: {row}. Error: {e}")
|
||||
return candles
|
||||
|
||||
def parse_sqlite_to_trades(db_path: str, exchange: str, symbol: str, sample_rows: int = None) -> list[StandardizedTrade]:
|
||||
"""Reads raw trades from an SQLite database."""
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
query = "SELECT id, instrument, price, size, side, timestamp FROM trades WHERE instrument = ? ORDER BY timestamp ASC"
|
||||
if sample_rows:
|
||||
query += f" LIMIT {sample_rows}"
|
||||
logger.info(f"Reading first {sample_rows} trades from SQLite for test run.")
|
||||
|
||||
cursor.execute(query, (symbol,))
|
||||
|
||||
# Fetch all results to apply tqdm effectively over the list
|
||||
rows = cursor.fetchall()
|
||||
|
||||
trades = []
|
||||
for row in tqdm(rows, total=len(rows), desc="Processing SQLite trades"): # Add tqdm
|
||||
trade_id, instrument, price, size, side, timestamp = row
|
||||
try:
|
||||
# Assuming timestamp is in milliseconds and needs conversion to datetime
|
||||
trade_timestamp = datetime.fromtimestamp(int(timestamp) / 1000)
|
||||
trade = StandardizedTrade(
|
||||
symbol=instrument,
|
||||
trade_id=str(trade_id),
|
||||
price=Decimal(str(price)),
|
||||
size=Decimal(str(size)),
|
||||
side=side,
|
||||
timestamp=trade_timestamp,
|
||||
exchange=exchange
|
||||
)
|
||||
trades.append(trade)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing trade row: {row}. Error: {e}")
|
||||
conn.close()
|
||||
return trades
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Ingest market data into the database.")
|
||||
parser.add_argument("--file", required=True, help="Path to the input data file (CSV or SQLite).")
|
||||
parser.add_argument("--exchange", required=True, help="Exchange name (e.g., 'okx').")
|
||||
parser.add_argument("--symbol", required=True, help="Trading symbol (e.g., 'BTC-USDT').")
|
||||
parser.add_argument("--timeframes", nargs='*', default=['1m'], help="Timeframes for aggregation (e.g., '1m', '5m', '1h'). Required for SQLite, optional for CSV.")
|
||||
parser.add_argument("--force", action="store_true", help="Overwrite existing data if it conflicts.")
|
||||
parser.add_argument("--test-run", action="store_true", help="Run without inserting data, print a sample instead.")
|
||||
parser.add_argument("--sample-rows", type=int, help="Number of rows to process in test-run mode. Only effective with --test-run.")
|
||||
parser.add_argument("--batch-size", type=int, default=10000, help="Batch size for inserting data into the database.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
file_path = args.file
|
||||
exchange = args.exchange
|
||||
symbol = args.symbol
|
||||
timeframes = args.timeframes
|
||||
force_update = args.force
|
||||
test_run = args.test_run
|
||||
sample_rows = args.sample_rows
|
||||
batch_size = args.batch_size
|
||||
|
||||
if test_run and sample_rows is None:
|
||||
logger.warning("--- No --sample-rows specified for --test-run. Processing full file for sample output. ---")
|
||||
|
||||
market_data_repo = MarketDataRepository()
|
||||
# raw_trade_repo = RawTradeRepository() # Not used in this script
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"Error: File not found at {file_path}")
|
||||
return
|
||||
|
||||
if file_path.endswith('.csv'):
|
||||
logger.info(f"Processing CSV file: {file_path}")
|
||||
|
||||
raw_candles = parse_csv_to_candles(file_path, exchange, symbol, sample_rows=sample_rows if test_run else None)
|
||||
logger.info(f"Parsed {len(raw_candles)} raw 1m candles from CSV.")
|
||||
|
||||
if not raw_candles:
|
||||
logger.info("No raw candles found to process in the CSV file.")
|
||||
return
|
||||
|
||||
all_aggregated_candles = []
|
||||
|
||||
# Convert raw candles to a pandas DataFrame for resampling
|
||||
df_raw_candles = pd.DataFrame([c.to_dict() for c in raw_candles])
|
||||
# Ensure 'end_time' is a datetime object and set as index for resampling
|
||||
df_raw_candles['end_time'] = pd.to_datetime(df_raw_candles['end_time'])
|
||||
df_raw_candles = df_raw_candles.set_index('end_time')
|
||||
|
||||
# Convert Decimal types to float for pandas resampling, then back to Decimal after aggregation
|
||||
# This ensures compatibility with pandas' numerical operations
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
if col in df_raw_candles.columns:
|
||||
df_raw_candles[col] = pd.to_numeric(df_raw_candles[col])
|
||||
# 'trade_count' might not exist, handle with .get()
|
||||
if 'trade_count' in df_raw_candles.columns:
|
||||
df_raw_candles['trade_count'] = pd.to_numeric(df_raw_candles['trade_count'])
|
||||
|
||||
for tf in timeframes:
|
||||
logger.info(f"Aggregating 1m candles to {tf} timeframe...")
|
||||
# Resample the DataFrame to the target timeframe
|
||||
resampled_df = resample_candles_to_timeframe(df_raw_candles, tf)
|
||||
|
||||
# Convert resampled DataFrame back to OHLCVCandle objects
|
||||
for index, row in resampled_df.iterrows():
|
||||
# index is the end_time for the resampled candle
|
||||
end_time = index
|
||||
# Calculate start_time based on end_time and timeframe
|
||||
# TimeframeBucket._parse_timeframe_to_timedelta returns timedelta
|
||||
time_delta = TimeframeBucket._parse_timeframe_to_timedelta(tf)
|
||||
start_time = end_time - time_delta
|
||||
|
||||
candle = OHLCVCandle(
|
||||
exchange=exchange,
|
||||
symbol=symbol,
|
||||
timeframe=tf,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
open=Decimal(str(row['open'])),
|
||||
high=Decimal(str(row['high'])),
|
||||
low=Decimal(str(row['low'])),
|
||||
close=Decimal(str(row['close'])),
|
||||
volume=Decimal(str(row['volume'])),
|
||||
trade_count=int(row.get('trades_count', 0)),
|
||||
is_complete=True # Resampled candles are considered complete
|
||||
)
|
||||
all_aggregated_candles.append(candle)
|
||||
|
||||
# Sort candles by timeframe and then by end_time for consistent output/insertion
|
||||
all_aggregated_candles.sort(key=lambda x: (x.timeframe, x.end_time))
|
||||
|
||||
logger.info(f"Aggregated {len(all_aggregated_candles)} candles for timeframes: {', '.join(timeframes)}")
|
||||
|
||||
if test_run:
|
||||
logger.info("--- Test Run: Sample of Aggregated Candles (first 5) ---")
|
||||
for i, candle in enumerate(all_aggregated_candles[:5]):
|
||||
logger.info(f" {candle.to_dict()}")
|
||||
logger.info("--- End of Test Run Sample ---")
|
||||
logger.info("Data not inserted into database due to --test-run flag.")
|
||||
else:
|
||||
logger.info(f"Starting batch insertion of {len(all_aggregated_candles)} aggregated candles with batch size {batch_size}.")
|
||||
market_data_repo.upsert_candles_batch(all_aggregated_candles, force_update=force_update, batch_size=batch_size)
|
||||
logger.info("CSV data ingestion complete.")
|
||||
|
||||
elif file_path.endswith('.db') or file_path.endswith('.sqlite'):
|
||||
logger.info(f"Processing SQLite database: {file_path}")
|
||||
if not timeframes:
|
||||
logger.error("Error: Timeframes must be specified for SQLite trade data aggregation.")
|
||||
return
|
||||
|
||||
trades = parse_sqlite_to_trades(file_path, exchange, symbol, sample_rows=sample_rows if test_run else None)
|
||||
logger.info(f"Parsed {len(trades)} trades from SQLite.")
|
||||
|
||||
if not trades:
|
||||
logger.info("No trades found to process in the SQLite database.")
|
||||
return
|
||||
|
||||
# Use BatchCandleProcessor to aggregate trades into candles
|
||||
processor = BatchCandleProcessor(symbol=symbol, exchange=exchange, timeframes=timeframes, logger=logger)
|
||||
aggregated_candles = processor.process_trades_to_candles(iter(trades))
|
||||
|
||||
logger.info(f"Aggregated {len(aggregated_candles)} candles from trades for timeframes: {', '.join(timeframes)}")
|
||||
|
||||
if test_run:
|
||||
logger.info("--- Test Run: Sample of Aggregated Candles (first 5) ---")
|
||||
for i, candle in enumerate(aggregated_candles[:5]):
|
||||
logger.info(f" {candle.to_dict()}")
|
||||
logger.info("--- End of Test Run Sample ---")
|
||||
logger.info("Data not inserted into database due to --test-run flag.")
|
||||
else:
|
||||
logger.info(f"Starting batch insertion of {len(aggregated_candles)} candles with batch size {batch_size}.")
|
||||
market_data_repo.upsert_candles_batch(aggregated_candles, force_update=force_update, batch_size=batch_size)
|
||||
logger.info("SQLite data ingestion complete.")
|
||||
|
||||
else:
|
||||
logger.error("Error: Unsupported file type. Please provide a .csv or .sqlite/.db file.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user