from __future__ import annotations

import json
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Iterable

from analysis.safe_btc5.types import parse_dt

from .recorder import OrderBook


BOOK_EVENTS = {
    "rest_book",
    "book",
    "rest_book_after_reconnect",
    "rest_book_after_tick_size_change",
}
FRAME_EVENTS = {
    "rest_book",
    "book",
    "rest_book_after_reconnect",
    "rest_book_after_tick_size_change",
    "price_change",
    "last_trade_price",
    "tick_size_change",
}


def parse_event_time(value: Any) -> datetime | None:
    parsed = parse_dt(value)
    if parsed is None:
        return None
    return parsed.astimezone(timezone.utc)


def event_delay_ms(row: dict[str, Any]) -> float | None:
    exchange_ts = parse_event_time(row.get("exchange_ts"))
    received_ts = parse_event_time(row.get("received_at_wall"))
    if exchange_ts is None or received_ts is None:
        return None
    return round((received_ts - exchange_ts).total_seconds() * 1000.0, 3)


def load_market_meta(
    market_dir: Path,
    *,
    market_kind: str | None = None,
) -> tuple[dict[str, Any], list[str], dict[str, str]]:
    meta_path = market_dir / "recording_meta.json"
    if not meta_path.exists():
        raise FileNotFoundError(meta_path)
    meta = json.loads(meta_path.read_text(encoding="utf-8"))
    token_ids = [str(token_id) for token_id in meta.get("token_ids", [])]
    outcomes = [str(outcome) for outcome in meta.get("outcomes", [])]
    token_to_outcome = {
        token_id: outcomes[idx] if idx < len(outcomes) else token_id
        for idx, token_id in enumerate(token_ids)
    }
    if market_kind:
        target = None
        for item in meta.get("target_markets") or []:
            if str(item.get("kind") or "") == market_kind:
                target = item
                break
        if target is None:
            raise ValueError(f"market kind not found in recording_meta.json: {market_kind}")
        token_ids = [str(token_id) for token_id in target.get("token_ids") or []]
        target_outcomes = [str(outcome) for outcome in target.get("outcomes") or []]
        token_to_outcome = {
            token_id: (
                f"{market_kind}:{target_outcomes[idx]}"
                if idx < len(target_outcomes)
                else f"{market_kind}:{token_id}"
            )
            for idx, token_id in enumerate(token_ids)
        }
        meta = {
            **meta,
            "slug": target.get("slug") or meta.get("slug"),
            "question": target.get("question") or meta.get("question"),
            "condition_id": target.get("condition_id") or "",
            "token_ids": token_ids,
            "outcomes": [token_to_outcome[token_id] for token_id in token_ids],
            "market_kind": market_kind,
        }
    return meta, token_ids, token_to_outcome


def compact_levels(
    levels: list[dict[str, float]],
    *,
    depth: int,
) -> list[dict[str, float]]:
    return [
        {"price": round(float(level["price"]), 6), "size": round(float(level["size"]), 6)}
        for level in levels[:depth]
    ]


def book_payload(book: OrderBook, *, depth: int) -> dict[str, Any]:
    best_bid, _bid_size = book.best_bid()
    best_ask, _ask_size = book.best_ask()
    return {
        "best_bid": best_bid,
        "best_ask": best_ask,
        "bids": compact_levels(book.bid_levels(), depth=depth),
        "asks": compact_levels(book.ask_levels(), depth=depth),
    }


def trade_payload(row: dict[str, Any]) -> dict[str, Any] | None:
    raw = row.get("raw_json") or {}
    if row.get("event_type") != "last_trade_price":
        return None
    try:
        price = float(raw.get("price"))
        size = float(raw.get("size") or 0)
    except (TypeError, ValueError):
        return None
    return {
        "asset_id": str(row.get("asset_id") or raw.get("asset_id") or ""),
        "outcome": str(row.get("outcome") or ""),
        "price": price,
        "size": size,
        "side": str(raw.get("side") or ""),
        "transaction_hash": str(raw.get("transaction_hash") or ""),
    }


def changed_levels(row: dict[str, Any]) -> list[dict[str, Any]]:
    raw = row.get("raw_json") or {}
    output: list[dict[str, Any]] = []
    if row.get("event_type") != "price_change":
        return output
    for change in raw.get("price_changes") or []:
        try:
            price = float(change.get("price"))
            size = float(change.get("size") or 0)
        except (TypeError, ValueError):
            continue
        output.append(
            {
                "asset_id": str(change.get("asset_id") or ""),
                "side": str(change.get("side") or ""),
                "price": price,
                "size": size,
                "best_bid": _float_or_none(change.get("best_bid")),
                "best_ask": _float_or_none(change.get("best_ask")),
            }
        )
    return output


def filter_changed_levels(
    levels: list[dict[str, Any]],
    token_set: set[str],
) -> list[dict[str, Any]]:
    return [level for level in levels if str(level.get("asset_id") or "") in token_set]


def _float_or_none(value: Any) -> float | None:
    try:
        return float(value)
    except (TypeError, ValueError):
        return None


def apply_orderbook_event(books: dict[str, OrderBook], row: dict[str, Any]) -> None:
    event_type = str(row.get("event_type") or "")
    asset_id = str(row.get("asset_id") or "")
    raw = row.get("raw_json") or {}
    if event_type in BOOK_EVENTS and asset_id:
        books.setdefault(asset_id, OrderBook()).apply_snapshot(raw)
        return
    if event_type != "price_change":
        return
    grouped: dict[str, list[dict[str, Any]]] = {}
    for change in raw.get("price_changes") or []:
        token_id = str(change.get("asset_id") or "")
        if token_id:
            grouped.setdefault(token_id, []).append(change)
    for token_id, changes in grouped.items():
        books.setdefault(token_id, OrderBook()).apply_delta(changes)


def iter_event_frames(
    *,
    slug: str,
    data_root: Path = Path("data/lpl"),
    depth: int = 5,
    frame_events: Iterable[str] = FRAME_EVENTS,
    market_kind: str | None = None,
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
    market_dir = data_root / slug
    events_path = market_dir / "orderbook_events.jsonl"
    if not events_path.exists():
        raise FileNotFoundError(events_path)

    meta, token_ids, token_to_outcome = load_market_meta(market_dir, market_kind=market_kind)
    token_set = set(token_ids)
    books = {token_id: OrderBook() for token_id in token_ids}
    last_trade_by_asset: dict[str, dict[str, Any]] = {}
    allowed_events = set(frame_events)
    frames: list[dict[str, Any]] = []

    with events_path.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            row = json.loads(line)
            event_type = str(row.get("event_type") or "")
            asset_id = str(row.get("asset_id") or "")
            if event_type in BOOK_EVENTS and asset_id and asset_id not in token_set:
                continue
            if event_type == "last_trade_price" and asset_id and asset_id not in token_set:
                continue
            if event_type == "price_change":
                levels = filter_changed_levels(changed_levels(row), token_set)
                if not levels:
                    continue
            apply_orderbook_event(books, row)
            trade = trade_payload(row)
            if trade is not None:
                last_trade_by_asset[trade["asset_id"]] = {
                    **trade,
                    "received_at_wall": row.get("received_at_wall") or "",
                    "exchange_ts": row.get("exchange_ts") or "",
                    "delay_ms": event_delay_ms(row),
                }
            if event_type not in allowed_events:
                continue
            if token_ids and not all(books[token_id].bids or books[token_id].asks for token_id in token_ids):
                continue
            frame_changes = filter_changed_levels(changed_levels(row), token_set)
            books_payload = {}
            for token_id in token_ids:
                payload = book_payload(books[token_id], depth=depth)
                payload["asset_id"] = token_id
                payload["outcome"] = token_to_outcome.get(token_id, token_id)
                payload["last_trade"] = last_trade_by_asset.get(token_id)
                books_payload[token_id] = payload
            frames.append(
                {
                    "received_at_wall": row.get("received_at_wall") or "",
                    "exchange_ts": row.get("exchange_ts") or "",
                    "delay_ms": event_delay_ms(row),
                    "message_index": row.get("message_index"),
                    "connection_id": row.get("connection_id"),
                    "event_type": event_type,
                    "asset_id": asset_id,
                    "outcome": str(row.get("outcome") or ""),
                    "changed_levels": frame_changes,
                    "trade": trade,
                    "books": books_payload,
                }
            )

    return meta, frames
