Cycles/trader/okx_trader.py
2025-05-27 08:42:42 +08:00

212 lines
7.8 KiB
Python

import os
import time
import hmac
import hashlib
import base64
import json
import pandas as pd
import threading
import requests
import websocket
import datetime
class OKXTrader:
PUBLIC_WS_URL = "wss://ws.okx.com:8443/ws/v5/public"
PRIVATE_WS_URL = "wss://ws.okx.com:8443/ws/v5/private"
REST_URL = "https://www.okx.com"
def __init__(self):
# Load credentials from JSON config file
config_path = os.path.join(os.path.dirname(__file__), '../credentials/okx_creds.json')
try:
with open(config_path, 'r') as f:
config = json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Credentials file not found at {config_path}. Please create it with the required keys.")
except json.JSONDecodeError:
raise ValueError(f"Credentials file at {config_path} is not valid JSON.")
self.api_key = config.get("OKX_API_KEY")
self.api_secret = config.get("OKX_API_SECRET")
self.api_passphrase = config.get("OKX_API_PASSPHRASE")
if not self.api_key or not self.api_secret or not self.api_passphrase:
raise ValueError("API key, secret, and passphrase must be set in the credentials JSON file.")
self.ws = None
self.ws_private = None
self._lock = threading.Lock()
self._private_lock = threading.Lock()
print(f"[DEBUG] Connecting to public WebSocket: {self.PUBLIC_WS_URL}")
self._connect_ws()
self._authenticate()
def _connect_ws(self):
if self.ws is None:
self.ws = websocket.create_connection(self.PUBLIC_WS_URL, timeout=10)
if self.api_key and self.api_secret and self.api_passphrase and self.ws_private is None:
self.ws_private = websocket.create_connection(self.PRIVATE_WS_URL, timeout=10)
def _get_timestamp(self):
return str(round(time.time(), 3))
def _sign(self, timestamp, method, request_path, body):
if not body:
body = ''
message = f'{timestamp}{method}{request_path}{body}'
mac = hmac.new(self.api_secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256)
return base64.b64encode(mac.digest()).decode()
def _authenticate(self):
import websocket
timestamp = self._get_timestamp()
sign = self._sign(timestamp, 'GET', '/users/self/verify', '')
login_params = {
"op": "login",
"args": [{
"apiKey": self.api_key,
"passphrase": self.api_passphrase,
"timestamp": timestamp,
"sign": sign
}]
}
self.ws_private.send(json.dumps(login_params))
print("Waiting for login response from OKX...")
while True:
try:
resp = self.ws_private.recv()
print(f"[DEBUG] Received from OKX private WS: {resp}")
if not resp:
continue
try:
msg = json.loads(resp)
except Exception:
print(f"[WARN] Non-JSON message received: {resp}")
continue
if msg.get("event") == "login":
if msg.get("code") == "0":
print("[INFO] OKX WebSocket login successful.")
break
else:
raise Exception(f"WebSocket authentication failed: {msg}")
except websocket._exceptions.WebSocketConnectionClosedException as e:
print(f"[ERROR] WebSocket connection closed during authentication: {e}")
raise
except Exception as e:
print(f"[ERROR] Exception during authentication: {e}")
raise
def subscribe_candlesticks(self, instrument="BTC-USDT", timeframe="1m"):
# OKX uses candle1m, candle5m, etc.
tf_map = {"1m": "candle1m", "5m": "candle5m", "15m": "candle15m", "1h": "candle1H"}
channel = tf_map.get(timeframe, f"candle{timeframe}")
params = {
"op": "subscribe",
"args": [{"channel": channel, "instId": instrument}]
}
print(f"[DEBUG] Subscribing to candlesticks: {json.dumps(params)}")
self.ws.send(json.dumps(params))
def subscribe_trades(self, instrument="BTC-USDT"):
params = {
"op": "subscribe",
"args": [{"channel": "trades", "instId": instrument}]
}
self.ws.send(json.dumps(params))
def subscribe_ticker(self, instrument="BTC-USDT"):
params = {
"op": "subscribe",
"args": [{"channel": "tickers", "instId": instrument}]
}
self.ws.send(json.dumps(params))
def subscribe_book(self, instrument="BTC-USDT", depth=5):
# OKX supports books5, books50, books-l2-tbt
channel = "books5" if depth <= 5 else "books50"
params = {
"op": "subscribe",
"args": [{"channel": channel, "instId": instrument}]
}
self.ws.send(json.dumps(params))
def subscribe_user_order(self):
params = {
"op": "subscribe",
"args": [{"channel": "orders", "instType": "SPOT"}]
}
self.ws_private.send(json.dumps(params))
def subscribe_user_trade(self):
params = {
"op": "subscribe",
"args": [{"channel": "trades", "instType": "SPOT"}]
}
self.ws_private.send(json.dumps(params))
def subscribe_user_balance(self):
params = {
"op": "subscribe",
"args": [{"channel": "balance_and_position"}]
}
self.ws_private.send(json.dumps(params))
def get_balance(self, currency=None):
url = f"{self.REST_URL}/api/v5/account/balance"
timestamp = self._get_timestamp()
method = "GET"
request_path = "/api/v5/account/balance"
body = ''
sign = self._sign(timestamp, method, request_path, body)
headers = {
"OK-ACCESS-KEY": self.api_key,
"OK-ACCESS-SIGN": sign,
"OK-ACCESS-TIMESTAMP": timestamp,
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
"Content-Type": "application/json"
}
resp = requests.get(url, headers=headers)
if resp.status_code == 200:
data = resp.json()
balances = data.get("data", [{}])[0].get("details", [])
if currency:
return [b for b in balances if b.get("ccy") == currency]
return balances
return []
def place_order(self, side, amount, instrument="BTC-USDT"):
url = f"{self.REST_URL}/api/v5/trade/order"
timestamp = self._get_timestamp()
method = "POST"
request_path = "/api/v5/trade/order"
body_dict = {
"instId": instrument,
"tdMode": "cash",
"side": side.lower(),
"ordType": "market",
"sz": str(amount)
}
body = json.dumps(body_dict)
sign = self._sign(timestamp, method, request_path, body)
headers = {
"OK-ACCESS-KEY": self.api_key,
"OK-ACCESS-SIGN": sign,
"OK-ACCESS-TIMESTAMP": timestamp,
"OK-ACCESS-PASSPHRASE": self.api_passphrase,
"Content-Type": "application/json"
}
resp = requests.post(url, headers=headers, data=body)
return resp.json()
def buy_btc(self, amount, instrument="BTC-USDT"):
return self.place_order("buy", amount, instrument)
def sell_btc(self, amount, instrument="BTC-USDT"):
return self.place_order("sell", amount, instrument)
def get_instruments(self):
url = f"{self.REST_URL}/api/v5/public/instruments?instType=SPOT"
resp = requests.get(url)
if resp.status_code == 200:
data = resp.json()
return data.get("data", [])
return []