TCPDashboard/scripts/data_ingestion.py

230 lines
11 KiB
Python
Raw Permalink Normal View History

2025-06-13 16:49:29 +08:00
import argparse
import os
import pandas as pd
import sqlite3
from datetime import datetime
from decimal import Decimal
from tqdm import tqdm # Import tqdm
from data.common.data_types import OHLCVCandle, StandardizedTrade
from data.common.aggregation.batch import BatchCandleProcessor
from data.common.aggregation.utils import resample_candles_to_timeframe # Import for CSV aggregation
from data.common.aggregation.bucket import TimeframeBucket # For calculating start_time from end_time
from database.repositories.market_data_repository import MarketDataRepository
from database.repositories.raw_trade_repository import RawTradeRepository
from utils.logger import get_logger # Import custom logger
logger = get_logger('data_ingestion')
def parse_csv_to_candles(file_path: str, exchange: str, symbol: str, sample_rows: int = None) -> list[OHLCVCandle]:
"""Parses a CSV file into OHLCVCandle objects, assuming 1-minute candles."""
if sample_rows:
df = pd.read_csv(file_path, nrows=sample_rows)
logger.info(f"Reading first {sample_rows} rows from CSV for test run.")
else:
df = pd.read_csv(file_path)
# Convert column names to lowercase to handle case insensitivity
df.columns = df.columns.str.lower()
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
if not all(col in df.columns for col in required_columns):
raise ValueError(f"CSV file must contain columns: {required_columns}")
candles = []
for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing CSV rows"): # Add tqdm
try:
timestamp = datetime.fromtimestamp(row['timestamp'])
candle = OHLCVCandle(
exchange=exchange,
symbol=symbol,
timeframe='1m', # Assume 1-minute candles for raw CSV data
start_time=timestamp,
end_time=timestamp, # For minute data, start and end time can be the same
open=Decimal(str(row['open'])),
high=Decimal(str(row['high'])),
low=Decimal(str(row['low'])),
close=Decimal(str(row['close'])),
volume=Decimal(str(row['volume'])),
trade_count=int(row.get('trades_count', 0)), # trades_count might not be in all CSVs
is_complete=True # Explicitly set to True for CSV data
)
candles.append(candle)
except Exception as e:
logger.error(f"Error parsing row: {row}. Error: {e}")
return candles
def parse_sqlite_to_trades(db_path: str, exchange: str, symbol: str, sample_rows: int = None) -> list[StandardizedTrade]:
"""Reads raw trades from an SQLite database."""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
query = "SELECT id, instrument, price, size, side, timestamp FROM trades WHERE instrument = ? ORDER BY timestamp ASC"
if sample_rows:
query += f" LIMIT {sample_rows}"
logger.info(f"Reading first {sample_rows} trades from SQLite for test run.")
cursor.execute(query, (symbol,))
# Fetch all results to apply tqdm effectively over the list
rows = cursor.fetchall()
trades = []
for row in tqdm(rows, total=len(rows), desc="Processing SQLite trades"): # Add tqdm
trade_id, instrument, price, size, side, timestamp = row
try:
# Assuming timestamp is in milliseconds and needs conversion to datetime
trade_timestamp = datetime.fromtimestamp(int(timestamp) / 1000)
trade = StandardizedTrade(
symbol=instrument,
trade_id=str(trade_id),
price=Decimal(str(price)),
size=Decimal(str(size)),
side=side,
timestamp=trade_timestamp,
exchange=exchange
)
trades.append(trade)
except Exception as e:
logger.error(f"Error parsing trade row: {row}. Error: {e}")
conn.close()
return trades
def main():
parser = argparse.ArgumentParser(description="Ingest market data into the database.")
parser.add_argument("--file", required=True, help="Path to the input data file (CSV or SQLite).")
parser.add_argument("--exchange", required=True, help="Exchange name (e.g., 'okx').")
parser.add_argument("--symbol", required=True, help="Trading symbol (e.g., 'BTC-USDT').")
parser.add_argument("--timeframes", nargs='*', default=['1m'], help="Timeframes for aggregation (e.g., '1m', '5m', '1h'). Required for SQLite, optional for CSV.")
parser.add_argument("--force", action="store_true", help="Overwrite existing data if it conflicts.")
parser.add_argument("--test-run", action="store_true", help="Run without inserting data, print a sample instead.")
parser.add_argument("--sample-rows", type=int, help="Number of rows to process in test-run mode. Only effective with --test-run.")
parser.add_argument("--batch-size", type=int, default=10000, help="Batch size for inserting data into the database.")
args = parser.parse_args()
file_path = args.file
exchange = args.exchange
symbol = args.symbol
timeframes = args.timeframes
force_update = args.force
test_run = args.test_run
sample_rows = args.sample_rows
batch_size = args.batch_size
if test_run and sample_rows is None:
logger.warning("--- No --sample-rows specified for --test-run. Processing full file for sample output. ---")
market_data_repo = MarketDataRepository()
# raw_trade_repo = RawTradeRepository() # Not used in this script
if not os.path.exists(file_path):
logger.error(f"Error: File not found at {file_path}")
return
if file_path.endswith('.csv'):
logger.info(f"Processing CSV file: {file_path}")
raw_candles = parse_csv_to_candles(file_path, exchange, symbol, sample_rows=sample_rows if test_run else None)
logger.info(f"Parsed {len(raw_candles)} raw 1m candles from CSV.")
if not raw_candles:
logger.info("No raw candles found to process in the CSV file.")
return
all_aggregated_candles = []
# Convert raw candles to a pandas DataFrame for resampling
df_raw_candles = pd.DataFrame([c.to_dict() for c in raw_candles])
# Ensure 'end_time' is a datetime object and set as index for resampling
df_raw_candles['end_time'] = pd.to_datetime(df_raw_candles['end_time'])
df_raw_candles = df_raw_candles.set_index('end_time')
# Convert Decimal types to float for pandas resampling, then back to Decimal after aggregation
# This ensures compatibility with pandas' numerical operations
for col in ['open', 'high', 'low', 'close', 'volume']:
if col in df_raw_candles.columns:
df_raw_candles[col] = pd.to_numeric(df_raw_candles[col])
# 'trade_count' might not exist, handle with .get()
if 'trade_count' in df_raw_candles.columns:
df_raw_candles['trade_count'] = pd.to_numeric(df_raw_candles['trade_count'])
for tf in timeframes:
logger.info(f"Aggregating 1m candles to {tf} timeframe...")
# Resample the DataFrame to the target timeframe
resampled_df = resample_candles_to_timeframe(df_raw_candles, tf)
# Convert resampled DataFrame back to OHLCVCandle objects
for index, row in resampled_df.iterrows():
# index is the end_time for the resampled candle
end_time = index
# Calculate start_time based on end_time and timeframe
# TimeframeBucket._parse_timeframe_to_timedelta returns timedelta
time_delta = TimeframeBucket._parse_timeframe_to_timedelta(tf)
start_time = end_time - time_delta
candle = OHLCVCandle(
exchange=exchange,
symbol=symbol,
timeframe=tf,
start_time=start_time,
end_time=end_time,
open=Decimal(str(row['open'])),
high=Decimal(str(row['high'])),
low=Decimal(str(row['low'])),
close=Decimal(str(row['close'])),
volume=Decimal(str(row['volume'])),
trade_count=int(row.get('trades_count', 0)),
is_complete=True # Resampled candles are considered complete
)
all_aggregated_candles.append(candle)
# Sort candles by timeframe and then by end_time for consistent output/insertion
all_aggregated_candles.sort(key=lambda x: (x.timeframe, x.end_time))
logger.info(f"Aggregated {len(all_aggregated_candles)} candles for timeframes: {', '.join(timeframes)}")
if test_run:
logger.info("--- Test Run: Sample of Aggregated Candles (first 5) ---")
for i, candle in enumerate(all_aggregated_candles[:5]):
logger.info(f" {candle.to_dict()}")
logger.info("--- End of Test Run Sample ---")
logger.info("Data not inserted into database due to --test-run flag.")
else:
logger.info(f"Starting batch insertion of {len(all_aggregated_candles)} aggregated candles with batch size {batch_size}.")
market_data_repo.upsert_candles_batch(all_aggregated_candles, force_update=force_update, batch_size=batch_size)
logger.info("CSV data ingestion complete.")
elif file_path.endswith('.db') or file_path.endswith('.sqlite'):
logger.info(f"Processing SQLite database: {file_path}")
if not timeframes:
logger.error("Error: Timeframes must be specified for SQLite trade data aggregation.")
return
trades = parse_sqlite_to_trades(file_path, exchange, symbol, sample_rows=sample_rows if test_run else None)
logger.info(f"Parsed {len(trades)} trades from SQLite.")
if not trades:
logger.info("No trades found to process in the SQLite database.")
return
# Use BatchCandleProcessor to aggregate trades into candles
processor = BatchCandleProcessor(symbol=symbol, exchange=exchange, timeframes=timeframes, logger=logger)
aggregated_candles = processor.process_trades_to_candles(iter(trades))
logger.info(f"Aggregated {len(aggregated_candles)} candles from trades for timeframes: {', '.join(timeframes)}")
if test_run:
logger.info("--- Test Run: Sample of Aggregated Candles (first 5) ---")
for i, candle in enumerate(aggregated_candles[:5]):
logger.info(f" {candle.to_dict()}")
logger.info("--- End of Test Run Sample ---")
logger.info("Data not inserted into database due to --test-run flag.")
else:
logger.info(f"Starting batch insertion of {len(aggregated_candles)} candles with batch size {batch_size}.")
market_data_repo.upsert_candles_batch(aggregated_candles, force_update=force_update, batch_size=batch_size)
logger.info("SQLite data ingestion complete.")
else:
logger.error("Error: Unsupported file type. Please provide a .csv or .sqlite/.db file.")
if __name__ == "__main__":
main()