OKXTrading/database_manager.py
2025-11-05 16:39:46 +08:00

265 lines
11 KiB
Python

import sqlite3
import threading
import logging
import os
from datetime import datetime
logger = logging.getLogger(__name__)
class DatabaseManager:
"""Database Manager"""
def __init__(self, db_path=None):
# Get database path from environment variable, use default if not set
if db_path is None:
db_path = os.getenv('DB_PATH', 'trading_system.db')
self.db_path = db_path
self.lock = threading.Lock()
self.init_database()
logger.info(f"Database initialization completed, path: {self.db_path}")
def init_database(self):
"""Initialize database"""
with self.lock:
try: # New: try-except to handle database initialization errors
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create positions table (no changes)
cursor.execute('''
CREATE TABLE IF NOT EXISTS positions (
symbol TEXT PRIMARY KEY,
base_amount REAL NOT NULL DEFAULT 0,
entry_price REAL NOT NULL DEFAULT 0,
stop_loss_price REAL,
take_profit_price REAL,
created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# Create trade records table (no changes)
cursor.execute('''
CREATE TABLE IF NOT EXISTS trade_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL,
side TEXT NOT NULL,
amount REAL NOT NULL,
price REAL NOT NULL,
order_id TEXT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# Create dynamic stops table (replacement: added highest_price and updated_at fields, no default values)
cursor.execute('''
CREATE TABLE IF NOT EXISTS dynamic_stops (
symbol TEXT PRIMARY KEY,
initial_price REAL NOT NULL,
current_stop_loss REAL,
current_take_profit REAL,
trailing_percent REAL DEFAULT 0.03,
highest_price REAL,
created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP
)
''')
# New: Migration step - check and add missing columns
# Get dynamic_stops table column information
cursor.execute("PRAGMA table_info(dynamic_stops)")
columns = [col[1] for col in cursor.fetchall()] # col[1] is column name
# If highest_price doesn't exist, add it (no default value)
if 'highest_price' not in columns:
cursor.execute('ALTER TABLE dynamic_stops ADD COLUMN highest_price REAL')
logger.info("Migration: Added dynamic_stops.highest_price column")
# If updated_at doesn't exist, add it (no default value)
if 'updated_at' not in columns:
cursor.execute('ALTER TABLE dynamic_stops ADD COLUMN updated_at TIMESTAMP')
logger.info("Migration: Added dynamic_stops.updated_at column")
# New: Set initial values for existing records
# Set highest_price = initial_price if NULL
cursor.execute('''
UPDATE dynamic_stops
SET highest_price = initial_price
WHERE highest_price IS NULL
''')
# Set updated_at = CURRENT_TIMESTAMP if NULL
cursor.execute('''
UPDATE dynamic_stops
SET updated_at = CURRENT_TIMESTAMP
WHERE updated_at IS NULL
''')
conn.commit()
except Exception as e:
logger.error(f"Database initialization error: {e}")
finally:
conn.close()
def save_position(self, symbol, base_amount, entry_price, stop_loss_price=None, take_profit_price=None):
"""Save position status"""
with self.lock:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO positions
(symbol, base_amount, entry_price, stop_loss_price, take_profit_price, updated_time)
VALUES (?, ?, ?, ?, ?, ?)
''', (symbol, base_amount, entry_price, stop_loss_price, take_profit_price, datetime.now()))
conn.commit()
conn.close()
def load_position(self, symbol):
"""Load position status"""
with self.lock:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
SELECT base_amount, entry_price, stop_loss_price, take_profit_price
FROM positions WHERE symbol = ?
''', (symbol,))
result = cursor.fetchone()
conn.close()
if result:
return {
'base_amount': result[0],
'entry_price': result[1],
'stop_loss_price': result[2],
'take_profit_price': result[3]
}
return None
def delete_position(self, symbol):
"""Delete position record"""
with self.lock:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('DELETE FROM positions WHERE symbol = ?', (symbol,))
conn.commit()
conn.close()
def save_trade_record(self, symbol, side, amount, price, order_id):
"""Save trade record"""
with self.lock:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO trade_records
(symbol, side, amount, price, order_id)
VALUES (?, ?, ?, ?, ?)
''', (symbol, side, amount, price, order_id))
conn.commit()
conn.close()
def set_dynamic_stop(self, symbol, initial_price, trailing_percent=0.03, multiplier=2):
"""Set dynamic stop-loss"""
with self.lock:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
current_stop_loss = initial_price * (1 - trailing_percent)
current_take_profit = initial_price * (1 + trailing_percent * multiplier)
highest_price = initial_price
# Manually set updated_at = CURRENT_TIMESTAMP during INSERT
cursor.execute('''
INSERT OR REPLACE INTO dynamic_stops
(symbol, initial_price, current_stop_loss, current_take_profit, trailing_percent, highest_price, updated_at)
VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
''', (symbol, initial_price, current_stop_loss, current_take_profit, trailing_percent, highest_price))
conn.commit()
conn.close()
def update_dynamic_stop(self, symbol, current_price, trailing_percent=0.03, multiplier=2):
"""Update dynamic stop-loss (trailing take-profit)"""
with self.lock:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# First query current highest_price and current_stop_loss
cursor.execute('SELECT highest_price, current_stop_loss FROM dynamic_stops WHERE symbol = ?', (symbol,))
result = cursor.fetchone()
if result:
highest_price, current_stop_loss = result
if highest_price is None:
highest_price = current_price
logger.warning(f"Dynamic stop record {symbol} highest_price is empty, initializing with current price")
if current_stop_loss is None:
current_stop_loss = current_price * (1 - trailing_percent)
logger.warning(f"Dynamic stop record {symbol} current_stop_loss is empty, recalculating")
# Only update if current_price > highest_price
if current_price > highest_price:
highest_price = current_price
updated = True
# Calculate new_stop_loss based on highest_price (only moves upward)
new_stop_loss = highest_price * (1 - trailing_percent)
if new_stop_loss > current_stop_loss:
current_stop_loss = new_stop_loss
updated = True
# Update take_profit based on highest_price
current_take_profit = highest_price * (1 + trailing_percent * multiplier)
if updated:
# Manually set updated_at = CURRENT_TIMESTAMP during UPDATE
cursor.execute('''
UPDATE dynamic_stops
SET highest_price = ?, current_stop_loss = ?, current_take_profit = ?, updated_at = CURRENT_TIMESTAMP
WHERE symbol = ?
''', (highest_price, current_stop_loss, current_take_profit, symbol))
conn.commit()
conn.close()
return current_stop_loss if result else None
def get_dynamic_stop(self, symbol):
"""Get dynamic stop-loss information"""
with self.lock:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Handle potentially missing highest_price during SELECT (use COALESCE to fall back to initial_price)
cursor.execute('''
SELECT current_stop_loss, current_take_profit, trailing_percent,
COALESCE(highest_price, initial_price) AS highest_price
FROM dynamic_stops WHERE symbol = ?
''', (symbol,))
result = cursor.fetchone()
conn.close()
if result:
return {
'current_stop_loss': result[0],
'current_take_profit': result[1],
'trailing_percent': result[2],
'highest_price': result[3]
}
return None
def delete_dynamic_stop(self, symbol):
"""Delete dynamic stop-loss record"""
with self.lock:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('DELETE FROM dynamic_stops WHERE symbol = ?', (symbol,))
conn.commit()
conn.close()