OKXTrading/riskmanager.py

340 lines
14 KiB
Python

import logging
import os
import time
from datetime import datetime
from database_manager import DatabaseManager
logger = logging.getLogger(__name__)
class RiskManager:
"""Risk Control Manager (Dynamic Position Management Version)"""
def __init__(self, api):
self.api = api
self.max_total_position = float(os.getenv('MAX_TOTAL_POSITION', '1')) # Maximum total position 100%
self.max_single_position = float(os.getenv('MAX_SINGLE_POSITION', '0.3')) # Single currency maximum position 30%
self.stop_loss = float(os.getenv('STOP_LOSS', '0.05')) # Stop-loss 5%
self.take_profit = float(os.getenv('TAKE_PROFIT', '0.15')) # Take-profit 15%
self.max_daily_trades = int(os.getenv('MAX_DAILY_TRADES', '10')) # Maximum daily trades
self.dust_threshold_value = 1.0 # Dust position value threshold (USDT)
self.low_position_ratio = 0.05 # Low position ratio threshold (5%)
self.trailing_stop_percent = float(os.getenv('TRAILING_STOP_PERCENT', '0.03')) # Trailing stop percentage
self.monitor_interval = int(os.getenv('MONITOR_INTERVAL', '300')) # Monitoring interval seconds
# Dynamic position parameters
self.volatility_adjustment = True
self.trend_adjustment = True
self.market_sentiment_adjustment = True
# Trade records
self.daily_trade_count = {}
self.db = DatabaseManager()
def calculate_total_assets(self):
"""Calculate total assets (in USDT)"""
try:
# Get all currency balances
balances = self.api.get_currency_balances()
if not balances:
logger.error("Unable to get currency balances")
return 0.0, 0.0, {}
total_usdt_value = 0.0
position_details = {}
# Calculate USDT balance (including available and frozen)
usdt_balance = balances.get('USDT', {})
usdt_avail = usdt_balance.get('amount', 0.0)
usdt_frozen = usdt_balance.get('frozen', 0.0)
total_usdt = usdt_avail + usdt_frozen
total_usdt_value += total_usdt
# Calculate value of other currencies
for currency, balance_info in balances.items():
if currency != 'USDT':
avail = balance_info.get('amount', 0.0)
frozen = balance_info.get('frozen', 0.0)
total_amount = avail + frozen
if total_amount > 0:
symbol = f"{currency}-USDT"
current_price = self.api.get_current_price(symbol)
if current_price:
value = total_amount * current_price
total_usdt_value += value
position_details[symbol] = {
'amount': total_amount, # Total quantity (available + frozen)
'value': value,
'current_price': current_price
}
else:
logger.warning(f"Unable to get {currency} price, skipping value calculation")
logger.debug(f"Total asset calculation: USDT={total_usdt:.2f}(available{usdt_avail:.2f}+frozen{usdt_frozen:.2f}), position value={total_usdt_value - total_usdt:.2f}, total={total_usdt_value:.2f}")
return total_usdt_value, usdt_avail, position_details
except Exception as e:
logger.error(f"Error calculating total assets: {e}")
return 0.0, 0.0, {}
def get_position_ratio(self, symbol):
"""Get specified currency's position ratio"""
try:
total_assets, usdt_balance, positions = self.calculate_total_assets()
if total_assets == 0:
return 0.0
position_value = positions.get(symbol, {}).get('value', 0.0)
return position_value / total_assets
except Exception as e:
logger.error(f"Error calculating position ratio: {e}")
return 0.0
def get_available_usdt_ratio(self):
"""Get available USDT ratio"""
try:
total_assets, usdt_balance, _ = self.calculate_total_assets()
if total_assets == 0:
return 0.0
return usdt_balance / total_assets
except Exception as e:
logger.error(f"Error calculating USDT ratio: {e}")
return 0.0
def get_position_size(self, symbol, confidence, current_price):
"""Calculate position size based on confidence level and USDT availability"""
try:
total_assets, usdt_balance, positions = self.calculate_total_assets()
if total_assets == 0:
logger.warning("Total assets is 0, unable to calculate position")
return 0.0
# Get current position value
current_pos_value = positions.get(symbol, {}).get('value', 0.0)
# Calculate USDT available ratio
usdt_ratio = usdt_balance / total_assets
# Adjust maximum position based on USDT availability
adjusted_max_single = self.max_single_position * min(1.0, usdt_ratio * 3) # More USDT allows larger positions
# Calculate maximum additional position
max_single_add = max(0, total_assets * adjusted_max_single - current_pos_value)
max_total_add = max(0, total_assets * self.max_total_position - (total_assets - usdt_balance))
max_add = min(max_single_add, max_total_add, usdt_balance)
# Adjust based on confidence level
multiplier = {
'HIGH': 0.8, # High confidence uses 80% of available amount
'MEDIUM': 0.5, # Medium confidence uses 50%
'LOW': 0.2 # Low confidence uses 20%
}.get(confidence, 0.2)
position_size = max_add * multiplier
# Ensure doesn't exceed available USDT balance
position_size = min(position_size, usdt_balance)
logger.info(f"Position calculation: {symbol}, total assets=${total_assets:.2f}, USDT=${usdt_balance:.2f}, "
f"current position=${current_pos_value:.2f}, suggested position=${position_size:.2f}")
return position_size
except Exception as e:
logger.error(f"Error calculating position size: {e}")
return 0.0
def get_dynamic_position_size(self, symbol, confidence, technical_indicators, current_price):
"""Dynamic position calculation"""
try:
# Base position calculation
base_size = self.get_position_size(symbol, confidence, current_price)
if base_size == 0:
return 0.0
# Get market state parameters
volatility = technical_indicators.get('volatility', 0)
trend_strength = technical_indicators.get('trend_strength', 0)
market_state = self.assess_market_state(technical_indicators)
# Volatility adjustment (high volatility reduces position)
volatility_factor = self._calculate_volatility_factor(volatility)
# Trend strength adjustment (strong trend increases position)
trend_factor = self._calculate_trend_factor(trend_strength)
# Market state adjustment
market_factor = self._calculate_market_factor(market_state)
# Confidence level adjustment
confidence_factor = self._calculate_confidence_factor(confidence)
# Calculate dynamic position
dynamic_size = base_size * volatility_factor * trend_factor * market_factor * confidence_factor
logger.info(f"Dynamic position calculation: {symbol}, base={base_size:.2f}, "
f"volatility factor={volatility_factor:.2f}, trend factor={trend_factor:.2f}, "
f"market factor={market_factor:.2f}, confidence factor={confidence_factor:.2f}, "
f"final={dynamic_size:.2f}")
return dynamic_size
except Exception as e:
logger.error(f"Dynamic position calculation error: {e}")
return self.get_position_size(symbol, confidence, current_price)
def _calculate_volatility_factor(self, volatility):
"""Calculate volatility adjustment factor"""
if not self.volatility_adjustment:
return 1.0
# Annualized volatility conversion and adjustment
if volatility > 0.8: # 80%+ annualized volatility
return 0.3
elif volatility > 0.6: # 60-80%
return 0.5
elif volatility > 0.4: # 40-60%
return 0.7
elif volatility > 0.2: # 20-40%
return 0.9
else: # <20%
return 1.0
def _calculate_trend_factor(self, trend_strength):
"""Calculate trend strength adjustment factor"""
if not self.trend_adjustment:
return 1.0
if trend_strength > 0.7: # Strong trend
return 1.3
elif trend_strength > 0.4: # Medium trend
return 1.1
elif trend_strength > 0.2: # Weak trend
return 1.0
else: # No trend
return 0.8
def _calculate_market_factor(self, market_state):
"""Calculate market state adjustment factor"""
if not self.market_sentiment_adjustment:
return 1.0
factors = {
'STRONG_BULL': 1.2,
'BULL': 1.1,
'NEUTRAL': 1.0,
'BEAR': 0.7,
'STRONG_BEAR': 0.5
}
return factors.get(market_state, 1.0)
def _calculate_confidence_factor(self, confidence):
"""Calculate confidence level adjustment factor"""
factors = {
'HIGH': 1.0,
'MEDIUM': 0.7,
'LOW': 0.4
}
return factors.get(confidence, 0.5)
def assess_market_state(self, technical_indicators):
"""Assess market state"""
try:
# Get key indicators
rsi = technical_indicators.get('rsi', 50)
if hasattr(rsi, '__len__'):
rsi = rsi.iloc[-1]
macd_line, signal_line, _ = technical_indicators.get('macd', (0, 0, 0))
if hasattr(macd_line, '__len__'):
macd_line, signal_line = macd_line.iloc[-1], signal_line.iloc[-1]
sma_20 = technical_indicators.get('sma_20', 0)
sma_50 = technical_indicators.get('sma_50', 0)
current_price = technical_indicators.get('current_price', 0)
if hasattr(sma_20, '__len__'):
sma_20, sma_50 = sma_20.iloc[-1], sma_50.iloc[-1]
# Calculate bullish signal score
bull_signals = 0
total_signals = 0
# RSI signal
if rsi > 50: bull_signals += 1
total_signals += 1
# MACD signal
if macd_line > signal_line: bull_signals += 1
total_signals += 1
# Moving average signal
if current_price > sma_20: bull_signals += 1
if current_price > sma_50: bull_signals += 1
total_signals += 2
# Determine market state
bull_ratio = bull_signals / total_signals
if bull_ratio >= 0.8:
return 'STRONG_BULL'
elif bull_ratio >= 0.6:
return 'BULL'
elif bull_ratio >= 0.4:
return 'NEUTRAL'
elif bull_ratio >= 0.2:
return 'BEAR'
else:
return 'STRONG_BEAR'
except Exception as e:
logger.error(f"Error assessing market state: {e}")
return 'NEUTRAL'
def is_dust_position(self, symbol, current_price, base_amount):
"""Check if dust position"""
hold_value = base_amount * current_price if current_price else 0
return hold_value < self.dust_threshold_value
def is_low_position_ratio(self, symbol, total_assets, current_price, base_amount):
"""Check if position ratio is low"""
hold_value = base_amount * current_price if current_price else 0
return (hold_value / total_assets) < self.low_position_ratio if total_assets > 0 else False
def can_trade(self, symbol, amount, current_price):
"""Check if can trade"""
try:
# Check daily trade count
today = datetime.now().date()
if symbol not in self.daily_trade_count:
self.daily_trade_count[symbol] = {'date': today, 'count': 0}
if self.daily_trade_count[symbol]['date'] != today:
self.daily_trade_count[symbol] = {'date': today, 'count': 0}
if self.daily_trade_count[symbol]['count'] >= self.max_daily_trades:
logger.warning(f"{symbol} daily trade count reached {self.max_daily_trades} limit")
return False
return True
except Exception as e:
logger.error(f"Error checking trade conditions: {e}")
return False
def increment_trade_count(self, symbol):
"""Increment trade count"""
try:
today = datetime.now().date()
if symbol in self.daily_trade_count and self.daily_trade_count[symbol]['date'] == today:
self.daily_trade_count[symbol]['count'] += 1
except Exception as e:
logger.error(f"Error incrementing trade count: {e}")