230 lines
11 KiB
Python
230 lines
11 KiB
Python
|
|
import argparse
|
||
|
|
import os
|
||
|
|
import pandas as pd
|
||
|
|
import sqlite3
|
||
|
|
from datetime import datetime
|
||
|
|
from decimal import Decimal
|
||
|
|
from tqdm import tqdm # Import tqdm
|
||
|
|
|
||
|
|
from data.common.data_types import OHLCVCandle, StandardizedTrade
|
||
|
|
from data.common.aggregation.batch import BatchCandleProcessor
|
||
|
|
from data.common.aggregation.utils import resample_candles_to_timeframe # Import for CSV aggregation
|
||
|
|
from data.common.aggregation.bucket import TimeframeBucket # For calculating start_time from end_time
|
||
|
|
from database.repositories.market_data_repository import MarketDataRepository
|
||
|
|
from database.repositories.raw_trade_repository import RawTradeRepository
|
||
|
|
from utils.logger import get_logger # Import custom logger
|
||
|
|
|
||
|
|
logger = get_logger('data_ingestion')
|
||
|
|
|
||
|
|
def parse_csv_to_candles(file_path: str, exchange: str, symbol: str, sample_rows: int = None) -> list[OHLCVCandle]:
|
||
|
|
"""Parses a CSV file into OHLCVCandle objects, assuming 1-minute candles."""
|
||
|
|
if sample_rows:
|
||
|
|
df = pd.read_csv(file_path, nrows=sample_rows)
|
||
|
|
logger.info(f"Reading first {sample_rows} rows from CSV for test run.")
|
||
|
|
else:
|
||
|
|
df = pd.read_csv(file_path)
|
||
|
|
|
||
|
|
# Convert column names to lowercase to handle case insensitivity
|
||
|
|
df.columns = df.columns.str.lower()
|
||
|
|
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||
|
|
if not all(col in df.columns for col in required_columns):
|
||
|
|
raise ValueError(f"CSV file must contain columns: {required_columns}")
|
||
|
|
|
||
|
|
candles = []
|
||
|
|
for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing CSV rows"): # Add tqdm
|
||
|
|
try:
|
||
|
|
timestamp = datetime.fromtimestamp(row['timestamp'])
|
||
|
|
candle = OHLCVCandle(
|
||
|
|
exchange=exchange,
|
||
|
|
symbol=symbol,
|
||
|
|
timeframe='1m', # Assume 1-minute candles for raw CSV data
|
||
|
|
start_time=timestamp,
|
||
|
|
end_time=timestamp, # For minute data, start and end time can be the same
|
||
|
|
open=Decimal(str(row['open'])),
|
||
|
|
high=Decimal(str(row['high'])),
|
||
|
|
low=Decimal(str(row['low'])),
|
||
|
|
close=Decimal(str(row['close'])),
|
||
|
|
volume=Decimal(str(row['volume'])),
|
||
|
|
trade_count=int(row.get('trades_count', 0)), # trades_count might not be in all CSVs
|
||
|
|
is_complete=True # Explicitly set to True for CSV data
|
||
|
|
)
|
||
|
|
candles.append(candle)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error parsing row: {row}. Error: {e}")
|
||
|
|
return candles
|
||
|
|
|
||
|
|
def parse_sqlite_to_trades(db_path: str, exchange: str, symbol: str, sample_rows: int = None) -> list[StandardizedTrade]:
|
||
|
|
"""Reads raw trades from an SQLite database."""
|
||
|
|
conn = sqlite3.connect(db_path)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
query = "SELECT id, instrument, price, size, side, timestamp FROM trades WHERE instrument = ? ORDER BY timestamp ASC"
|
||
|
|
if sample_rows:
|
||
|
|
query += f" LIMIT {sample_rows}"
|
||
|
|
logger.info(f"Reading first {sample_rows} trades from SQLite for test run.")
|
||
|
|
|
||
|
|
cursor.execute(query, (symbol,))
|
||
|
|
|
||
|
|
# Fetch all results to apply tqdm effectively over the list
|
||
|
|
rows = cursor.fetchall()
|
||
|
|
|
||
|
|
trades = []
|
||
|
|
for row in tqdm(rows, total=len(rows), desc="Processing SQLite trades"): # Add tqdm
|
||
|
|
trade_id, instrument, price, size, side, timestamp = row
|
||
|
|
try:
|
||
|
|
# Assuming timestamp is in milliseconds and needs conversion to datetime
|
||
|
|
trade_timestamp = datetime.fromtimestamp(int(timestamp) / 1000)
|
||
|
|
trade = StandardizedTrade(
|
||
|
|
symbol=instrument,
|
||
|
|
trade_id=str(trade_id),
|
||
|
|
price=Decimal(str(price)),
|
||
|
|
size=Decimal(str(size)),
|
||
|
|
side=side,
|
||
|
|
timestamp=trade_timestamp,
|
||
|
|
exchange=exchange
|
||
|
|
)
|
||
|
|
trades.append(trade)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error parsing trade row: {row}. Error: {e}")
|
||
|
|
conn.close()
|
||
|
|
return trades
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description="Ingest market data into the database.")
|
||
|
|
parser.add_argument("--file", required=True, help="Path to the input data file (CSV or SQLite).")
|
||
|
|
parser.add_argument("--exchange", required=True, help="Exchange name (e.g., 'okx').")
|
||
|
|
parser.add_argument("--symbol", required=True, help="Trading symbol (e.g., 'BTC-USDT').")
|
||
|
|
parser.add_argument("--timeframes", nargs='*', default=['1m'], help="Timeframes for aggregation (e.g., '1m', '5m', '1h'). Required for SQLite, optional for CSV.")
|
||
|
|
parser.add_argument("--force", action="store_true", help="Overwrite existing data if it conflicts.")
|
||
|
|
parser.add_argument("--test-run", action="store_true", help="Run without inserting data, print a sample instead.")
|
||
|
|
parser.add_argument("--sample-rows", type=int, help="Number of rows to process in test-run mode. Only effective with --test-run.")
|
||
|
|
parser.add_argument("--batch-size", type=int, default=10000, help="Batch size for inserting data into the database.")
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
file_path = args.file
|
||
|
|
exchange = args.exchange
|
||
|
|
symbol = args.symbol
|
||
|
|
timeframes = args.timeframes
|
||
|
|
force_update = args.force
|
||
|
|
test_run = args.test_run
|
||
|
|
sample_rows = args.sample_rows
|
||
|
|
batch_size = args.batch_size
|
||
|
|
|
||
|
|
if test_run and sample_rows is None:
|
||
|
|
logger.warning("--- No --sample-rows specified for --test-run. Processing full file for sample output. ---")
|
||
|
|
|
||
|
|
market_data_repo = MarketDataRepository()
|
||
|
|
# raw_trade_repo = RawTradeRepository() # Not used in this script
|
||
|
|
|
||
|
|
if not os.path.exists(file_path):
|
||
|
|
logger.error(f"Error: File not found at {file_path}")
|
||
|
|
return
|
||
|
|
|
||
|
|
if file_path.endswith('.csv'):
|
||
|
|
logger.info(f"Processing CSV file: {file_path}")
|
||
|
|
|
||
|
|
raw_candles = parse_csv_to_candles(file_path, exchange, symbol, sample_rows=sample_rows if test_run else None)
|
||
|
|
logger.info(f"Parsed {len(raw_candles)} raw 1m candles from CSV.")
|
||
|
|
|
||
|
|
if not raw_candles:
|
||
|
|
logger.info("No raw candles found to process in the CSV file.")
|
||
|
|
return
|
||
|
|
|
||
|
|
all_aggregated_candles = []
|
||
|
|
|
||
|
|
# Convert raw candles to a pandas DataFrame for resampling
|
||
|
|
df_raw_candles = pd.DataFrame([c.to_dict() for c in raw_candles])
|
||
|
|
# Ensure 'end_time' is a datetime object and set as index for resampling
|
||
|
|
df_raw_candles['end_time'] = pd.to_datetime(df_raw_candles['end_time'])
|
||
|
|
df_raw_candles = df_raw_candles.set_index('end_time')
|
||
|
|
|
||
|
|
# Convert Decimal types to float for pandas resampling, then back to Decimal after aggregation
|
||
|
|
# This ensures compatibility with pandas' numerical operations
|
||
|
|
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||
|
|
if col in df_raw_candles.columns:
|
||
|
|
df_raw_candles[col] = pd.to_numeric(df_raw_candles[col])
|
||
|
|
# 'trade_count' might not exist, handle with .get()
|
||
|
|
if 'trade_count' in df_raw_candles.columns:
|
||
|
|
df_raw_candles['trade_count'] = pd.to_numeric(df_raw_candles['trade_count'])
|
||
|
|
|
||
|
|
for tf in timeframes:
|
||
|
|
logger.info(f"Aggregating 1m candles to {tf} timeframe...")
|
||
|
|
# Resample the DataFrame to the target timeframe
|
||
|
|
resampled_df = resample_candles_to_timeframe(df_raw_candles, tf)
|
||
|
|
|
||
|
|
# Convert resampled DataFrame back to OHLCVCandle objects
|
||
|
|
for index, row in resampled_df.iterrows():
|
||
|
|
# index is the end_time for the resampled candle
|
||
|
|
end_time = index
|
||
|
|
# Calculate start_time based on end_time and timeframe
|
||
|
|
# TimeframeBucket._parse_timeframe_to_timedelta returns timedelta
|
||
|
|
time_delta = TimeframeBucket._parse_timeframe_to_timedelta(tf)
|
||
|
|
start_time = end_time - time_delta
|
||
|
|
|
||
|
|
candle = OHLCVCandle(
|
||
|
|
exchange=exchange,
|
||
|
|
symbol=symbol,
|
||
|
|
timeframe=tf,
|
||
|
|
start_time=start_time,
|
||
|
|
end_time=end_time,
|
||
|
|
open=Decimal(str(row['open'])),
|
||
|
|
high=Decimal(str(row['high'])),
|
||
|
|
low=Decimal(str(row['low'])),
|
||
|
|
close=Decimal(str(row['close'])),
|
||
|
|
volume=Decimal(str(row['volume'])),
|
||
|
|
trade_count=int(row.get('trades_count', 0)),
|
||
|
|
is_complete=True # Resampled candles are considered complete
|
||
|
|
)
|
||
|
|
all_aggregated_candles.append(candle)
|
||
|
|
|
||
|
|
# Sort candles by timeframe and then by end_time for consistent output/insertion
|
||
|
|
all_aggregated_candles.sort(key=lambda x: (x.timeframe, x.end_time))
|
||
|
|
|
||
|
|
logger.info(f"Aggregated {len(all_aggregated_candles)} candles for timeframes: {', '.join(timeframes)}")
|
||
|
|
|
||
|
|
if test_run:
|
||
|
|
logger.info("--- Test Run: Sample of Aggregated Candles (first 5) ---")
|
||
|
|
for i, candle in enumerate(all_aggregated_candles[:5]):
|
||
|
|
logger.info(f" {candle.to_dict()}")
|
||
|
|
logger.info("--- End of Test Run Sample ---")
|
||
|
|
logger.info("Data not inserted into database due to --test-run flag.")
|
||
|
|
else:
|
||
|
|
logger.info(f"Starting batch insertion of {len(all_aggregated_candles)} aggregated candles with batch size {batch_size}.")
|
||
|
|
market_data_repo.upsert_candles_batch(all_aggregated_candles, force_update=force_update, batch_size=batch_size)
|
||
|
|
logger.info("CSV data ingestion complete.")
|
||
|
|
|
||
|
|
elif file_path.endswith('.db') or file_path.endswith('.sqlite'):
|
||
|
|
logger.info(f"Processing SQLite database: {file_path}")
|
||
|
|
if not timeframes:
|
||
|
|
logger.error("Error: Timeframes must be specified for SQLite trade data aggregation.")
|
||
|
|
return
|
||
|
|
|
||
|
|
trades = parse_sqlite_to_trades(file_path, exchange, symbol, sample_rows=sample_rows if test_run else None)
|
||
|
|
logger.info(f"Parsed {len(trades)} trades from SQLite.")
|
||
|
|
|
||
|
|
if not trades:
|
||
|
|
logger.info("No trades found to process in the SQLite database.")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Use BatchCandleProcessor to aggregate trades into candles
|
||
|
|
processor = BatchCandleProcessor(symbol=symbol, exchange=exchange, timeframes=timeframes, logger=logger)
|
||
|
|
aggregated_candles = processor.process_trades_to_candles(iter(trades))
|
||
|
|
|
||
|
|
logger.info(f"Aggregated {len(aggregated_candles)} candles from trades for timeframes: {', '.join(timeframes)}")
|
||
|
|
|
||
|
|
if test_run:
|
||
|
|
logger.info("--- Test Run: Sample of Aggregated Candles (first 5) ---")
|
||
|
|
for i, candle in enumerate(aggregated_candles[:5]):
|
||
|
|
logger.info(f" {candle.to_dict()}")
|
||
|
|
logger.info("--- End of Test Run Sample ---")
|
||
|
|
logger.info("Data not inserted into database due to --test-run flag.")
|
||
|
|
else:
|
||
|
|
logger.info(f"Starting batch insertion of {len(aggregated_candles)} candles with batch size {batch_size}.")
|
||
|
|
market_data_repo.upsert_candles_batch(aggregated_candles, force_update=force_update, batch_size=batch_size)
|
||
|
|
logger.info("SQLite data ingestion complete.")
|
||
|
|
|
||
|
|
else:
|
||
|
|
logger.error("Error: Unsupported file type. Please provide a .csv or .sqlite/.db file.")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|