import logging import os import time from database_manager import DatabaseManager logger = logging.getLogger(__name__) class RiskManager: """Risk Control Manager (Dynamic Position Management Version)""" def __init__(self, api): self.api = api self.max_total_position = float(os.getenv('MAX_TOTAL_POSITION', '1')) # Maximum total position 100% self.max_single_position = float(os.getenv('MAX_SINGLE_POSITION', '0.3')) # Single currency maximum position 30% self.stop_loss = float(os.getenv('STOP_LOSS', '0.05')) # Stop-loss 5% self.take_profit = float(os.getenv('TAKE_PROFIT', '0.15')) # Take-profit 15% self.max_daily_trades = int(os.getenv('MAX_DAILY_TRADES', '10')) # Maximum daily trades self.dust_threshold_value = 1.0 # Dust position value threshold (USDT) self.low_position_ratio = 0.05 # Low position ratio threshold (5%) self.trailing_stop_percent = float(os.getenv('TRAILING_STOP_PERCENT', '0.03')) # Trailing stop percentage self.monitor_interval = int(os.getenv('MONITOR_INTERVAL', '300')) # Monitoring interval seconds # Dynamic position parameters self.volatility_adjustment = True self.trend_adjustment = True self.market_sentiment_adjustment = True # Trade records self.daily_trade_count = {} self.db = DatabaseManager() def calculate_total_assets(self): """Calculate total assets (in USDT)""" try: # Get all currency balances balances = self.api.get_currency_balances() if not balances: logger.error("Unable to get currency balances") return 0.0, 0.0, {} total_usdt_value = 0.0 position_details = {} # Calculate USDT balance (including available and frozen) usdt_balance = balances.get('USDT', {}) usdt_avail = usdt_balance.get('amount', 0.0) usdt_frozen = usdt_balance.get('frozen', 0.0) total_usdt = usdt_avail + usdt_frozen total_usdt_value += total_usdt # Calculate value of other currencies for currency, balance_info in balances.items(): if currency != 'USDT': avail = balance_info.get('amount', 0.0) frozen = balance_info.get('frozen', 0.0) total_amount = avail + frozen if total_amount > 0: symbol = f"{currency}-USDT" current_price = self.api.get_current_price(symbol) if current_price: value = total_amount * current_price total_usdt_value += value position_details[symbol] = { 'amount': total_amount, # Total quantity (available + frozen) 'value': value, 'current_price': current_price } else: logger.warning(f"Unable to get {currency} price, skipping value calculation") logger.debug(f"Total asset calculation: USDT={total_usdt:.2f}(available{usdt_avail:.2f}+frozen{usdt_frozen:.2f}), position value={total_usdt_value - total_usdt:.2f}, total={total_usdt_value:.2f}") return total_usdt_value, usdt_avail, position_details except Exception as e: logger.error(f"Error calculating total assets: {e}") return 0.0, 0.0, {} def get_position_ratio(self, symbol): """Get specified currency's position ratio""" try: total_assets, usdt_balance, positions = self.calculate_total_assets() if total_assets == 0: return 0.0 position_value = positions.get(symbol, {}).get('value', 0.0) return position_value / total_assets except Exception as e: logger.error(f"Error calculating position ratio: {e}") return 0.0 def get_available_usdt_ratio(self): """Get available USDT ratio""" try: total_assets, usdt_balance, _ = self.calculate_total_assets() if total_assets == 0: return 0.0 return usdt_balance / total_assets except Exception as e: logger.error(f"Error calculating USDT ratio: {e}") return 0.0 def get_position_size(self, symbol, confidence, current_price): """Calculate position size based on confidence level and USDT availability""" try: total_assets, usdt_balance, positions = self.calculate_total_assets() if total_assets == 0: logger.warning("Total assets is 0, unable to calculate position") return 0.0 # Get current position value current_pos_value = positions.get(symbol, {}).get('value', 0.0) # Calculate USDT available ratio usdt_ratio = usdt_balance / total_assets # Adjust maximum position based on USDT availability adjusted_max_single = self.max_single_position * min(1.0, usdt_ratio * 3) # More USDT allows larger positions # Calculate maximum additional position max_single_add = max(0, total_assets * adjusted_max_single - current_pos_value) max_total_add = max(0, total_assets * self.max_total_position - (total_assets - usdt_balance)) max_add = min(max_single_add, max_total_add, usdt_balance) # Adjust based on confidence level multiplier = { 'HIGH': 0.8, # High confidence uses 80% of available amount 'MEDIUM': 0.5, # Medium confidence uses 50% 'LOW': 0.2 # Low confidence uses 20% }.get(confidence, 0.2) position_size = max_add * multiplier # Ensure doesn't exceed available USDT balance position_size = min(position_size, usdt_balance) logger.info(f"Position calculation: {symbol}, total assets=${total_assets:.2f}, USDT=${usdt_balance:.2f}, " f"current position=${current_pos_value:.2f}, suggested position=${position_size:.2f}") return position_size except Exception as e: logger.error(f"Error calculating position size: {e}") return 0.0 def get_dynamic_position_size(self, symbol, confidence, technical_indicators, current_price): """Dynamic position calculation""" try: # Base position calculation base_size = self.get_position_size(symbol, confidence, current_price) if base_size == 0: return 0.0 # Get market state parameters volatility = technical_indicators.get('volatility', 0) trend_strength = technical_indicators.get('trend_strength', 0) market_state = self.assess_market_state(technical_indicators) # Volatility adjustment (high volatility reduces position) volatility_factor = self._calculate_volatility_factor(volatility) # Trend strength adjustment (strong trend increases position) trend_factor = self._calculate_trend_factor(trend_strength) # Market state adjustment market_factor = self._calculate_market_factor(market_state) # Confidence level adjustment confidence_factor = self._calculate_confidence_factor(confidence) # Calculate dynamic position dynamic_size = base_size * volatility_factor * trend_factor * market_factor * confidence_factor logger.info(f"Dynamic position calculation: {symbol}, base={base_size:.2f}, " f"volatility factor={volatility_factor:.2f}, trend factor={trend_factor:.2f}, " f"market factor={market_factor:.2f}, confidence factor={confidence_factor:.2f}, " f"final={dynamic_size:.2f}") return dynamic_size except Exception as e: logger.error(f"Dynamic position calculation error: {e}") return self.get_position_size(symbol, confidence, current_price) def _calculate_volatility_factor(self, volatility): """Calculate volatility adjustment factor""" if not self.volatility_adjustment: return 1.0 # Annualized volatility conversion and adjustment if volatility > 0.8: # 80%+ annualized volatility return 0.3 elif volatility > 0.6: # 60-80% return 0.5 elif volatility > 0.4: # 40-60% return 0.7 elif volatility > 0.2: # 20-40% return 0.9 else: # <20% return 1.0 def _calculate_trend_factor(self, trend_strength): """Calculate trend strength adjustment factor""" if not self.trend_adjustment: return 1.0 if trend_strength > 0.7: # Strong trend return 1.3 elif trend_strength > 0.4: # Medium trend return 1.1 elif trend_strength > 0.2: # Weak trend return 1.0 else: # No trend return 0.8 def _calculate_market_factor(self, market_state): """Calculate market state adjustment factor""" if not self.market_sentiment_adjustment: return 1.0 factors = { 'STRONG_BULL': 1.2, 'BULL': 1.1, 'NEUTRAL': 1.0, 'BEAR': 0.7, 'STRONG_BEAR': 0.5 } return factors.get(market_state, 1.0) def _calculate_confidence_factor(self, confidence): """Calculate confidence level adjustment factor""" factors = { 'HIGH': 1.0, 'MEDIUM': 0.7, 'LOW': 0.4 } return factors.get(confidence, 0.5) def assess_market_state(self, technical_indicators): """Assess market state""" try: # Get key indicators rsi = technical_indicators.get('rsi', 50) if hasattr(rsi, '__len__'): rsi = rsi.iloc[-1] macd_line, signal_line, _ = technical_indicators.get('macd', (0, 0, 0)) if hasattr(macd_line, '__len__'): macd_line, signal_line = macd_line.iloc[-1], signal_line.iloc[-1] sma_20 = technical_indicators.get('sma_20', 0) sma_50 = technical_indicators.get('sma_50', 0) current_price = technical_indicators.get('current_price', 0) if hasattr(sma_20, '__len__'): sma_20, sma_50 = sma_20.iloc[-1], sma_50.iloc[-1] # Calculate bullish signal score bull_signals = 0 total_signals = 0 # RSI signal if rsi > 50: bull_signals += 1 total_signals += 1 # MACD signal if macd_line > signal_line: bull_signals += 1 total_signals += 1 # Moving average signal if current_price > sma_20: bull_signals += 1 if current_price > sma_50: bull_signals += 1 total_signals += 2 # Determine market state bull_ratio = bull_signals / total_signals if bull_ratio >= 0.8: return 'STRONG_BULL' elif bull_ratio >= 0.6: return 'BULL' elif bull_ratio >= 0.4: return 'NEUTRAL' elif bull_ratio >= 0.2: return 'BEAR' else: return 'STRONG_BEAR' except Exception as e: logger.error(f"Error assessing market state: {e}") return 'NEUTRAL' def is_dust_position(self, symbol, current_price, base_amount): """Check if dust position""" hold_value = base_amount * current_price if current_price else 0 return hold_value < self.dust_threshold_value def is_low_position_ratio(self, symbol, total_assets, current_price, base_amount): """Check if position ratio is low""" hold_value = base_amount * current_price if current_price else 0 return (hold_value / total_assets) < self.low_position_ratio if total_assets > 0 else False def can_trade(self, symbol, amount, current_price): """Check if can trade""" try: # Check daily trade count today = datetime.now().date() if symbol not in self.daily_trade_count: self.daily_trade_count[symbol] = {'date': today, 'count': 0} if self.daily_trade_count[symbol]['date'] != today: self.daily_trade_count[symbol] = {'date': today, 'count': 0} if self.daily_trade_count[symbol]['count'] >= self.max_daily_trades: logger.warning(f"{symbol} daily trade count reached {self.max_daily_trades} limit") return False return True except Exception as e: logger.error(f"Error checking trade conditions: {e}") return False def increment_trade_count(self, symbol): """Increment trade count""" try: today = datetime.now().date() if symbol in self.daily_trade_count and self.daily_trade_count[symbol]['date'] == today: self.daily_trade_count[symbol]['count'] += 1 except Exception as e: logger.error(f"Error incrementing trade count: {e}")