399 lines
16 KiB
Python
399 lines
16 KiB
Python
import os
|
|
import okx.Account as Account
|
|
import okx.MarketData as MarketData
|
|
import okx.Trade as Trade
|
|
from apiratelimiter import APIRateLimiter
|
|
import pandas as pd
|
|
import numpy as np
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
tdmode = "cross" # cross for demo account
|
|
|
|
class OKXAPIClient:
|
|
"""OKX API Client (using official SDK)"""
|
|
|
|
def __init__(self):
|
|
self.api_key = os.getenv('OKX_API_KEY')
|
|
self.secret_key = os.getenv('OKX_SECRET_KEY')
|
|
self.password = os.getenv('OKX_PASSWORD')
|
|
|
|
if not all([self.api_key, self.secret_key, self.password]):
|
|
raise ValueError("Please set OKX API key and password")
|
|
|
|
# Initialize OKX SDK client - using live trading environment
|
|
Flag = "1" # Live trading environment
|
|
self.account_api = Account.AccountAPI(self.api_key, self.secret_key, self.password, False, Flag)
|
|
self.market_api = MarketData.MarketAPI(self.api_key, self.secret_key, self.password, False, Flag)
|
|
self.trade_api = Trade.TradeAPI(self.api_key, self.secret_key, self.password, False, Flag)
|
|
|
|
# API rate limiting
|
|
self.rate_limiter = APIRateLimiter(2)
|
|
self.log_http = False
|
|
|
|
# Cache instrument info
|
|
self.instrument_cache = {}
|
|
|
|
def get_market_data(self, symbol, timeframe='1H', limit=200):
|
|
"""Get market data"""
|
|
self.rate_limiter.wait("market_data")
|
|
|
|
try:
|
|
result = self.market_api.get_candlesticks(
|
|
instId=symbol,
|
|
bar=timeframe,
|
|
limit=str(limit)
|
|
)
|
|
|
|
if self.log_http:
|
|
logger.debug(f"HTTP Request: GET {symbol} {timeframe} {result.get('code', 'Unknown')}")
|
|
|
|
# Error checking
|
|
if 'code' not in result or result['code'] != '0':
|
|
error_msg = result.get('msg', 'Unknown error')
|
|
error_code = result.get('code', 'Unknown code')
|
|
logger.error(f"Failed to get {symbol} market data: {error_msg} (code: {error_code})")
|
|
return None
|
|
|
|
# Check if data exists
|
|
if 'data' not in result or not result['data']:
|
|
logger.warning(f"{symbol} market data is empty")
|
|
return None
|
|
|
|
# Create DataFrame
|
|
columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume', 'volCcy', 'volCcyQuote', 'confirm']
|
|
df = pd.DataFrame(result['data'], columns=columns)
|
|
|
|
# Convert data types
|
|
df['timestamp'] = pd.to_datetime(df['timestamp'].astype(np.int64), unit='ms')
|
|
numeric_cols = ['open', 'high', 'low', 'close', 'volume', 'volCcy', 'volCcyQuote']
|
|
df[numeric_cols] = df[numeric_cols].apply(pd.to_numeric)
|
|
|
|
return df.sort_values('timestamp')
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting market data: {e}")
|
|
return None
|
|
|
|
def get_account_balance(self):
|
|
"""Get account balance (USDT)"""
|
|
self.rate_limiter.wait("balance")
|
|
|
|
try:
|
|
result = self.account_api.get_account_balance()
|
|
|
|
if self.log_http:
|
|
logger.debug(f"HTTP Request: GET balance {result['code']}")
|
|
|
|
# Correction: code "0" indicates success
|
|
if result['code'] != '0':
|
|
logger.error(f"Failed to get balance: {result['msg']} (code: {result['code']})")
|
|
return None
|
|
|
|
# Extract USDT balance
|
|
for currency in result['data'][0]['details']:
|
|
if currency['ccy'] == 'USDT':
|
|
return float(currency['availBal'])
|
|
|
|
return 0.0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting balance: {e}")
|
|
return None
|
|
|
|
def get_currency_balances(self):
|
|
"""Get all currency balances"""
|
|
self.rate_limiter.wait("balances")
|
|
|
|
try:
|
|
result = self.account_api.get_account_balance()
|
|
|
|
# Correction: code "0" indicates success
|
|
if result['code'] != '0':
|
|
logger.error(f"Failed to get balance: {result['msg']} (code: {result['code']})")
|
|
return {}
|
|
|
|
# Check response data structure
|
|
if not result['data'] or len(result['data']) == 0:
|
|
logger.error("No balance data in API response")
|
|
return {}
|
|
|
|
# Check if details field exists
|
|
if 'details' not in result['data'][0]:
|
|
logger.error("No details field in API response")
|
|
return {}
|
|
|
|
balances = {}
|
|
for currency in result['data'][0]['details']:
|
|
if float(currency.get('availBal', 0)) > 0:
|
|
balances[currency['ccy']] = {
|
|
'amount': float(currency.get('availBal', 0)),
|
|
'frozen': float(currency.get('frozenBal', 0))
|
|
}
|
|
|
|
return balances
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting currency balances: {e}")
|
|
return {}
|
|
|
|
def get_positions(self):
|
|
"""Get exchange position information (based on currency balances)"""
|
|
try:
|
|
# Get all currency balances
|
|
balances = self.get_currency_balances()
|
|
if not balances:
|
|
return {}
|
|
|
|
# Filter out non-USDT currencies as positions
|
|
positions = {}
|
|
for currency, balance in balances.items():
|
|
if currency != 'USDT' and balance['amount'] > 0:
|
|
# Construct trading pair symbol
|
|
symbol = f"{currency}-USDT"
|
|
# Get current price to calculate position value
|
|
current_price = self.get_current_price(symbol)
|
|
if current_price:
|
|
positions[symbol] = {
|
|
'amount': balance['amount'],
|
|
'value': balance['amount'] * current_price,
|
|
'avg_price': 0.0 # Spot positions don't have average price concept
|
|
}
|
|
else:
|
|
positions[symbol] = {
|
|
'amount': balance['amount'],
|
|
'value': 0.0,
|
|
'avg_price': 0.0
|
|
}
|
|
|
|
return positions
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting positions: {e}")
|
|
return {}
|
|
|
|
def get_current_price(self, symbol):
|
|
"""Get current price"""
|
|
self.rate_limiter.wait("price")
|
|
|
|
try:
|
|
result = self.market_api.get_ticker(instId=symbol)
|
|
|
|
# Correction: code "0" indicates success
|
|
if result['code'] != '0':
|
|
logger.error(f"Failed to get price: {result['msg']} (code: {result['code']})")
|
|
return None
|
|
|
|
return float(result['data'][0]['last'])
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting price: {e}")
|
|
return None
|
|
|
|
def get_instrument_info(self, symbol):
|
|
"""Get trading pair information"""
|
|
if symbol in self.instrument_cache:
|
|
return self.instrument_cache[symbol]
|
|
|
|
self.rate_limiter.wait("instrument")
|
|
|
|
try:
|
|
result = self.account_api.get_instruments(instType='SPOT')
|
|
|
|
if result['code'] != '0':
|
|
logger.error(f"Failed to get instrument: {result['msg']} (code: {result['code']})")
|
|
return None, None
|
|
|
|
# Find specified trading pair
|
|
for inst in result['data']:
|
|
if inst['instId'] == symbol:
|
|
min_sz = float(inst['minSz'])
|
|
lot_sz = float(inst['lotSz'])
|
|
logger.debug(f"Got {symbol} precision: minSz={min_sz}, lotSz={lot_sz}")
|
|
self.instrument_cache[symbol] = (min_sz, lot_sz)
|
|
return min_sz, lot_sz
|
|
|
|
logger.error(f"Trading pair not found: {symbol}")
|
|
return None, None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting instrument info: {e}")
|
|
return None, None
|
|
|
|
def get_default_min_size(self, symbol):
|
|
"""Get default minimum order size"""
|
|
# Set default minimum order size based on currency
|
|
defaults = {
|
|
'BTC-USDT': 0.0001,
|
|
'ETH-USDT': 0.001,
|
|
'SOL-USDT': 0.01,
|
|
'XRP-USDT': 1.0
|
|
}
|
|
return defaults.get(symbol, 0.01) # Default 0.01
|
|
|
|
def create_order(self, symbol, side, amount, retries=3):
|
|
"""Create order"""
|
|
for attempt in range(retries):
|
|
try:
|
|
self.rate_limiter.wait("order")
|
|
|
|
# Parse trading pair symbol
|
|
parts = symbol.split('-')
|
|
if len(parts) != 2:
|
|
logger.error(f"Invalid trading pair format: {symbol}")
|
|
return None
|
|
|
|
base_currency, quote_currency = parts
|
|
|
|
# Adjust parameters based on buy/sell direction
|
|
if side == 'buy':
|
|
# When buying, amount is quote currency amount (USDT amount)
|
|
# Use amount-based order placement
|
|
order_params = {
|
|
'instId': symbol,
|
|
'tdMode': tdmode,
|
|
'side': 'buy',
|
|
'ordType': 'market',
|
|
'sz': str(amount), # Quote currency amount
|
|
'tgtCcy': 'quote_ccy' # Specify sz as quote currency
|
|
}
|
|
logger.info(f"[{symbol}] Create buy order: amount={amount:.2f} {quote_currency}")
|
|
|
|
else:
|
|
# When selling, amount is base currency quantity
|
|
# Get precision info and adjust quantity
|
|
min_sz, lot_sz = self.get_instrument_info(symbol)
|
|
if min_sz is None:
|
|
min_sz = self.get_default_min_size(symbol)
|
|
if lot_sz is None:
|
|
lot_sz = min_sz
|
|
|
|
# Adjust quantity to appropriate precision
|
|
if lot_sz > 0:
|
|
amount = (amount / lot_sz) * lot_sz
|
|
|
|
amount_str = f"{amount:.10f}"
|
|
|
|
order_params = {
|
|
'instId': symbol,
|
|
'tdMode': tdmode,
|
|
'side': 'sell',
|
|
'ordType': 'market',
|
|
'sz': amount_str # Base currency quantity
|
|
}
|
|
logger.info(f"[{symbol}] Create sell order: quantity={amount_str} {base_currency}")
|
|
|
|
# Use SDK to create order
|
|
result = self.trade_api.place_order(**order_params)
|
|
|
|
if self.log_http:
|
|
logger.debug(f"HTTP Request: POST create order {result['code']}")
|
|
|
|
# Check API response
|
|
if result['code'] != '0':
|
|
logger.error(f"Failed to create order: {result['msg']} (code: {result['code']})")
|
|
if 'data' in result and len(result['data']) > 0:
|
|
for item in result['data']:
|
|
logger.error(f"Detailed error: {item.get('sMsg', 'Unknown')} (sCode: {item.get('sCode', 'Unknown')})")
|
|
# Specific error handling
|
|
if result['code'] == '50113': # Insufficient permissions
|
|
logger.error("API key may not have trading permissions, please check API key settings")
|
|
elif result['code'] == '51020': # Minimum order amount
|
|
logger.error("Order amount below exchange minimum requirement")
|
|
if attempt < retries - 1:
|
|
wait_time = 2 ** attempt
|
|
time.sleep(wait_time)
|
|
continue
|
|
|
|
# Check order status
|
|
if len(result['data']) > 0:
|
|
order_data = result['data'][0]
|
|
if order_data.get('sCode') != '0':
|
|
logger.error(f"Order creation failed: {order_data.get('sMsg', 'Unknown error')} (sCode: {order_data.get('sCode', 'Unknown')})")
|
|
if attempt < retries - 1:
|
|
wait_time = 2 ** attempt
|
|
time.sleep(wait_time)
|
|
continue
|
|
|
|
order_id = order_data.get('ordId')
|
|
if order_id:
|
|
logger.info(f"Order created successfully: {order_id}")
|
|
return order_id
|
|
else:
|
|
logger.error("Order ID is empty")
|
|
if attempt < retries - 1:
|
|
wait_time = 2 ** attempt
|
|
time.sleep(wait_time)
|
|
continue
|
|
else:
|
|
logger.error("No order data in API response")
|
|
if attempt < retries - 1:
|
|
wait_time = 2 ** attempt
|
|
time.sleep(wait_time)
|
|
continue
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating order (attempt {attempt+1}/{retries}): {str(e)}")
|
|
if attempt < retries - 1:
|
|
wait_time = 2 ** attempt
|
|
time.sleep(wait_time)
|
|
else:
|
|
return None
|
|
|
|
return None
|
|
|
|
def get_order_status(self, symbol, order_id):
|
|
"""Get order status"""
|
|
self.rate_limiter.wait("order_status")
|
|
|
|
try:
|
|
result = self.trade_api.get_order(instId=symbol, ordId=order_id)
|
|
|
|
# Correction: code "0" indicates success
|
|
if result['code'] != '0':
|
|
logger.error(f"Failed to get order status: {result['msg']} (code: {result['code']})")
|
|
return None
|
|
|
|
if len(result['data']) > 0:
|
|
order_data = result['data'][0]
|
|
return {
|
|
'state': order_data.get('state'),
|
|
'avgPx': float(order_data.get('avgPx', 0)),
|
|
'accFillSz': float(order_data.get('accFillSz', 0)),
|
|
'fillPx': float(order_data.get('fillPx', 0)),
|
|
'fillSz': float(order_data.get('fillSz', 0)),
|
|
'fillTime': order_data.get('fillTime')
|
|
}
|
|
else:
|
|
logger.error("No order data in API response")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting order status: {e}")
|
|
return None
|
|
|
|
def wait_for_order_completion(self, symbol, order_id, max_attempts=10, interval=1):
|
|
"""Wait for order completion"""
|
|
for attempt in range(max_attempts):
|
|
order_status = self.get_order_status(symbol, order_id)
|
|
if order_status is None:
|
|
return None
|
|
|
|
state = order_status['state']
|
|
if state == 'filled':
|
|
logger.info(f"Order completed: {order_id}, fill price={order_status['avgPx']:.2f}, fill quantity={order_status['accFillSz']:.10f}")
|
|
return order_status
|
|
elif state == 'canceled':
|
|
logger.warning(f"Order canceled: {order_id}")
|
|
return None
|
|
elif state == 'partially_filled':
|
|
logger.info(f"Order partially filled: {order_id}, filled={order_status['accFillSz']:.10f}")
|
|
time.sleep(interval)
|
|
else:
|
|
logger.info(f"Order status: {state}, waiting...")
|
|
time.sleep(interval)
|
|
|
|
logger.warning(f"Order not completed within specified time: {order_id}")
|
|
return None
|