from __future__ import annotations

import csv
import json
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any

from analysis.safe_btc5.types import parse_dt

from .recorder import OrderBook


def _parse_at(value: str) -> datetime:
    parsed = parse_dt(value)
    if parsed is None:
        raise ValueError(f"invalid datetime: {value}")
    return parsed.astimezone(timezone.utc)


def _event_time(row: dict[str, Any]) -> datetime | None:
    return parse_dt(row.get("received_at_wall"))


def _apply_event(books: dict[str, OrderBook], row: dict[str, Any]) -> None:
    event_type = row.get("event_type")
    asset_id = str(row.get("asset_id") or "")
    raw = row.get("raw_json") or {}
    if event_type in {"rest_book", "rest_book_after_reconnect", "book"} and asset_id:
        books.setdefault(asset_id, OrderBook()).apply_snapshot(raw)
        return
    if event_type == "price_change":
        grouped: dict[str, list[dict[str, Any]]] = {}
        for change in raw.get("price_changes", []):
            token_id = str(change.get("asset_id") or "")
            if token_id:
                grouped.setdefault(token_id, []).append(change)
        for token_id, changes in grouped.items():
            books.setdefault(token_id, OrderBook()).apply_delta(changes)


def _state_rows(
    *,
    row: dict[str, Any],
    books: dict[str, OrderBook],
    token_to_outcome: dict[str, str],
) -> list[dict[str, Any]]:
    output: list[dict[str, Any]] = []
    for token_id, book in books.items():
        bid, bid_size = book.best_bid()
        ask, ask_size = book.best_ask()
        output.append(
            {
                "received_at_wall": row.get("received_at_wall", ""),
                "message_index": row.get("message_index", ""),
                "event_type": row.get("event_type", ""),
                "asset_id": token_id,
                "outcome": token_to_outcome.get(token_id, ""),
                "best_bid": "" if bid is None else bid,
                "best_bid_size": "" if bid_size is None else bid_size,
                "best_ask": "" if ask is None else ask,
                "best_ask_size": "" if ask_size is None else ask_size,
                "spread": "" if bid is None or ask is None else round(ask - bid, 8),
            }
        )
    return output


def replay_orderbook_window(
    *,
    slug: str,
    at: str,
    window_seconds: int,
    data_root: Path = Path("data/lpl"),
    output_path: Path | None = None,
) -> dict[str, Any]:
    target = _parse_at(at)
    start = target - timedelta(seconds=window_seconds)
    end = target + timedelta(seconds=window_seconds)
    market_dir = data_root / slug
    events_path = market_dir / "orderbook_events.jsonl"
    meta_path = market_dir / "recording_meta.json"
    if not events_path.exists():
        raise FileNotFoundError(events_path)
    if not meta_path.exists():
        raise FileNotFoundError(meta_path)

    meta = json.loads(meta_path.read_text(encoding="utf-8"))
    token_ids = [str(token_id) for token_id in meta.get("token_ids", [])]
    outcomes = [str(outcome) for outcome in meta.get("outcomes", [])]
    token_to_outcome = {
        token_id: outcomes[idx] if idx < len(outcomes) else ""
        for idx, token_id in enumerate(token_ids)
    }
    books = {token_id: OrderBook() for token_id in token_ids}

    rows: list[dict[str, Any]] = []
    processed = 0
    with events_path.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            row = json.loads(line)
            ts = _event_time(row)
            if ts is None:
                continue
            if ts > end:
                break
            _apply_event(books, row)
            processed += 1
            if start <= ts <= end:
                rows.extend(_state_rows(row=row, books=books, token_to_outcome=token_to_outcome))

    if output_path is None:
        stamp = target.strftime("%Y%m%dT%H%M%SZ")
        output_path = market_dir / f"replay_{stamp}_{window_seconds}s.csv"
    output_path.parent.mkdir(parents=True, exist_ok=True)
    fieldnames: list[str] = []
    for row in rows:
        for key in row:
            if key not in fieldnames:
                fieldnames.append(key)
    with output_path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)

    return {
        "slug": slug,
        "target": target.isoformat(),
        "window_seconds": window_seconds,
        "processed_events": processed,
        "rows": len(rows),
        "output_path": str(output_path),
    }
