orderflow_backtest/visualizer.py

256 lines
9.8 KiB
Python

# Set Qt5Agg as the default backend before importing pyplot
import os
import matplotlib
matplotlib.use('Qt5Agg')
import logging
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.patches import Rectangle
from datetime import datetime, timezone
from collections import deque
from typing import Deque, Optional
from pathlib import Path
from storage import Book, BookSnapshot
from models import Metric
from repositories.sqlite_metrics_repository import SQLiteMetricsRepository
class Visualizer:
"""Render OHLC candles, volume, OBI and CVD charts from order book data.
Aggregates mid-prices into OHLC bars and displays OBI/CVD metrics beneath.
Uses Qt5Agg backend for interactive charts.
Public methods:
- update_from_book: process all snapshots from a Book and display charts
- set_db_path: set database path for loading stored metrics
- flush: finalize and draw the last in-progress bar
- show: display the Matplotlib window using Qt5Agg
"""
def __init__(self, window_seconds: int = 60, max_bars: int = 200) -> None:
# Create subplots: OHLC on top, Volume below, OBI and CVD at bottom
self.fig, (self.ax_ohlc, self.ax_volume, self.ax_obi, self.ax_cvd) = plt.subplots(4, 1, figsize=(12, 10), sharex=True)
self.window_seconds = int(max(1, window_seconds))
self.max_bars = int(max(1, max_bars))
self._db_path: Optional[Path] = None
# Bars buffer: list of tuples (start_ts, open, high, low, close)
self._bars: Deque[tuple[int, float, float, float, float, float]] = deque(maxlen=self.max_bars)
# Current in-progress bucket state
self._current_bucket_ts: Optional[int] = None
self._open: Optional[float] = None
self._high: Optional[float] = None
self._low: Optional[float] = None
self._close: Optional[float] = None
self._volume: float = 0.0
def _bucket_start(self, ts: int) -> int:
return int(ts) - (int(ts) % self.window_seconds)
def _normalize_ts_seconds(self, ts: int) -> int:
"""Return epoch seconds from possibly ms/us timestamps.
Heuristic based on magnitude:
- >1e14: microseconds → divide by 1e6
- >1e11: milliseconds → divide by 1e3
- else: seconds
"""
its = int(ts)
if its > 100_000_000_000_000: # > 1e14 → microseconds
return its // 1_000_000
if its > 100_000_000_000: # > 1e11 → milliseconds
return its // 1_000
return its
def set_db_path(self, db_path: Path) -> None:
"""Set the database path for loading stored metrics."""
self._db_path = db_path
def _load_stored_metrics(self, start_timestamp: int, end_timestamp: int) -> list[Metric]:
"""Load stored metrics from database for the given time range."""
if not self._db_path:
return []
try:
metrics_repo = SQLiteMetricsRepository(self._db_path)
with metrics_repo.connect() as conn:
return metrics_repo.load_metrics_by_timerange(conn, start_timestamp, end_timestamp)
except Exception as e:
logging.error(f"Error loading metrics for visualization: {e}")
return []
def _append_current_bar(self) -> None:
if self._current_bucket_ts is None or self._open is None:
return
self._bars.append(
(
self._current_bucket_ts,
float(self._open),
float(self._high if self._high is not None else self._open),
float(self._low if self._low is not None else self._open),
float(self._close if self._close is not None else self._open),
float(self._volume),
)
)
def _draw(self) -> None:
# Clear all subplots
self.ax_ohlc.clear()
self.ax_volume.clear()
self.ax_obi.clear()
self.ax_cvd.clear()
if not self._bars:
self.fig.canvas.draw_idle()
return
day_seconds = 24 * 60 * 60
width = self.window_seconds / day_seconds
# Draw OHLC candlesticks and extract volume data
volume_data = []
timestamps_ohlc = []
for start_ts, open_, high_, low_, close_, volume in self._bars:
# Collect volume data
dt = datetime.fromtimestamp(start_ts, tz=timezone.utc).replace(tzinfo=None)
x = mdates.date2num(dt)
volume_data.append((x, volume))
timestamps_ohlc.append(x)
# Wick
self.ax_ohlc.vlines(x + width / 2.0, low_, high_, color="black", linewidth=1.0)
# Body
lower = min(open_, close_)
height = max(1e-12, abs(close_ - open_))
color = "green" if close_ >= open_ else "red"
rect = Rectangle((x, lower), width, height, facecolor=color, edgecolor=color, linewidth=1.0)
self.ax_ohlc.add_patch(rect)
# Plot volume bars
if volume_data:
volumes_x = [v[0] for v in volume_data]
volumes_y = [v[1] for v in volume_data]
self.ax_volume.bar(volumes_x, volumes_y, width=width, alpha=0.7, color='blue', align='center')
# Draw metrics if available
if self._bars:
first_ts = self._bars[0][0]
last_ts = self._bars[-1][0]
metrics = self._load_stored_metrics(first_ts, last_ts + self.window_seconds)
if metrics:
# Prepare data for plotting
timestamps = [mdates.date2num(datetime.fromtimestamp(m.timestamp / 1000, tz=timezone.utc).replace(tzinfo=None)) for m in metrics]
obi_values = [m.obi for m in metrics]
cvd_values = [m.cvd for m in metrics]
# Plot OBI and CVD
self.ax_obi.plot(timestamps, obi_values, 'b-', linewidth=1, label='OBI')
self.ax_obi.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
self.ax_cvd.plot(timestamps, cvd_values, 'r-', linewidth=1, label='CVD')
# Configure axes
self.ax_ohlc.set_title("Mid-price OHLC")
self.ax_ohlc.set_ylabel("Price")
self.ax_volume.set_title("Volume")
self.ax_volume.set_ylabel("Volume")
self.ax_obi.set_title("Order Book Imbalance (OBI)")
self.ax_obi.set_ylabel("OBI")
self.ax_obi.set_ylim(-1.1, 1.1)
self.ax_cvd.set_title("Cumulative Volume Delta (CVD)")
self.ax_cvd.set_ylabel("CVD")
self.ax_cvd.set_xlabel("Time (UTC)")
# Format time axis for bottom subplot only
self.ax_cvd.xaxis_date()
self.ax_cvd.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S"))
self.fig.tight_layout()
self.fig.canvas.draw_idle()
def update_from_book(self, book: Book) -> None:
"""Update the visualizer with all snapshots from the book.
Uses best bid/ask to compute mid-price; aggregates into OHLC bars.
Processes all snapshots in chronological order.
"""
if not book.snapshots:
logging.warning("Book has no snapshots to visualize")
return
# Reset state before processing all snapshots
self._bars.clear()
self._current_bucket_ts = None
self._open = self._high = self._low = self._close = None
self._volume = 0.0
logging.info(f"Visualizing {len(book.snapshots)} snapshots")
# Process all snapshots in chronological order
snapshot_count = 0
for snapshot in sorted(book.snapshots, key=lambda s: s.timestamp):
snapshot_count += 1
if not snapshot.bids or not snapshot.asks:
continue
try:
best_bid = max(snapshot.bids.keys())
best_ask = min(snapshot.asks.keys())
except (ValueError, TypeError):
continue
mid = (float(best_bid) + float(best_ask)) / 2.0
ts_raw = int(snapshot.timestamp)
ts = self._normalize_ts_seconds(ts_raw)
bucket_ts = self._bucket_start(ts)
# Calculate volume from trades in this snapshot
snapshot_volume = sum(trade.size for trade in snapshot.trades)
# New bucket: close and store previous bar
if self._current_bucket_ts is None:
self._current_bucket_ts = bucket_ts
self._open = self._high = self._low = self._close = mid
self._volume = snapshot_volume
elif bucket_ts != self._current_bucket_ts:
self._append_current_bar()
self._current_bucket_ts = bucket_ts
self._open = self._high = self._low = self._close = mid
self._volume = snapshot_volume
else:
# Update current bucket OHLC and accumulate volume
if self._high is None or mid > self._high:
self._high = mid
if self._low is None or mid < self._low:
self._low = mid
self._close = mid
self._volume += snapshot_volume
# Finalize the last bar
self._append_current_bar()
logging.info(f"Created {len(self._bars)} OHLC bars from {snapshot_count} valid snapshots")
# Draw all bars
self._draw()
def flush(self) -> None:
"""Finalize the in-progress bar and redraw."""
self._append_current_bar()
# Reset current state (optional: keep last bucket running)
self._current_bucket_ts = None
self._open = self._high = self._low = self._close = None
self._volume = 0.0
self._draw()
def show(self) -> None:
plt.show()