2025-05-30 18:20:38 +08:00
|
|
|
"""
|
|
|
|
|
Database Connection Utility for Crypto Trading Bot Platform
|
|
|
|
|
Provides connection pooling, session management, and database utilities
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from typing import Optional, Generator, Any, Dict, List, Union
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
# Load environment variables from .env file if it exists
|
|
|
|
|
try:
|
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
env_file = Path(__file__).parent.parent / '.env'
|
|
|
|
|
if env_file.exists():
|
|
|
|
|
load_dotenv(env_file)
|
|
|
|
|
except ImportError:
|
|
|
|
|
# dotenv not available, proceed without it
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import create_engine, Engine, text, event
|
|
|
|
|
from sqlalchemy.orm import sessionmaker, Session, scoped_session
|
|
|
|
|
from sqlalchemy.pool import QueuePool
|
|
|
|
|
from sqlalchemy.exc import SQLAlchemyError, OperationalError, DisconnectionError
|
|
|
|
|
from sqlalchemy.engine import make_url
|
|
|
|
|
import time
|
|
|
|
|
from functools import wraps
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
|
|
|
|
|
|
from .models import Base, create_all_tables, drop_all_tables, RawTrade
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Configure logging
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DatabaseConfig:
|
|
|
|
|
"""Database configuration class"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.database_url = os.getenv(
|
|
|
|
|
'DATABASE_URL',
|
|
|
|
|
'postgresql://dashboard@localhost:5434/dashboard'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Connection pool settings
|
|
|
|
|
self.pool_size = int(os.getenv('DB_POOL_SIZE', '5'))
|
|
|
|
|
self.max_overflow = int(os.getenv('DB_MAX_OVERFLOW', '10'))
|
|
|
|
|
self.pool_pre_ping = os.getenv('DB_POOL_PRE_PING', 'true').lower() == 'true'
|
|
|
|
|
self.pool_recycle = int(os.getenv('DB_POOL_RECYCLE', '3600')) # 1 hour
|
|
|
|
|
|
|
|
|
|
# Connection timeout settings
|
|
|
|
|
self.connect_timeout = int(os.getenv('DB_CONNECT_TIMEOUT', '30'))
|
|
|
|
|
self.statement_timeout = int(os.getenv('DB_STATEMENT_TIMEOUT', '30000')) # 30 seconds in ms
|
|
|
|
|
|
|
|
|
|
# Retry settings
|
|
|
|
|
self.max_retries = int(os.getenv('DB_MAX_RETRIES', '3'))
|
|
|
|
|
self.retry_delay = float(os.getenv('DB_RETRY_DELAY', '1.0'))
|
|
|
|
|
|
|
|
|
|
# SSL settings
|
|
|
|
|
self.ssl_mode = os.getenv('DB_SSL_MODE', 'prefer')
|
|
|
|
|
|
|
|
|
|
logger.info(f"Database configuration initialized for: {self._safe_url()}")
|
|
|
|
|
|
|
|
|
|
def _safe_url(self) -> str:
|
|
|
|
|
"""Return database URL with password masked for logging"""
|
|
|
|
|
url = make_url(self.database_url)
|
|
|
|
|
return str(url.set(password="***"))
|
|
|
|
|
|
|
|
|
|
def get_engine_kwargs(self) -> Dict[str, Any]:
|
|
|
|
|
"""Get SQLAlchemy engine configuration"""
|
|
|
|
|
return {
|
|
|
|
|
'poolclass': QueuePool,
|
|
|
|
|
'pool_size': self.pool_size,
|
|
|
|
|
'max_overflow': self.max_overflow,
|
|
|
|
|
'pool_pre_ping': self.pool_pre_ping,
|
|
|
|
|
'pool_recycle': self.pool_recycle,
|
|
|
|
|
'connect_args': {
|
|
|
|
|
'connect_timeout': self.connect_timeout,
|
|
|
|
|
'options': f'-c statement_timeout={self.statement_timeout}',
|
|
|
|
|
'sslmode': self.ssl_mode,
|
|
|
|
|
},
|
2025-06-03 12:49:46 +08:00
|
|
|
'echo': False, # Disable SQL logging to reduce verbosity
|
2025-05-30 18:20:38 +08:00
|
|
|
'future': True, # Use SQLAlchemy 2.0 style
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DatabaseManager:
|
|
|
|
|
"""
|
|
|
|
|
Database manager with connection pooling and session management
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Optional[DatabaseConfig] = None):
|
|
|
|
|
self.config = config or DatabaseConfig()
|
|
|
|
|
self._engine: Optional[Engine] = None
|
|
|
|
|
self._session_factory: Optional[sessionmaker] = None
|
|
|
|
|
self._scoped_session: Optional[scoped_session] = None
|
|
|
|
|
|
|
|
|
|
def initialize(self) -> None:
|
|
|
|
|
"""Initialize database engine and session factory"""
|
|
|
|
|
try:
|
|
|
|
|
logger.info("Initializing database connection...")
|
|
|
|
|
|
|
|
|
|
# Create engine with retry logic
|
|
|
|
|
self._engine = self._create_engine_with_retry()
|
|
|
|
|
|
|
|
|
|
# Setup session factory
|
|
|
|
|
self._session_factory = sessionmaker(
|
|
|
|
|
bind=self._engine,
|
|
|
|
|
autocommit=False,
|
|
|
|
|
autoflush=False,
|
|
|
|
|
expire_on_commit=False
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Setup scoped session for thread safety
|
|
|
|
|
self._scoped_session = scoped_session(self._session_factory)
|
|
|
|
|
|
|
|
|
|
# Add connection event listeners
|
|
|
|
|
self._setup_connection_events()
|
|
|
|
|
|
|
|
|
|
logger.info("Database connection initialized successfully")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to initialize database: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def _create_engine_with_retry(self) -> Engine:
|
|
|
|
|
"""Create database engine with retry logic"""
|
|
|
|
|
for attempt in range(self.config.max_retries):
|
|
|
|
|
try:
|
|
|
|
|
engine = create_engine(
|
|
|
|
|
self.config.database_url,
|
|
|
|
|
**self.config.get_engine_kwargs()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Test connection
|
|
|
|
|
with engine.connect() as conn:
|
|
|
|
|
conn.execute(text("SELECT 1"))
|
|
|
|
|
logger.info("Database connection test successful")
|
|
|
|
|
|
|
|
|
|
return engine
|
|
|
|
|
|
|
|
|
|
except (OperationalError, DisconnectionError) as e:
|
|
|
|
|
attempt_num = attempt + 1
|
|
|
|
|
if attempt_num == self.config.max_retries:
|
|
|
|
|
logger.error(f"Failed to connect to database after {self.config.max_retries} attempts: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
logger.warning(f"Database connection attempt {attempt_num} failed: {e}. Retrying in {self.config.retry_delay}s...")
|
|
|
|
|
time.sleep(self.config.retry_delay)
|
|
|
|
|
|
|
|
|
|
def _setup_connection_events(self) -> None:
|
|
|
|
|
"""Setup SQLAlchemy connection event listeners"""
|
|
|
|
|
if not self._engine:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
@event.listens_for(self._engine, "connect")
|
|
|
|
|
def set_sqlite_pragma(dbapi_connection, connection_record):
|
|
|
|
|
"""Set connection-level settings"""
|
|
|
|
|
if 'postgresql' in str(self._engine.url):
|
|
|
|
|
with dbapi_connection.cursor() as cursor:
|
|
|
|
|
# Set timezone to UTC
|
|
|
|
|
cursor.execute("SET timezone TO 'UTC'")
|
|
|
|
|
# Set application name for monitoring
|
|
|
|
|
cursor.execute("SET application_name TO 'crypto_trading_bot'")
|
|
|
|
|
|
|
|
|
|
@event.listens_for(self._engine, "checkout")
|
|
|
|
|
def checkout_listener(dbapi_connection, connection_record, connection_proxy):
|
|
|
|
|
"""Log connection checkout"""
|
|
|
|
|
logger.debug("Database connection checked out from pool")
|
|
|
|
|
|
|
|
|
|
@event.listens_for(self._engine, "checkin")
|
|
|
|
|
def checkin_listener(dbapi_connection, connection_record):
|
|
|
|
|
"""Log connection checkin"""
|
|
|
|
|
logger.debug("Database connection returned to pool")
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def engine(self) -> Engine:
|
|
|
|
|
"""Get database engine"""
|
|
|
|
|
if not self._engine:
|
|
|
|
|
raise RuntimeError("Database not initialized. Call initialize() first.")
|
|
|
|
|
return self._engine
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def session_factory(self) -> sessionmaker:
|
|
|
|
|
"""Get session factory"""
|
|
|
|
|
if not self._session_factory:
|
|
|
|
|
raise RuntimeError("Database not initialized. Call initialize() first.")
|
|
|
|
|
return self._session_factory
|
|
|
|
|
|
|
|
|
|
def create_session(self) -> Session:
|
|
|
|
|
"""Create a new database session"""
|
|
|
|
|
if not self._session_factory:
|
|
|
|
|
raise RuntimeError("Database not initialized. Call initialize() first.")
|
|
|
|
|
return self._session_factory()
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def get_session(self) -> Generator[Session, None, None]:
|
|
|
|
|
"""
|
|
|
|
|
Context manager for database sessions with automatic cleanup
|
|
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
|
with db_manager.get_session() as session:
|
|
|
|
|
# Use session here
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
session = self.create_session()
|
|
|
|
|
try:
|
|
|
|
|
yield session
|
|
|
|
|
session.commit()
|
|
|
|
|
except Exception:
|
|
|
|
|
session.rollback()
|
|
|
|
|
raise
|
|
|
|
|
finally:
|
|
|
|
|
session.close()
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def get_scoped_session(self) -> Generator[Session, None, None]:
|
|
|
|
|
"""
|
|
|
|
|
Context manager for scoped sessions (thread-safe)
|
|
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
|
with db_manager.get_scoped_session() as session:
|
|
|
|
|
# Use session here
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
if not self._scoped_session:
|
|
|
|
|
raise RuntimeError("Database not initialized. Call initialize() first.")
|
|
|
|
|
|
|
|
|
|
session = self._scoped_session()
|
|
|
|
|
try:
|
|
|
|
|
yield session
|
|
|
|
|
session.commit()
|
|
|
|
|
except Exception:
|
|
|
|
|
session.rollback()
|
|
|
|
|
raise
|
|
|
|
|
finally:
|
|
|
|
|
self._scoped_session.remove()
|
|
|
|
|
|
|
|
|
|
def test_connection(self) -> bool:
|
|
|
|
|
"""Test database connection"""
|
|
|
|
|
try:
|
|
|
|
|
with self.get_session() as session:
|
|
|
|
|
session.execute(text("SELECT 1"))
|
|
|
|
|
logger.info("Database connection test successful")
|
|
|
|
|
return True
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Database connection test failed: {e}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def get_pool_status(self) -> Dict[str, Any]:
|
|
|
|
|
"""Get connection pool status"""
|
|
|
|
|
if not self._engine or not hasattr(self._engine.pool, 'size'):
|
|
|
|
|
return {"status": "Pool not available"}
|
|
|
|
|
|
|
|
|
|
pool = self._engine.pool
|
|
|
|
|
return {
|
|
|
|
|
"pool_size": pool.size(),
|
|
|
|
|
"checked_in": pool.checkedin(),
|
|
|
|
|
"checked_out": pool.checkedout(),
|
|
|
|
|
"overflow": pool.overflow(),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def create_tables(self) -> None:
|
|
|
|
|
"""Create all database tables"""
|
|
|
|
|
try:
|
|
|
|
|
create_all_tables(self.engine)
|
|
|
|
|
logger.info("Database tables created successfully")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to create database tables: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def drop_tables(self) -> None:
|
|
|
|
|
"""Drop all database tables"""
|
|
|
|
|
try:
|
|
|
|
|
drop_all_tables(self.engine)
|
|
|
|
|
logger.info("Database tables dropped successfully")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to drop database tables: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def execute_schema_file(self, schema_file_path: str) -> None:
|
|
|
|
|
"""Execute SQL schema file"""
|
|
|
|
|
try:
|
|
|
|
|
with open(schema_file_path, 'r') as file:
|
|
|
|
|
schema_sql = file.read()
|
|
|
|
|
|
|
|
|
|
with self.get_session() as session:
|
|
|
|
|
# Split and execute each statement
|
|
|
|
|
statements = [stmt.strip() for stmt in schema_sql.split(';') if stmt.strip()]
|
|
|
|
|
for statement in statements:
|
|
|
|
|
if statement:
|
|
|
|
|
session.execute(text(statement))
|
|
|
|
|
|
|
|
|
|
logger.info(f"Schema file executed successfully: {schema_file_path}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to execute schema file {schema_file_path}: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def close(self) -> None:
|
|
|
|
|
"""Close database connections and cleanup"""
|
|
|
|
|
try:
|
|
|
|
|
if self._scoped_session:
|
|
|
|
|
self._scoped_session.remove()
|
|
|
|
|
|
|
|
|
|
if self._engine:
|
|
|
|
|
self._engine.dispose()
|
|
|
|
|
logger.info("Database connections closed")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error closing database connections: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retry_on_database_error(max_retries: int = 3, delay: float = 1.0):
|
|
|
|
|
"""
|
|
|
|
|
Decorator to retry database operations on transient errors
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
max_retries: Maximum number of retry attempts
|
|
|
|
|
delay: Delay between retries in seconds
|
|
|
|
|
"""
|
|
|
|
|
def decorator(func):
|
|
|
|
|
@wraps(func)
|
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
|
for attempt in range(max_retries):
|
|
|
|
|
try:
|
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
except (OperationalError, DisconnectionError) as e:
|
|
|
|
|
if attempt == max_retries - 1:
|
|
|
|
|
logger.error(f"Database operation failed after {max_retries} attempts: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
logger.warning(f"Database operation failed (attempt {attempt + 1}): {e}. Retrying in {delay}s...")
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
return None
|
|
|
|
|
return wrapper
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Global database manager instance
|
|
|
|
|
db_manager = DatabaseManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_db_manager() -> DatabaseManager:
|
|
|
|
|
"""Get global database manager instance"""
|
|
|
|
|
return db_manager
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_database(config: Optional[DatabaseConfig] = None) -> DatabaseManager:
|
|
|
|
|
"""
|
|
|
|
|
Initialize global database manager
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
config: Optional database configuration
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
DatabaseManager instance
|
|
|
|
|
"""
|
|
|
|
|
global db_manager
|
|
|
|
|
if config:
|
|
|
|
|
db_manager = DatabaseManager(config)
|
|
|
|
|
db_manager.initialize()
|
|
|
|
|
return db_manager
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Convenience functions for common operations
|
|
|
|
|
def get_session() -> Generator[Session, None, None]:
|
|
|
|
|
"""Get database session (convenience function)"""
|
|
|
|
|
return db_manager.get_session()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_scoped_session() -> Generator[Session, None, None]:
|
|
|
|
|
"""Get scoped database session (convenience function)"""
|
|
|
|
|
return db_manager.get_scoped_session()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_connection() -> bool:
|
|
|
|
|
"""Test database connection (convenience function)"""
|
|
|
|
|
return db_manager.test_connection()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_pool_status() -> Dict[str, Any]:
|
|
|
|
|
"""Get connection pool status (convenience function)"""
|
2025-06-06 23:51:21 +08:00
|
|
|
return db_manager.get_pool_status()
|