data ingestion

This commit is contained in:
Vasily.onl 2025-06-13 16:49:29 +08:00
parent f09864d61b
commit 622fda9d2e
6 changed files with 408 additions and 10 deletions

View File

@ -8,6 +8,7 @@ and trade data aggregation.
import re
from typing import List, Tuple
from utils.timeframe_utils import load_timeframe_options
import pandas as pd
from ..data_types import StandardizedTrade, OHLCVCandle
@ -74,8 +75,75 @@ def parse_timeframe(timeframe: str) -> Tuple[int, str]:
return number, unit
def resample_candles_to_timeframe(df: pd.DataFrame, target_timeframe: str) -> pd.DataFrame:
"""
Resamples a DataFrame of OHLCV candles to a higher timeframe.
Args:
df (pd.DataFrame): Input DataFrame with a datetime index and 'open', 'high', 'low', 'close', 'volume',
and optionally 'trades_count' columns.
target_timeframe (str): The target timeframe for resampling (e.g., '1h', '1d').
Returns:
pd.DataFrame: Resampled DataFrame with OHLCV data for the target timeframe.
"""
if df.empty:
return pd.DataFrame()
# Ensure the DataFrame index is a datetime index
if not isinstance(df.index, pd.DatetimeIndex):
df['timestamp'] = pd.to_datetime(df['timestamp'])
df = df.set_index('timestamp')
# Convert timedelta string to pandas frequency string
# '1m' -> '1T', '1h' -> '1H', '1d' -> '1D'
timeframe_map = {
's': 'S',
'm': 'T',
'h': 'H',
'd': 'D'
}
# Convert target_timeframe to pandas offset string
match = re.match(r'^(\d+)([smhd])$', target_timeframe.lower())
if not match:
raise ValueError(f"Invalid target timeframe format: {target_timeframe}")
number = match.group(1)
unit = timeframe_map.get(match.group(2))
if not unit:
raise ValueError(f"Unsupported timeframe unit: {target_timeframe}")
resample_freq = f"{number}{unit}"
# Define how to aggregate each column
ohlcv_dict = {
'open': 'first',
'high': 'max',
'low': 'min',
'close': 'last',
'volume': 'sum',
}
# Only include 'trades_count' if it exists in the DataFrame
if 'trades_count' in df.columns:
ohlcv_dict['trades_count'] = 'sum'
# Resample the data
resampled_df = df.resample(resample_freq).apply(ohlcv_dict)
# Drop rows where all OHLCV values are NaN (e.g., periods with no data)
resampled_df.dropna(subset=['open', 'high', 'low', 'close'], inplace=True)
# Fill NaN trades_count with 0 after resampling
if 'trades_count' in resampled_df.columns:
resampled_df['trades_count'] = resampled_df['trades_count'].fillna(0).astype(int)
return resampled_df
__all__ = [
'aggregate_trades_to_candles',
'validate_timeframe',
'parse_timeframe'
'parse_timeframe',
'resample_candles_to_timeframe'
]

View File

@ -5,6 +5,7 @@ from contextlib import contextmanager
from typing import Optional
from ..connection import get_db_manager
from utils.logger import get_logger
class DatabaseOperationError(Exception):
@ -17,24 +18,24 @@ class BaseRepository:
def __init__(self, logger: Optional[logging.Logger] = None):
"""Initialize repository with optional logger."""
self.logger = logger
if logger is None:
self.logger = get_logger(self.__class__.__name__)
else:
self.logger = logger
self._db_manager = get_db_manager()
self._db_manager.initialize()
def log_info(self, message: str) -> None:
"""Log info message if logger is available."""
if self.logger:
self.logger.info(message)
self.logger.info(message)
def log_debug(self, message: str) -> None:
"""Log debug message if logger is available."""
if self.logger:
self.logger.debug(message)
self.logger.debug(message)
def log_error(self, message: str) -> None:
"""Log error message if logger is available."""
if self.logger:
self.logger.error(message)
self.logger.error(message)
@contextmanager
def get_session(self):

View File

@ -10,6 +10,7 @@ from sqlalchemy.dialects.postgresql import insert
from ..models import MarketData
from data.common.data_types import OHLCVCandle
from .base_repository import BaseRepository, DatabaseOperationError
from tqdm import tqdm
class MarketDataRepository(BaseRepository):
@ -68,6 +69,63 @@ class MarketDataRepository(BaseRepository):
self.log_error(f"Error storing candle {candle.symbol} {candle.timeframe}: {e}")
raise DatabaseOperationError(f"Failed to store candle: {e}")
def upsert_candles_batch(self, candles: List[OHLCVCandle], force_update: bool = False, batch_size: int = 1000) -> int:
"""
Insert or update multiple candles in the market_data table in batches.
"""
total_processed = 0
try:
for i in tqdm(range(0, len(candles), batch_size), desc="Inserting candles in batches"):
batch = candles[i:i + batch_size]
values = [
{
'exchange': candle.exchange,
'symbol': candle.symbol,
'timeframe': candle.timeframe,
'timestamp': candle.end_time,
'open': candle.open,
'high': candle.high,
'low': candle.low,
'close': candle.close,
'volume': candle.volume,
'trades_count': candle.trade_count
}
for candle in batch
]
with self.get_session() as session:
stmt = insert(MarketData).values(values)
if force_update:
final_stmt = stmt.on_conflict_do_update(
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp'],
set_={
'open': stmt.excluded.open,
'high': stmt.excluded.high,
'low': stmt.excluded.low,
'close': stmt.excluded.close,
'volume': stmt.excluded.volume,
'trades_count': stmt.excluded.trades_count
}
)
action = "Updated"
else:
final_stmt = stmt.on_conflict_do_nothing(
index_elements=['exchange', 'symbol', 'timeframe', 'timestamp']
)
action = "Stored"
session.execute(final_stmt)
session.commit()
total_processed += len(batch)
self.log_debug(f"{action} {len(batch)} candles in batch. Total processed: {total_processed}")
return total_processed
except Exception as e:
self.log_error(f"Error storing candles in batch: {e}")
raise DatabaseOperationError(f"Failed to store candles in batch: {e}")
def get_candles(self,
symbol: str,
timeframe: str,
@ -77,6 +135,7 @@ class MarketDataRepository(BaseRepository):
"""
Retrieve candles from the database using the ORM.
"""
self.log_debug(f"DB: get_candles called with: symbol={symbol}, timeframe={timeframe}, start_time={start_time}, end_time={end_time}, exchange={exchange}")
try:
with self.get_session() as session:
query = (
@ -102,7 +161,7 @@ class MarketDataRepository(BaseRepository):
} for r in results
]
self.log_debug(f"Retrieved {len(candles)} candles for {symbol} {timeframe}")
self.log_debug(f"DB: Retrieved {len(candles)} candles for {symbol} {timeframe} from {start_time} to {end_time}")
return candles
except Exception as e:
@ -195,4 +254,20 @@ class MarketDataRepository(BaseRepository):
except Exception as e:
self.log_error(f"Error retrieving candles as DataFrame: {e}")
raise DatabaseOperationError(f"Failed to retrieve candles as DataFrame: {e}")
raise DatabaseOperationError(f"Failed to retrieve candles as DataFrame: {e}")
def delete_candles_before_timestamp(self, timestamp: datetime) -> int:
"""
Delete candles from the market_data table that are older than the specified timestamp.
"""
try:
with self.get_session() as session:
deleted_count = session.query(MarketData).filter(
MarketData.timestamp < timestamp
).delete(synchronize_session=False)
session.commit()
self.logger.warning(f"Deleted {deleted_count} candles older than {timestamp}")
return deleted_count
except Exception as e:
self.log_error(f"Error deleting candles older than {timestamp}: {e}")
raise DatabaseOperationError(f"Failed to delete candles: {e}")

View File

@ -40,6 +40,7 @@ dependencies = [
"pytest>=8.3.5",
"psutil>=7.0.0",
"tzlocal>=5.3.1",
"tdqm>=0.0.1",
]
[project.optional-dependencies]

230
scripts/data_ingestion.py Normal file
View File

@ -0,0 +1,230 @@
import argparse
import os
import pandas as pd
import sqlite3
from datetime import datetime
from decimal import Decimal
from tqdm import tqdm # Import tqdm
from data.common.data_types import OHLCVCandle, StandardizedTrade
from data.common.aggregation.batch import BatchCandleProcessor
from data.common.aggregation.utils import resample_candles_to_timeframe # Import for CSV aggregation
from data.common.aggregation.bucket import TimeframeBucket # For calculating start_time from end_time
from database.repositories.market_data_repository import MarketDataRepository
from database.repositories.raw_trade_repository import RawTradeRepository
from utils.logger import get_logger # Import custom logger
logger = get_logger('data_ingestion')
def parse_csv_to_candles(file_path: str, exchange: str, symbol: str, sample_rows: int = None) -> list[OHLCVCandle]:
"""Parses a CSV file into OHLCVCandle objects, assuming 1-minute candles."""
if sample_rows:
df = pd.read_csv(file_path, nrows=sample_rows)
logger.info(f"Reading first {sample_rows} rows from CSV for test run.")
else:
df = pd.read_csv(file_path)
# Convert column names to lowercase to handle case insensitivity
df.columns = df.columns.str.lower()
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
if not all(col in df.columns for col in required_columns):
raise ValueError(f"CSV file must contain columns: {required_columns}")
candles = []
for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing CSV rows"): # Add tqdm
try:
timestamp = datetime.fromtimestamp(row['timestamp'])
candle = OHLCVCandle(
exchange=exchange,
symbol=symbol,
timeframe='1m', # Assume 1-minute candles for raw CSV data
start_time=timestamp,
end_time=timestamp, # For minute data, start and end time can be the same
open=Decimal(str(row['open'])),
high=Decimal(str(row['high'])),
low=Decimal(str(row['low'])),
close=Decimal(str(row['close'])),
volume=Decimal(str(row['volume'])),
trade_count=int(row.get('trades_count', 0)), # trades_count might not be in all CSVs
is_complete=True # Explicitly set to True for CSV data
)
candles.append(candle)
except Exception as e:
logger.error(f"Error parsing row: {row}. Error: {e}")
return candles
def parse_sqlite_to_trades(db_path: str, exchange: str, symbol: str, sample_rows: int = None) -> list[StandardizedTrade]:
"""Reads raw trades from an SQLite database."""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
query = "SELECT id, instrument, price, size, side, timestamp FROM trades WHERE instrument = ? ORDER BY timestamp ASC"
if sample_rows:
query += f" LIMIT {sample_rows}"
logger.info(f"Reading first {sample_rows} trades from SQLite for test run.")
cursor.execute(query, (symbol,))
# Fetch all results to apply tqdm effectively over the list
rows = cursor.fetchall()
trades = []
for row in tqdm(rows, total=len(rows), desc="Processing SQLite trades"): # Add tqdm
trade_id, instrument, price, size, side, timestamp = row
try:
# Assuming timestamp is in milliseconds and needs conversion to datetime
trade_timestamp = datetime.fromtimestamp(int(timestamp) / 1000)
trade = StandardizedTrade(
symbol=instrument,
trade_id=str(trade_id),
price=Decimal(str(price)),
size=Decimal(str(size)),
side=side,
timestamp=trade_timestamp,
exchange=exchange
)
trades.append(trade)
except Exception as e:
logger.error(f"Error parsing trade row: {row}. Error: {e}")
conn.close()
return trades
def main():
parser = argparse.ArgumentParser(description="Ingest market data into the database.")
parser.add_argument("--file", required=True, help="Path to the input data file (CSV or SQLite).")
parser.add_argument("--exchange", required=True, help="Exchange name (e.g., 'okx').")
parser.add_argument("--symbol", required=True, help="Trading symbol (e.g., 'BTC-USDT').")
parser.add_argument("--timeframes", nargs='*', default=['1m'], help="Timeframes for aggregation (e.g., '1m', '5m', '1h'). Required for SQLite, optional for CSV.")
parser.add_argument("--force", action="store_true", help="Overwrite existing data if it conflicts.")
parser.add_argument("--test-run", action="store_true", help="Run without inserting data, print a sample instead.")
parser.add_argument("--sample-rows", type=int, help="Number of rows to process in test-run mode. Only effective with --test-run.")
parser.add_argument("--batch-size", type=int, default=10000, help="Batch size for inserting data into the database.")
args = parser.parse_args()
file_path = args.file
exchange = args.exchange
symbol = args.symbol
timeframes = args.timeframes
force_update = args.force
test_run = args.test_run
sample_rows = args.sample_rows
batch_size = args.batch_size
if test_run and sample_rows is None:
logger.warning("--- No --sample-rows specified for --test-run. Processing full file for sample output. ---")
market_data_repo = MarketDataRepository()
# raw_trade_repo = RawTradeRepository() # Not used in this script
if not os.path.exists(file_path):
logger.error(f"Error: File not found at {file_path}")
return
if file_path.endswith('.csv'):
logger.info(f"Processing CSV file: {file_path}")
raw_candles = parse_csv_to_candles(file_path, exchange, symbol, sample_rows=sample_rows if test_run else None)
logger.info(f"Parsed {len(raw_candles)} raw 1m candles from CSV.")
if not raw_candles:
logger.info("No raw candles found to process in the CSV file.")
return
all_aggregated_candles = []
# Convert raw candles to a pandas DataFrame for resampling
df_raw_candles = pd.DataFrame([c.to_dict() for c in raw_candles])
# Ensure 'end_time' is a datetime object and set as index for resampling
df_raw_candles['end_time'] = pd.to_datetime(df_raw_candles['end_time'])
df_raw_candles = df_raw_candles.set_index('end_time')
# Convert Decimal types to float for pandas resampling, then back to Decimal after aggregation
# This ensures compatibility with pandas' numerical operations
for col in ['open', 'high', 'low', 'close', 'volume']:
if col in df_raw_candles.columns:
df_raw_candles[col] = pd.to_numeric(df_raw_candles[col])
# 'trade_count' might not exist, handle with .get()
if 'trade_count' in df_raw_candles.columns:
df_raw_candles['trade_count'] = pd.to_numeric(df_raw_candles['trade_count'])
for tf in timeframes:
logger.info(f"Aggregating 1m candles to {tf} timeframe...")
# Resample the DataFrame to the target timeframe
resampled_df = resample_candles_to_timeframe(df_raw_candles, tf)
# Convert resampled DataFrame back to OHLCVCandle objects
for index, row in resampled_df.iterrows():
# index is the end_time for the resampled candle
end_time = index
# Calculate start_time based on end_time and timeframe
# TimeframeBucket._parse_timeframe_to_timedelta returns timedelta
time_delta = TimeframeBucket._parse_timeframe_to_timedelta(tf)
start_time = end_time - time_delta
candle = OHLCVCandle(
exchange=exchange,
symbol=symbol,
timeframe=tf,
start_time=start_time,
end_time=end_time,
open=Decimal(str(row['open'])),
high=Decimal(str(row['high'])),
low=Decimal(str(row['low'])),
close=Decimal(str(row['close'])),
volume=Decimal(str(row['volume'])),
trade_count=int(row.get('trades_count', 0)),
is_complete=True # Resampled candles are considered complete
)
all_aggregated_candles.append(candle)
# Sort candles by timeframe and then by end_time for consistent output/insertion
all_aggregated_candles.sort(key=lambda x: (x.timeframe, x.end_time))
logger.info(f"Aggregated {len(all_aggregated_candles)} candles for timeframes: {', '.join(timeframes)}")
if test_run:
logger.info("--- Test Run: Sample of Aggregated Candles (first 5) ---")
for i, candle in enumerate(all_aggregated_candles[:5]):
logger.info(f" {candle.to_dict()}")
logger.info("--- End of Test Run Sample ---")
logger.info("Data not inserted into database due to --test-run flag.")
else:
logger.info(f"Starting batch insertion of {len(all_aggregated_candles)} aggregated candles with batch size {batch_size}.")
market_data_repo.upsert_candles_batch(all_aggregated_candles, force_update=force_update, batch_size=batch_size)
logger.info("CSV data ingestion complete.")
elif file_path.endswith('.db') or file_path.endswith('.sqlite'):
logger.info(f"Processing SQLite database: {file_path}")
if not timeframes:
logger.error("Error: Timeframes must be specified for SQLite trade data aggregation.")
return
trades = parse_sqlite_to_trades(file_path, exchange, symbol, sample_rows=sample_rows if test_run else None)
logger.info(f"Parsed {len(trades)} trades from SQLite.")
if not trades:
logger.info("No trades found to process in the SQLite database.")
return
# Use BatchCandleProcessor to aggregate trades into candles
processor = BatchCandleProcessor(symbol=symbol, exchange=exchange, timeframes=timeframes, logger=logger)
aggregated_candles = processor.process_trades_to_candles(iter(trades))
logger.info(f"Aggregated {len(aggregated_candles)} candles from trades for timeframes: {', '.join(timeframes)}")
if test_run:
logger.info("--- Test Run: Sample of Aggregated Candles (first 5) ---")
for i, candle in enumerate(aggregated_candles[:5]):
logger.info(f" {candle.to_dict()}")
logger.info("--- End of Test Run Sample ---")
logger.info("Data not inserted into database due to --test-run flag.")
else:
logger.info(f"Starting batch insertion of {len(aggregated_candles)} candles with batch size {batch_size}.")
market_data_repo.upsert_candles_batch(aggregated_candles, force_update=force_update, batch_size=batch_size)
logger.info("SQLite data ingestion complete.")
else:
logger.error("Error: Unsupported file type. Please provide a .csv or .sqlite/.db file.")
if __name__ == "__main__":
main()

23
uv.lock generated
View File

@ -441,6 +441,7 @@ dependencies = [
{ name = "requests" },
{ name = "sqlalchemy" },
{ name = "structlog" },
{ name = "tdqm" },
{ name = "tzlocal" },
{ name = "waitress" },
{ name = "watchdog" },
@ -498,6 +499,7 @@ requires-dist = [
{ name = "requests", specifier = ">=2.31.0" },
{ name = "sqlalchemy", specifier = ">=2.0.0" },
{ name = "structlog", specifier = ">=23.1.0" },
{ name = "tdqm", specifier = ">=0.0.1" },
{ name = "tzlocal", specifier = ">=5.3.1" },
{ name = "waitress", specifier = ">=3.0.0" },
{ name = "watchdog", specifier = ">=3.0.0" },
@ -1745,6 +1747,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f5/52/7a2c7a317b254af857464da3d60a0d3730c44f912f8c510c76a738a207fd/structlog-25.3.0-py3-none-any.whl", hash = "sha256:a341f5524004c158498c3127eecded091eb67d3a611e7a3093deca30db06e172", size = 68240 },
]
[[package]]
name = "tdqm"
version = "0.0.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "tqdm" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5a/38/58c9e22b95e98666fe29f35e217ab6126a03798dadf3f4f8f3e5f898510b/tdqm-0.0.1.tar.gz", hash = "sha256:f050004a76b1d22f70b78209b48781353c82440215b6ba7d22b1f499b05a0101", size = 1388 }
[[package]]
name = "tomli"
version = "2.2.1"
@ -1784,6 +1795,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 },
]
[[package]]
name = "tqdm"
version = "4.67.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 },
]
[[package]]
name = "typing-extensions"
version = "4.13.2"